Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

License

NotificationsYou must be signed in to change notification settings

NVIDIA/apex

Repository files navigation

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.Some of the code here will be included in upstream Pytorch eventually.The intent of Apex is to make up-to-date utilities available to users as quickly as possible.

Full API Documentation:https://nvidia.github.io/apex

Contents

1. Amp: Automatic Mixed Precision

Deprecated. UsePyTorch AMP

apex.amp is a tool to enable mixed precision training by changing only 3 lines of your script.Users can easily experiment with different pure and mixed precision training modes by supplyingdifferent flags toamp.initialize.

Webinar introducing Amp(The flagcast_batchnorm has been renamed tokeep_batchnorm_fp32).

API Documentation

Comprehensive Imagenet example

DCGAN example coming soon...

Moving to the new Amp API (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)

2. Distributed Training

apex.parallel.DistributedDataParallel is deprecated. Usetorch.nn.parallel.DistributedDataParallel

apex.parallel.DistributedDataParallel is a module wrapper, similar totorch.nn.parallel.DistributedDataParallel. It enables convenient multiprocess distributed training,optimized for NVIDIA's NCCL communication library.

API Documentation

Python Source

Example/Walkthrough

TheImagenet exampleshows use ofapex.parallel.DistributedDataParallel along withapex.amp.

Synchronized Batch Normalization

Deprecated. Usetorch.nn.SyncBatchNorm

apex.parallel.SyncBatchNorm extendstorch.nn.modules.batchnorm._BatchNorm tosupport synchronized BN.It allreduces stats across processes during multiprocess (DistributedDataParallel) training.Synchronous BN has been used in cases where only a smalllocal minibatch can fit on each GPU.Allreduced stats increase the effective batch size for the BN layer to theglobal batch size across all processes (which, technically, is the correctformulation).Synchronous BN has been observed to improve converged accuracy in some of our research models.

Checkpointing

To properly save and load youramp training, we introduce theamp.state_dict(), which contains allloss_scalers and their corresponding unskipped steps,as well asamp.load_state_dict() to restore these attributes.

In order to get bitwise accuracy, we recommend the following workflow:

# Initializationopt_level='O1'model,optimizer=amp.initialize(model,optimizer,opt_level=opt_level)# Train your model...withamp.scale_loss(loss,optimizer)asscaled_loss:scaled_loss.backward()...# Save checkpointcheckpoint= {'model':model.state_dict(),'optimizer':optimizer.state_dict(),'amp':amp.state_dict()}torch.save(checkpoint,'amp_checkpoint.pt')...# Restoremodel= ...optimizer= ...checkpoint=torch.load('amp_checkpoint.pt')model,optimizer=amp.initialize(model,optimizer,opt_level=opt_level)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])amp.load_state_dict(checkpoint['amp'])# Continue training...

Note that we recommend restoring the model using the sameopt_level. Also note that we recommend calling theload_state_dict methods afteramp.initialize.

Installation

Eachapex.contrib module requires one or more install options other than--cpp_ext and--cuda_ext.Note that contrib modules do not necessarily support stable PyTorch releases.

Containers

NVIDIA PyTorch Containers are available on NGC:https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.The containers come with all the custom extensions available at the moment.

Seethe NGC documentation for details such as:

  • how to pull a container
  • how to run a pulled container
  • release notes

From Source

To install Apex from source, we recommend using the nightly Pytorch obtainable fromhttps://github.com/pytorch/pytorch.

The latest stable release obtainable fromhttps://pytorch.org should also work.

We recommend installingNinja to make compilation faster.

Linux

For performance and full functionality, we recommend installing Apex withCUDA and C++ extensions via

git clone https://github.com/NVIDIA/apexcd apex# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings"--build-option=--cpp_ext" --config-settings"--build-option=--cuda_ext" ./# otherwisepip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./

To reduce the build time of APEX, parallel building can be enhanced via

NVCC_APPEND_FLAGS="--threads 4" pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings"--build-option=--cpp_ext --cuda_ext --parallel 8" ./

When CPU cores or memory are limited, the--parallel option is generally preferred over--threads. Seepull#1882 for more details.

APEX also supports a Python-only build via

pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./

A Python-only build omits:

  • Fused kernels required to useapex.optimizers.FusedAdam.
  • Fused kernels required to useapex.normalization.FusedLayerNorm andapex.normalization.FusedRMSNorm.
  • Fused kernels that improve the performance and numerical stability ofapex.parallel.SyncBatchNorm.
  • Fused kernels that improve the performance ofapex.parallel.DistributedDataParallel andapex.amp.DistributedDataParallel,amp, andSyncBatchNorm will still be usable, but they may be slower.

[Experimental] Windows

pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" . may work if you were able to build Pytorch from sourceon your system. A Python-only build viapip install -v --no-cache-dir . is more likely to work.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.

Custom C++/CUDA Extensions and Install Options

If a requirement of a module is not met, then it will not be built.

Module NameInstall OptionMisc
apex_C--cpp_ext
amp_C--cuda_ext
syncbn--cuda_ext
fused_layer_norm_cuda--cuda_extapex.normalization
mlp_cuda--cuda_ext
scaled_upper_triang_masked_softmax_cuda--cuda_ext
generic_scaled_masked_softmax_cuda--cuda_ext
scaled_masked_softmax_cuda--cuda_ext
fused_weight_gradient_mlp_cuda--cuda_extRequires CUDA>=11
permutation_search_cuda--permutation_searchapex.contrib.sparsity
bnp--bnpapex.contrib.groupbn
xentropy--xentropyapex.contrib.xentropy
focal_loss_cuda--focal_lossapex.contrib.focal_loss
fused_index_mul_2d--index_mul_2dapex.contrib.index_mul_2d
fused_adam_cuda--deprecated_fused_adamapex.contrib.optimizers
fused_lamb_cuda--deprecated_fused_lambapex.contrib.optimizers
fast_layer_norm--fast_layer_normapex.contrib.layer_norm. different fromfused_layer_norm
fmhalib--fmhaapex.contrib.fmha
fast_multihead_attn--fast_multihead_attnapex.contrib.multihead_attn
transducer_joint_cuda--transducerapex.contrib.transducer
transducer_loss_cuda--transducerapex.contrib.transducer
cudnn_gbn_lib--cudnn_gbnRequires cuDNN>=8.5,apex.contrib.cudnn_gbn
peer_memory_cuda--peer_memoryapex.contrib.peer_memory
nccl_p2p_cuda--nccl_p2pRequires NCCL >= 2.10,apex.contrib.nccl_p2p
fast_bottleneck--fast_bottleneckRequirespeer_memory_cuda andnccl_p2p_cuda,apex.contrib.bottleneck
fused_conv_bias_relu--fused_conv_bias_reluRequires cuDNN>=8.4,apex.contrib.conv_bias_relu

About

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

Resources

License

Stars

Watchers

Forks

Packages

No packages published

[8]ページ先頭

©2009-2025 Movatter.jp