- Notifications
You must be signed in to change notification settings - Fork1.4k
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
License
NVIDIA/apex
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.Some of the code here will be included in upstream Pytorch eventually.The intent of Apex is to make up-to-date utilities available to users as quickly as possible.
Full API Documentation:https://nvidia.github.io/apex
GTC 2019 andPytorch DevCon 2019 Slides
Deprecated. UsePyTorch AMP
apex.amp
is a tool to enable mixed precision training by changing only 3 lines of your script.Users can easily experiment with different pure and mixed precision training modes by supplyingdifferent flags toamp.initialize
.
Webinar introducing Amp(The flagcast_batchnorm
has been renamed tokeep_batchnorm_fp32
).
Comprehensive Imagenet example
Moving to the new Amp API (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
apex.parallel.DistributedDataParallel
is deprecated. Usetorch.nn.parallel.DistributedDataParallel
apex.parallel.DistributedDataParallel
is a module wrapper, similar totorch.nn.parallel.DistributedDataParallel
. It enables convenient multiprocess distributed training,optimized for NVIDIA's NCCL communication library.
TheImagenet exampleshows use ofapex.parallel.DistributedDataParallel
along withapex.amp
.
Deprecated. Usetorch.nn.SyncBatchNorm
apex.parallel.SyncBatchNorm
extendstorch.nn.modules.batchnorm._BatchNorm
tosupport synchronized BN.It allreduces stats across processes during multiprocess (DistributedDataParallel) training.Synchronous BN has been used in cases where only a smalllocal minibatch can fit on each GPU.Allreduced stats increase the effective batch size for the BN layer to theglobal batch size across all processes (which, technically, is the correctformulation).Synchronous BN has been observed to improve converged accuracy in some of our research models.
To properly save and load youramp
training, we introduce theamp.state_dict()
, which contains allloss_scalers
and their corresponding unskipped steps,as well asamp.load_state_dict()
to restore these attributes.
In order to get bitwise accuracy, we recommend the following workflow:
# Initializationopt_level='O1'model,optimizer=amp.initialize(model,optimizer,opt_level=opt_level)# Train your model...withamp.scale_loss(loss,optimizer)asscaled_loss:scaled_loss.backward()...# Save checkpointcheckpoint= {'model':model.state_dict(),'optimizer':optimizer.state_dict(),'amp':amp.state_dict()}torch.save(checkpoint,'amp_checkpoint.pt')...# Restoremodel= ...optimizer= ...checkpoint=torch.load('amp_checkpoint.pt')model,optimizer=amp.initialize(model,optimizer,opt_level=opt_level)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])amp.load_state_dict(checkpoint['amp'])# Continue training...
Note that we recommend restoring the model using the sameopt_level
. Also note that we recommend calling theload_state_dict
methods afteramp.initialize
.
Eachapex.contrib
module requires one or more install options other than--cpp_ext
and--cuda_ext
.Note that contrib modules do not necessarily support stable PyTorch releases.
NVIDIA PyTorch Containers are available on NGC:https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.The containers come with all the custom extensions available at the moment.
Seethe NGC documentation for details such as:
- how to pull a container
- how to run a pulled container
- release notes
To install Apex from source, we recommend using the nightly Pytorch obtainable fromhttps://github.com/pytorch/pytorch.
The latest stable release obtainable fromhttps://pytorch.org should also work.
We recommend installingNinja
to make compilation faster.
For performance and full functionality, we recommend installing Apex withCUDA and C++ extensions via
git clone https://github.com/NVIDIA/apexcd apex# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings"--build-option=--cpp_ext" --config-settings"--build-option=--cuda_ext" ./# otherwisepip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
To reduce the build time of APEX, parallel building can be enhanced via
NVCC_APPEND_FLAGS="--threads 4" pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings"--build-option=--cpp_ext --cuda_ext --parallel 8" ./
When CPU cores or memory are limited, the--parallel
option is generally preferred over--threads
. Seepull#1882 for more details.
APEX also supports a Python-only build via
pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./
A Python-only build omits:
- Fused kernels required to use
apex.optimizers.FusedAdam
. - Fused kernels required to use
apex.normalization.FusedLayerNorm
andapex.normalization.FusedRMSNorm
. - Fused kernels that improve the performance and numerical stability of
apex.parallel.SyncBatchNorm
. - Fused kernels that improve the performance of
apex.parallel.DistributedDataParallel
andapex.amp
.DistributedDataParallel
,amp
, andSyncBatchNorm
will still be usable, but they may be slower.
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" .
may work if you were able to build Pytorch from sourceon your system. A Python-only build viapip install -v --no-cache-dir .
is more likely to work.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
If a requirement of a module is not met, then it will not be built.
Module Name | Install Option | Misc |
---|---|---|
apex_C | --cpp_ext | |
amp_C | --cuda_ext | |
syncbn | --cuda_ext | |
fused_layer_norm_cuda | --cuda_ext | apex.normalization |
mlp_cuda | --cuda_ext | |
scaled_upper_triang_masked_softmax_cuda | --cuda_ext | |
generic_scaled_masked_softmax_cuda | --cuda_ext | |
scaled_masked_softmax_cuda | --cuda_ext | |
fused_weight_gradient_mlp_cuda | --cuda_ext | Requires CUDA>=11 |
permutation_search_cuda | --permutation_search | apex.contrib.sparsity |
bnp | --bnp | apex.contrib.groupbn |
xentropy | --xentropy | apex.contrib.xentropy |
focal_loss_cuda | --focal_loss | apex.contrib.focal_loss |
fused_index_mul_2d | --index_mul_2d | apex.contrib.index_mul_2d |
fused_adam_cuda | --deprecated_fused_adam | apex.contrib.optimizers |
fused_lamb_cuda | --deprecated_fused_lamb | apex.contrib.optimizers |
fast_layer_norm | --fast_layer_norm | apex.contrib.layer_norm . different fromfused_layer_norm |
fmhalib | --fmha | apex.contrib.fmha |
fast_multihead_attn | --fast_multihead_attn | apex.contrib.multihead_attn |
transducer_joint_cuda | --transducer | apex.contrib.transducer |
transducer_loss_cuda | --transducer | apex.contrib.transducer |
cudnn_gbn_lib | --cudnn_gbn | Requires cuDNN>=8.5,apex.contrib.cudnn_gbn |
peer_memory_cuda | --peer_memory | apex.contrib.peer_memory |
nccl_p2p_cuda | --nccl_p2p | Requires NCCL >= 2.10,apex.contrib.nccl_p2p |
fast_bottleneck | --fast_bottleneck | Requirespeer_memory_cuda andnccl_p2p_cuda ,apex.contrib.bottleneck |
fused_conv_bias_relu | --fused_conv_bias_relu | Requires cuDNN>=8.4,apex.contrib.conv_bias_relu |
About
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch