Note
Go to the endto download the full example code.
Performance Tuning Guide#
Created On: Sep 21, 2020 | Last Updated: Jul 09, 2025 | Last Verified: Nov 05, 2024
Author:Szymon Migacz
Performance Tuning Guide is a set of optimizations and best practices which canaccelerate training and inference of deep learning models in PyTorch. Presentedtechniques often can be implemented by changing only a few lines of code and canbe applied to a wide range of deep learning models across all domains.
General optimization techniques for PyTorch models
CPU-specific performance optimizations
GPU acceleration strategies
Distributed training optimizations
PyTorch 2.0 or later
Python 3.8 or later
CUDA-capable GPU (recommended for GPU optimizations)
Linux, macOS, or Windows operating system
Overview#
Performance optimization is crucial for efficient deep learning model training and inference.This tutorial covers a comprehensive set of techniques to accelerate PyTorch workloads acrossdifferent hardware configurations and use cases.
General optimizations#
importtorchimporttorchvision
Enable asynchronous data loading and augmentation#
torch.utils.data.DataLoadersupports asynchronous data loading and data augmentation in separate workersubprocesses. The default setting forDataLoader isnum_workers=0,which means that the data loading is synchronous and done in the main process.As a result the main training process has to wait for the data to be availableto continue the execution.
Settingnum_workers>0 enables asynchronous data loading and overlapbetween the training and data loading.num_workers should be tuneddepending on the workload, CPU, GPU, and location of training data.
DataLoader acceptspin_memory argument, which defaults toFalse.When using a GPU it’s better to setpin_memory=True, this instructsDataLoader to use pinned memory and enables faster and asynchronous memorycopy from the host to the GPU.
Disable gradient calculation for validation or inference#
PyTorch saves intermediate buffers from all operations which involve tensorsthat require gradients. Typically gradients aren’t needed for validation orinference.torch.no_grad()context manager can be applied to disable gradient calculation within aspecified block of code, this accelerates execution and reduces the amount ofrequired memory.torch.no_grad()can also be used as a function decorator.
Disable bias for convolutions directly followed by a batch norm#
torch.nn.Conv2d()hasbias parameter which defaults toTrue (the same is true forConv1dandConv3d).
If ann.Conv2d layer is directly followed by ann.BatchNorm2d layer,then the bias in the convolution is not needed, instead usenn.Conv2d(...,bias=False,....). Bias is not needed because in the firststepBatchNorm subtracts the mean, which effectively cancels out theeffect of bias.
This is also applicable to 1d and 3d convolutions as long asBatchNorm (orother normalization layer) normalizes on the same dimension as convolution’sbias.
Models available fromtorchvisionalready implement this optimization.
Use parameter.grad = None instead of model.zero_grad() or optimizer.zero_grad()#
Instead of calling:
model.zero_grad()# oroptimizer.zero_grad()
to zero out gradients, use the following method instead:
forparaminmodel.parameters():param.grad=None
The second code snippet does not zero the memory of each individual parameter,also the subsequent backward pass uses assignment instead of addition to storegradients, this reduces the number of memory operations.
Setting gradient toNone has a slightly different numerical behavior thansetting it to zero, for more details refer to thedocumentation.
Alternatively, callmodel oroptimizer.zero_grad(set_to_none=True).
Fuse operations#
Pointwise operations such as elementwise addition, multiplication, and mathfunctions likesin(),cos(),sigmoid(), etc., can be combined into asingle kernel. This fusion helps reduce memory access and kernel launch times.Typically, pointwise operations are memory-bound; PyTorch eager-mode initiatesa separate kernel for each operation, which involves loading data from memory,executing the operation (often not the most time-consuming step), and writingthe results back to memory.
By using a fused operator, only one kernel is launched for multiple pointwiseoperations, and data is loaded and stored just once. This efficiency isparticularly beneficial for activation functions, optimizers, and custom RNN cells etc.
PyTorch 2 introduces a compile-mode facilitated by TorchInductor, an underlying compilerthat automatically fuses kernels. TorchInductor extends its capabilities beyond simpleelement-wise operations, enabling advanced fusion of eligible pointwise and reductionoperations for optimized performance.
In the simplest case fusion can be enabled by applyingtorch.compiledecorator to the function definition, for example:
@torch.compiledefgelu(x):returnx*0.5*(1.0+torch.erf(x/1.41421))
Refer toIntroduction to torch.compilefor more advanced use cases.
Enable channels_last memory format for computer vision models#
PyTorch supportschannels_last memory format forconvolutional networks. This format is meant to be used in conjunction withAMP to further accelerateconvolutional neural networks withTensor Cores.
Support forchannels_last is experimental, but it’s expected to work forstandard computer vision models (e.g. ResNet-50, SSD). To convert models tochannels_last format followChannels Last Memory Format Tutorial.The tutorial includes a section onconverting existing models.
Checkpoint intermediate buffers#
Buffer checkpointing is a technique to mitigate the memory capacity burden ofmodel training. Instead of storing inputs of all layers to compute upstreamgradients in backward propagation, it stores the inputs of a few layers andthe others are recomputed during backward pass. The reduced memoryrequirements enables increasing the batch size that can improve utilization.
Checkpointing targets should be selected carefully. The best is not to storelarge layer outputs that have small re-computation cost. The example targetlayers are activation functions (e.g.ReLU,Sigmoid,Tanh),up/down sampling and matrix-vector operations with small accumulation depth.
PyTorch supports a nativetorch.utils.checkpointAPI to automatically perform checkpointing and recomputation.
Disable debugging APIs#
Many PyTorch APIs are intended for debugging and should be disabled forregular training runs:
anomaly detection:torch.autograd.detect_anomalyortorch.autograd.set_detect_anomaly(True)
profiler related:torch.autograd.profiler.emit_nvtx,torch.autograd.profiler.profile
autograd
gradcheck:torch.autograd.gradcheckortorch.autograd.gradgradcheck
CPU specific optimizations#
Utilize Non-Uniform Memory Access (NUMA) Controls#
NUMA or non-uniform memory access is a memory layout design used in data center machines meant to take advantage of locality of memory in multi-socket machines with multiple memory controllers and blocks. Generally speaking, all deep learning workloads, training or inference, get better performance without accessing hardware resources across NUMA nodes. Thus, inference can be run with multiple instances, each instance runs on one socket, to raise throughput. For training tasks on single node, distributed training is recommended to make each training process run on one socket.
In general cases the following command executes a PyTorch script on cores on the Nth node only, and avoids cross-socket memory access to reduce memory access overhead.
numactl--cpunodebind=N--membind=Npython<pytorch_script>
More detailed descriptions can be foundhere.
Utilize OpenMP#
OpenMP is utilized to bring better performance for parallel computation tasks.OMP_NUM_THREADS is the easiest switch that can be used to accelerate computations. It determines number of threads used for OpenMP computations.CPU affinity setting controls how workloads are distributed over multiple cores. It affects communication overhead, cache line invalidation overhead, or page thrashing, thus proper setting of CPU affinity brings performance benefits.GOMP_CPU_AFFINITY orKMP_AFFINITY determines how to bind OpenMP* threads to physical processing units. Detailed information can be foundhere.
With the following command, PyTorch run the task on N OpenMP threads.
exportOMP_NUM_THREADS=N
Typically, the following environment variables are used to set for CPU affinity with GNU OpenMP implementation.OMP_PROC_BIND specifies whether threads may be moved between processors. Setting it to CLOSE keeps OpenMP threads close to the primary thread in contiguous place partitions.OMP_SCHEDULE determines how OpenMP threads are scheduled.GOMP_CPU_AFFINITY binds threads to specific CPUs.An important tuning parameter is core pinning which prevent the threads of migrating between multiple CPUs, enhancing data location and minimizing inter core communication.
exportOMP_SCHEDULE=STATICexportOMP_PROC_BIND=CLOSEexportGOMP_CPU_AFFINITY="N-M"
Intel OpenMP Runtime Library (libiomp)#
By default, PyTorch uses GNU OpenMP (GNUlibgomp) for parallel computation. On Intel platforms, Intel OpenMP Runtime Library (libiomp) provides OpenMP API specification support. It sometimes brings more performance benefits compared tolibgomp. Utilizing environment variableLD_PRELOAD can switch OpenMP library tolibiomp:
exportLD_PRELOAD=<path>/libiomp5.so:$LD_PRELOAD
Similar to CPU affinity settings in GNU OpenMP, environment variables are provided inlibiomp to control CPU affinity settings.KMP_AFFINITY binds OpenMP threads to physical processing units.KMP_BLOCKTIME sets the time, in milliseconds, that a thread should wait, after completing the execution of a parallel region, before sleeping. In most cases, settingKMP_BLOCKTIME to 1 or 0 yields good performances.The following commands show a common settings with Intel OpenMP Runtime Library.
exportKMP_AFFINITY=granularity=fine,compact,1,0exportKMP_BLOCKTIME=1
Switch Memory allocator#
For deep learning workloads,Jemalloc orTCMalloc can get better performance by reusing memory as much as possible than defaultmalloc function.Jemalloc is a general purposemalloc implementation that emphasizes fragmentation avoidance and scalable concurrency support.TCMalloc also features a couple of optimizations to speed up program executions. One of them is holding memory in caches to speed up access of commonly-used objects. Holding such caches even after deallocation also helps avoid costly system calls if such memory is later re-allocated.Use environment variableLD_PRELOAD to take advantage of one of them.
exportLD_PRELOAD=<jemalloc.so/tcmalloc.so>:$LD_PRELOAD
Train a model on CPU with PyTorch``DistributedDataParallel``(DDP) functionality#
For small scale models or memory-bound models, such as DLRM, training on CPU is also a good choice. On a machine with multiple sockets, distributed training brings a high-efficient hardware resource usage to accelerate the training process.Torch-ccl, optimized with Intel(R)oneCCL (collective communications library) for efficient distributed deep learning training implementing such collectives likeallreduce,allgather,alltoall, implements PyTorch C10DProcessGroup API and can be dynamically loaded as externalProcessGroup. Upon optimizations implemented in PyTorch DDP module,torch-ccl accelerates communication operations. Beside the optimizations made to communication kernels,torch-ccl also features simultaneous computation-communication functionality.
GPU specific optimizations#
Enable Tensor cores#
Tensor cores are specialized hardware designed to compute matrix-matrix multiplicationoperations, primarily utilized in deep learning and AI workloads. Tensor cores havespecific precision requirements which can be adjusted manually or via the AutomaticMixed Precision API.
In particular, tensor operations take advantage of lower precision workloads.Which can be controlled viatorch.set_float32_matmul_precision.The default format is set to ‘highest,’ which utilizes the tensor data type.However, PyTorch offers alternative precision settings: ‘high’ and ‘medium.’These options prioritize computational speed over numerical precision.”
Use CUDA Graphs#
At the time of using a GPU, work first must be launched from the CPU andin some cases the context switch between CPU and GPU can lead to bad resourceutilization. CUDA graphs are a way to keep computation within the GPU withoutpaying the extra cost of kernel launches and host synchronization.
# It can be enabled usingtorch.compile(m,"reduce-overhead")# ortorch.compile(m,"max-autotune")
Support for CUDA graph is in development, and its usage can incur in increaseddevice memory consumption and some models might not compile.
Enable cuDNN auto-tuner#
NVIDIA cuDNN supports many algorithmsto compute a convolution. Autotuner runs a short benchmark and selects thekernel with the best performance on a given hardware for a given input size.
For convolutional networks (other types currently not supported), enable cuDNNautotuner before launching the training loop by setting:
the auto-tuner decisions may be non-deterministic; different algorithm maybe selected for different runs. For more details seePyTorch: Reproducibility
in some rare cases, such as with highly variable input sizes, it’s betterto run convolutional networks with autotuner disabled to avoid the overheadassociated with algorithm selection for each input size.
Avoid unnecessary CPU-GPU synchronization#
Avoid unnecessary synchronizations, to let the CPU run ahead of theaccelerator as much as possible to make sure that the accelerator work queuecontains many operations.
When possible, avoid operations which require synchronizations, for example:
print(cuda_tensor)cuda_tensor.item()memory copies:
tensor.cuda(),cuda_tensor.cpu()and equivalenttensor.to(device)callscuda_tensor.nonzero()python control flow which depends on results of operations performed on CUDAtensors e.g.
if(cuda_tensor!=0).all()
Create tensors directly on the target device#
Instead of callingtorch.rand(size).cuda() to generate a random tensor,produce the output directly on the target device:torch.rand(size,device='cuda').
This is applicable to all functions which create new tensors and acceptdevice argument:torch.rand(),torch.zeros(),torch.full()and similar.
Use mixed precision and AMP#
Mixed precision leveragesTensor Coresand offers up to 3x overall speedup on Volta and newer GPU architectures. Touse Tensor Cores AMP should be enabled and matrix/tensor dimensions shouldsatisfy requirements for calling kernels that use Tensor Cores.
To use Tensor Cores:
set sizes to multiples of 8 (to map onto dimensions of Tensor Cores)
seeDeep Learning Performance Documentationfor more details and guidelines specific to layer type
if layer size is derived from other parameters rather than fixed, it canstill be explicitly padded e.g. vocabulary size in NLP models
enable AMP
Introduction to Mixed Precision Training and AMP:slides
native PyTorch AMP is available:documentation,examples,tutorial
Preallocate memory in case of variable input length#
Models for speech recognition or for NLP are often trained on input tensorswith variable sequence length. Variable length can be problematic for PyTorchcaching allocator and can lead to reduced performance or to unexpectedout-of-memory errors. If a batch with a short sequence length is followed byan another batch with longer sequence length, then PyTorch is forced torelease intermediate buffers from previous iteration and to re-allocate newbuffers. This process is time consuming and causes fragmentation in thecaching allocator which may result in out-of-memory errors.
A typical solution is to implement preallocation. It consists of thefollowing steps:
generate a (usually random) batch of inputs with maximum sequence length(either corresponding to max length in the training dataset or to somepredefined threshold)
execute a forward and a backward pass with the generated batch, do notexecute an optimizer or a learning rate scheduler, this step preallocatesbuffers of maximum size, which can be reused in subsequenttraining iterations
zero out gradients
proceed to regular training
Distributed optimizations#
Use efficient data-parallel backend#
PyTorch has two ways to implement data-parallel training:
DistributedDataParallel offers much better performance and scaling tomultiple-GPUs. For more information refer to therelevant section of CUDA Best Practicesfrom PyTorch documentation.
Skip unnecessary all-reduce if training withDistributedDataParallel and gradient accumulation#
By defaulttorch.nn.parallel.DistributedDataParallelexecutes gradient all-reduce after every backward pass to compute the averagegradient over all workers participating in the training. If training usesgradient accumulation over N steps, then all-reduce is not necessary afterevery training step, it’s only required to perform all-reduce after the lastcall to backward, just before the execution of the optimizer.
DistributedDataParallel providesno_sync()context manager which disables gradient all-reduce for particular iteration.no_sync() should be applied to firstN-1 iterations of gradientaccumulation, the last iteration should follow the default execution andperform the required gradient all-reduce.
Match the order of layers in constructors and during the execution if usingDistributedDataParallel(find_unused_parameters=True)#
torch.nn.parallel.DistributedDataParallelwithfind_unused_parameters=True uses the order of layers and parametersfrom model constructors to build buckets forDistributedDataParallelgradient all-reduce.DistributedDataParallel overlaps all-reduce with thebackward pass. All-reduce for a particular bucket is asynchronously triggeredonly when all gradients for parameters in a given bucket are available.
To maximize the amount of overlap, the order in model constructors shouldroughly match the order during the execution. If the order doesn’t match, thenall-reduce for the entire bucket waits for the gradient which is the last toarrive, this may reduce the overlap between backward pass and all-reduce,all-reduce may end up being exposed, which slows down the training.
DistributedDataParallel withfind_unused_parameters=False (which isthe default setting) relies on automatic bucket formation based on order ofoperations encountered during the backward pass. Withfind_unused_parameters=False it’s not necessary to reorder layers orparameters to achieve optimal performance.
Load-balance workload in a distributed setting#
Load imbalance typically may happen for models processing sequential data(speech recognition, translation, language models etc.). If one devicereceives a batch of data with sequence length longer than sequence lengths forthe remaining devices, then all devices wait for the worker which finisheslast. Backward pass functions as an implicit synchronization point in adistributed setting withDistributedDataParallelbackend.
There are multiple ways to solve the load balancing problem. The core idea isto distribute workload over all workers as uniformly as possible within eachglobal batch. For example Transformer solves imbalance by forming batches withapproximately constant number of tokens (and variable number of sequences in abatch), other models solve imbalance by bucketing samples with similarsequence length or even by sorting dataset by sequence length.
Conclusion#
This tutorial covered a comprehensive set of performance optimization techniquesfor PyTorch models. The key takeaways include:
General optimizations: Enable async data loading, disable gradients forinference, fuse operations with
torch.compile, and use efficient memory formatsCPU optimizations: Leverage NUMA controls, optimize OpenMP settings, anduse efficient memory allocators
GPU optimizations: Enable Tensor cores, use CUDA graphs, enable cuDNNautotuner, and implement mixed precision training
Distributed optimizations: Use DistributedDataParallel, optimize gradientsynchronization, and balance workloads across devices
Many of these optimizations can be applied with minimal code changes and providesignificant performance improvements across a wide range of deep learning models.