Rate this Page

CUDA semantics#

Created On: Jan 16, 2017 | Last Updated On: Sep 04, 2025

torch.cuda is used to set up and run CUDA operations. It keeps track ofthe currently selected GPU, and all CUDA tensors you allocate will by default becreated on that device. The selected device can be changed with atorch.cuda.device context manager.

However, once a tensor is allocated, you can do operations on it irrespectiveof the selected device, and the results will be always placed on the samedevice as the tensor.

Cross-GPU operations are not allowed by default, with the exception ofcopy_() and other methods with copy-like functionalitysuch asto() andcuda().Unless you enable peer-to-peer memory access, any attempts to launch ops ontensors spread across different devices will raise an error.

Below you can find a small example showcasing this:

cuda=torch.device('cuda')# Default CUDA devicecuda0=torch.device('cuda:0')cuda2=torch.device('cuda:2')# GPU 2 (these are 0-indexed)x=torch.tensor([1.,2.],device=cuda0)# x.device is device(type='cuda', index=0)y=torch.tensor([1.,2.]).cuda()# y.device is device(type='cuda', index=0)withtorch.cuda.device(1):# allocates a tensor on GPU 1a=torch.tensor([1.,2.],device=cuda)# transfers a tensor from CPU to GPU 1b=torch.tensor([1.,2.]).cuda()# a.device and b.device are device(type='cuda', index=1)# You can also use ``Tensor.to`` to transfer a tensor:b2=torch.tensor([1.,2.]).to(device=cuda)# b.device and b2.device are device(type='cuda', index=1)c=a+b# c.device is device(type='cuda', index=1)z=x+y# z.device is device(type='cuda', index=0)# even within a context, you can specify the device# (or give a GPU index to the .cuda call)d=torch.randn(2,device=cuda2)e=torch.randn(2).to(cuda2)f=torch.randn(2).cuda(cuda2)# d.device, e.device, and f.device are all device(type='cuda', index=2)

TensorFloat-32 (TF32) on Ampere (and later) devices#

After Pytorch 2.9, we provide a new sets of APIs to control the TF32 behavior in a more fine-grained way, andsuggest to use the new APIs for better control.We can set float32 precision per backend and per operators. We can also override the global setting for a specific operator.

torch.backends.fp32_precision="ieee"torch.backends.cuda.matmul.fp32_precision="ieee"torch.backends.cudnn.fp32_precision="ieee"torch.backends.cudnn.conv.fp32_precision="tf32"torch.backends.cudnn.rnn.fp32_precision="tf32"

The fp32_precision can be set toieee ortf32 forcuda/cudnn.ieee fp32_precision indicate that we will useFP32 as internal computation precision.tf32 fp32_precision indicate that we will allow to useTF32 as internal computation precision.

We can override a generic setting for a specific operator if the fp32_precision is set toieee.

torch.backends.cudnn.fp32_precision="tf32"torch.backends.cudnn.conv.fp32_precision="ieee"torch.backends.cudnn.rnn.fp32_precision="ieee"

We can also override a generic setting for a specific backend if the fp32_precision is set toieee.

torch.backends.fp32_precision="tf32"torch.backends.cudnn.fp32_precision="ieee"torch.backends.cudnn.conv.fp32_precision="ieee"torch.backends.cudnn.rnn.fp32_precision="ieee"

For above 2 cases, bothtorch.backends.cudnn.conv.fp32_precision andtorch.backends.cudnn.rnn.fp32_precisionis overridden toieee.

We suggest to use the new settings for better control. And we do not support to use mix of old and new settings.

Warning

Old settings withallow_tf32 as follows is going to be deprecated. We suggest to use the above new settings forbetter control. And we do not support to use mix of old and new settings.

Starting in PyTorch 1.7, there is a new flag calledallow_tf32. This flagdefaults to True in PyTorch 1.7 to PyTorch 1.11, and False in PyTorch 1.12 and later.This flag controls whether PyTorch is allowed to use the TensorFloat32 (TF32) tensor cores,available on NVIDIA GPUs since Ampere, internally to compute matmul (matrix multipliesand batched matrix multiplies) and convolutions.

TF32 tensor cores are designed to achieve better performance on matmul and convolutions ontorch.float32 tensors by rounding input data to have 10 bits of mantissa, and accumulatingresults with FP32 precision, maintaining FP32 dynamic range.

matmuls and convolutions are controlled separately, and their corresponding flags can be accessed at:

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False# in PyTorch 1.12 and later.torch.backends.cuda.matmul.allow_tf32=True# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.torch.backends.cudnn.allow_tf32=True

The precision of matmuls can also be set more broadly (limited not just to CUDA) viaset_float32_matmul_precision().Note that besides matmuls and convolutions themselves, functions and nn modules that internally usesmatmuls or convolutions are also affected. These includenn.Linear,nn.Conv*, cdist, tensordot,affine grid and grid sample, adaptive log softmax, GRU and LSTM.

To get an idea of the precision and speed, see the example code and benchmark data (on A100) below:

a_full=torch.randn(10240,10240,dtype=torch.double,device='cuda')b_full=torch.randn(10240,10240,dtype=torch.double,device='cuda')ab_full=a_full@b_fullmean=ab_full.abs().mean()# 80.7277a=a_full.float()b=b_full.float()# Do matmul at TF32 mode.torch.backends.cuda.matmul.allow_tf32=Trueab_tf32=a@b# takes 0.016s on GA100error=(ab_tf32-ab_full).abs().max()# 0.1747relative_error=error/mean# 0.0022# Do matmul with TF32 disabled.torch.backends.cuda.matmul.allow_tf32=Falseab_fp32=a@b# takes 0.11s on GA100error=(ab_fp32-ab_full).abs().max()# 0.0031relative_error=error/mean# 0.000039

From the above example, we can see that with TF32 enabled, the speed is ~7x faster on A100, and thatrelative error compared to double precision is approximately 2 orders of magnitude larger. Note thatthe exact ratio of TF32 to single precision speed depends on the hardware generation, as propertiessuch as the ratio of memory bandwidth to compute as well as the ratio of TF32 to FP32 matmul throughputmay vary from generation to generation or model to model.If full FP32 precision is needed, users can disable TF32 by:

torch.backends.cuda.matmul.allow_tf32=Falsetorch.backends.cudnn.allow_tf32=False

To toggle the TF32 flags off in C++, you can do

at::globalContext().setAllowTF32CuBLAS(false);at::globalContext().setAllowTF32CuDNN(false);

For more information about TF32, see:

Reduced Precision Reduction in FP16 GEMMs#

(Distinct from full FP16 accumulation that is intended for hardware that has higher throughputwith FP16 accumulation than FP32 accumulation, seeFull FP16 accumulation)

fp16 GEMMs are potentially done with some intermediate reduced precision reductions (e.g., in fp16 rather than fp32). These selective reductions in precision can allow for higher performance on certain workloads (particularly those with a largek dimension) and GPU architectures at the cost of numerical precision and potential for overflow.

Some example benchmark data on V100:

[---------------------------bench_gemm_transformer--------------------------][m,k,n]|allow_fp16_reduc=True|allow_fp16_reduc=False1threads:--------------------------------------------------------------------[4096,4048,4096]|1634.6|1639.8[4096,4056,4096]|1670.8|1661.9[4096,4080,4096]|1664.2|1658.3[4096,4096,4096]|1639.4|1651.0[4096,4104,4096]|1677.4|1674.9[4096,4128,4096]|1655.7|1646.0[4096,4144,4096]|1796.8|2519.6[4096,5096,4096]|2094.6|3190.0[4096,5104,4096]|2144.0|2663.5[4096,5112,4096]|2149.1|2766.9[4096,5120,4096]|2142.8|2631.0[4096,9728,4096]|3875.1|5779.8[4096,16384,4096]|6182.9|9656.5(timesinmicroseconds).

If full precision reductions are needed, users can disable reduced precision reductions in fp16 GEMMs with:

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction=False

To toggle the reduced precision reduction flags in C++, one can do

at::globalContext().setAllowFP16ReductionCuBLAS(false);

Reduced Precision Reduction in BF16 GEMMs#

A similar flag (as above) exists for BFloat16 GEMMs.Note that this switch is set toTrue by default for BF16, if you observenumerical instability in your workload, you may wish to set it toFalse.

If reduced precision reductions are not desired, users can disable reducedprecision reductions in bf16 GEMMs with:

torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction=False

To toggle the reduced precision reduction flags in C++, one can do

at::globalContext().setAllowBF16ReductionCuBLAS(true);

Full FP16 Accmumulation in FP16 GEMMs#

Certain GPUs have increased performance when doing _all_ FP16 GEMM accumulationin FP16, at the cost of numerical precision and greater likelihood of overflow.Note that this setting only has an effect on GPUs of compute capability 7.0 (Volta)or newer.

This behavior can be enabled via:

torch.backends.cuda.matmul.allow_fp16_accumulation=True

To toggle the reduced precision reduction flags in C++, one can do

at::globalContext().setAllowFP16AccumulationCuBLAS(true);

Asynchronous execution#

By default, GPU operations are asynchronous. When you call a function thatuses the GPU, the operations areenqueued to the particular device, but notnecessarily executed until later. This allows us to execute more computationsin parallel, including operations on CPU or other GPUs.

In general, the effect of asynchronous computation is invisible to the caller,because (1) each device executes operations in the order they are queued, and(2) PyTorch automatically performs necessary synchronization when copying databetween CPU and GPU or between two GPUs. Hence, computation will proceed as ifevery operation was executed synchronously.

You can force synchronous computation by setting environment variableCUDA_LAUNCH_BLOCKING=1. This can be handy when an error occurs on the GPU.(With asynchronous execution, such an error isn’t reported until after theoperation is actually executed, so the stack trace does not show where it wasrequested.)

A consequence of the asynchronous computation is that time measurements withoutsynchronizations are not accurate. To get precise measurements, one should eithercalltorch.cuda.synchronize() before measuring, or usetorch.cuda.Eventto record times as following:

start_event=torch.cuda.Event(enable_timing=True)end_event=torch.cuda.Event(enable_timing=True)start_event.record()# Run some things hereend_event.record()torch.cuda.synchronize()# Wait for the events to be recorded!elapsed_time_ms=start_event.elapsed_time(end_event)

As an exception, several functions such asto() andcopy_() admit an explicitnon_blocking argument,which lets the caller bypass synchronization when it is unnecessary.Another exception is CUDA streams, explained below.

CUDA streams#

ACUDA stream is a linear sequence of execution that belongs to a specificdevice. You normally do not need to create one explicitly: by default, eachdevice uses its own “default” stream.

Operations inside each stream are serialized in the order they are created,but operations from different streams can execute concurrently in anyrelative order, unless explicit synchronization functions (such assynchronize() orwait_stream()) areused. For example, the following code is incorrect:

cuda=torch.device('cuda')s=torch.cuda.Stream()# Create a new stream.A=torch.empty((100,100),device=cuda).normal_(0.0,1.0)withtorch.cuda.stream(s):# sum() may start execution before normal_() finishes!B=torch.sum(A)

When the “current stream” is the default stream, PyTorch automatically performsnecessary synchronization when data is moved around, as explained above.However, when using non-default streams, it is the user’s responsibility toensure proper synchronization. The fixed version of this example is:

cuda=torch.device('cuda')s=torch.cuda.Stream()# Create a new stream.A=torch.empty((100,100),device=cuda).normal_(0.0,1.0)s.wait_stream(torch.cuda.default_stream(cuda))# NEW!withtorch.cuda.stream(s):B=torch.sum(A)A.record_stream(s)# NEW!

There are two new additions. Thetorch.cuda.Stream.wait_stream() callensures that thenormal_() execution has finished before we start runningsum(A) on a side stream. Thetorch.Tensor.record_stream() (see formore details) ensures that we do not deallocate A beforesum(A) hascompleted. You can also manually wait on the stream at some later point intime withtorch.cuda.default_stream(cuda).wait_stream(s) (note that itis pointless to wait immediately, since that will prevent the stream executionfrom running in parallel with other work on the default stream.) See thedocumentation fortorch.Tensor.record_stream() on more details on whento use one or another.

Note that this synchronization is necessary even when there is noread dependency, e.g., as seen in this example:

cuda=torch.device('cuda')s=torch.cuda.Stream()# Create a new stream.A=torch.empty((100,100),device=cuda)s.wait_stream(torch.cuda.default_stream(cuda))# STILL REQUIRED!withtorch.cuda.stream(s):A.normal_(0.0,1.0)A.record_stream(s)

Despite the computation ons not reading the contents ofA and noother uses ofA, it is still necessary to synchronize, becauseAmay correspond to memory reallocated by the CUDA caching allocator, withpending operations from the old (deallocated) memory.

Stream semantics of backward passes#

Each backward CUDA op runs on the same stream that was used for its corresponding forward op.If your forward pass runs independent ops in parallel on different streams,this helps the backward pass exploit that same parallelism.

The stream semantics of a backward call with respect to surrounding ops are the sameas for any other call. The backward pass inserts internal syncs to ensure this even whenbackward ops run on multiple streams as described in the previous paragraph.More concretely, when callingautograd.backward,autograd.grad, ortensor.backward,and optionally supplying CUDA tensor(s) as the initial gradient(s) (e.g.,autograd.backward(...,grad_tensors=initial_grads),autograd.grad(...,grad_outputs=initial_grads), ortensor.backward(...,gradient=initial_grad)),the acts of

  1. optionally populating initial gradient(s),

  2. invoking the backward pass, and

  3. using the gradients

have the same stream-semantics relationship as any group of ops:

s=torch.cuda.Stream()# Safe, grads are used in the same stream context as backward()withtorch.cuda.stream(s):loss.backward()usegrads# Unsafewithtorch.cuda.stream(s):loss.backward()usegrads# Safe, with synchronizationwithtorch.cuda.stream(s):loss.backward()torch.cuda.current_stream().wait_stream(s)usegrads# Safe, populating initial grad and invoking backward are in the same stream contextwithtorch.cuda.stream(s):loss.backward(gradient=torch.ones_like(loss))# Unsafe, populating initial_grad and invoking backward are in different stream contexts,# without synchronizationinitial_grad=torch.ones_like(loss)withtorch.cuda.stream(s):loss.backward(gradient=initial_grad)# Safe, with synchronizationinitial_grad=torch.ones_like(loss)s.wait_stream(torch.cuda.current_stream())withtorch.cuda.stream(s):initial_grad.record_stream(s)loss.backward(gradient=initial_grad)

BC note: Using grads on the default stream#

In prior versions of PyTorch (1.9 and earlier), the autograd engine always syncedthe default stream with all backward ops, so the following pattern:

withtorch.cuda.stream(s):loss.backward()usegrads

was safe as long asusegrads happened on the default stream.In present PyTorch, that pattern is no longer safe. Ifbackward()andusegrads are in different stream contexts, you must sync the streams:

withtorch.cuda.stream(s):loss.backward()torch.cuda.current_stream().wait_stream(s)usegrads

even ifusegrads is on the default stream.

Memory management#

PyTorch uses a caching memory allocator to speed up memory allocations. Thisallows fast memory deallocation without device synchronizations. However, theunused memory managed by the allocator will still show as if used innvidia-smi. You can usememory_allocated() andmax_memory_allocated() to monitor memory occupied bytensors, and usememory_reserved() andmax_memory_reserved() to monitor the total amount of memorymanaged by the caching allocator. Callingempty_cache()releases allunused cached memory from PyTorch so that those can be usedby other GPU applications. However, the occupied GPU memory by tensors will notbe freed so it can not increase the amount of GPU memory available for PyTorch.

To better understand how CUDA memory is being used over time,Understanding CUDA Memory Usage describes tools for capturing and visualizing traces of memory use.

For more advanced users, we offer more comprehensive memory benchmarking viamemory_stats(). We also offer the capability to capture acomplete snapshot of the memory allocator state viamemory_snapshot(), which can help you understand theunderlying allocation patterns produced by your code.

Optimizing memory usage withPYTORCH_CUDA_ALLOC_CONF#

Use of a caching allocator can interfere with memory checking tools such ascuda-memcheck. To debug memory errors usingcuda-memcheck, setPYTORCH_NO_CUDA_MEMORY_CACHING=1 in your environment to disable caching.

The behavior of the caching allocator can be controlled via the environment variablePYTORCH_CUDA_ALLOC_CONF.The format isPYTORCH_CUDA_ALLOC_CONF=<option>:<value>,<option2>:<value2>...Available options:

  • backend allows selecting the underlying allocator implementation.Currently, valid options arenative, which uses PyTorch’s nativeimplementation, andcudaMallocAsync, which usesCUDA’s built-in asynchronous allocator.cudaMallocAsync requires CUDA 11.4 or newer. The default isnative.backend applies to all devices used by the process, and can’t bespecified on a per-device basis.

  • max_split_size_mb prevents the native allocatorfrom splitting blocks larger than this size (in MB). This can reducefragmentation and may allow some borderline workloads to complete withoutrunning out of memory. Performance cost can range from ‘zero’ to ‘substantial’depending on allocation patterns. Default value is unlimited, i.e. all blockscan be split. Thememory_stats() andmemory_summary() methods are useful for tuning. Thisoption should be used as a last resort for a workload that is abortingdue to ‘out of memory’ and showing a large amount of inactive split blocks.max_split_size_mb is only meaningful withbackend:native.Withbackend:cudaMallocAsync,max_split_size_mb is ignored.

  • roundup_power2_divisions helps with rounding the requested allocationsize to nearest power-2 division and making better use of the blocks. Inthe native CUDACachingAllocator, the sizes are rounded up in multipleof blocks size of 512, so this works fine for smaller sizes. However, thiscan be inefficient for large near-by allocations as each will go to differentsize of blocks and reuse of those blocks are minimized. This might createlots of unused blocks and will waste GPU memory capacity. This option enablesthe rounding of allocation size to nearest power-2 division. For example, ifwe need to round-up size of 1200 and if number of divisions is 4,the size 1200 lies between 1024 and 2048 and if we do 4 divisions betweenthem, the values are 1024, 1280, 1536, and 1792. So, allocation size of 1200will be rounded to 1280 as the nearest ceiling of power-2 division.Specify a single value to apply for all allocation sizes or specify anarray of key value pairs to set power-2 division individually for eachpower of two interval. For example to set 1 division for all allocationsunder 256MB, 2 division for allocations between 256MB and 512MB, 4 divisionsfor allocations between 512MB and 1GB and 8 divisions for any larger allocations,set the knob value to: [256:1,512:2,1024:4,>:8].roundup_power2_divisions is only meaningful withbackend:native.Withbackend:cudaMallocAsync,roundup_power2_divisions is ignored.

  • max_non_split_rounding_mb will allow non-split blocks for better reuse, eg,

    a 1024MB cached block can be reused for a 512MB allocation request. In the defaultcase, we only allow up to 20MB of rounding of non-split blocks, so a 512MB blockcan only be served with between 512-532 MB size block. If we set the value of thisoption to 1024, it will allow 512-1536 MB size blocks to be used for a 512MB blockwhich increases reuse of larger blocks. This will also help in reducing the stallsin avoiding expensive cudaMalloc calls.

  • garbage_collection_threshold helps actively reclaiming unused GPU memory toavoid triggering expensive sync-and-reclaim-all operation (release_cached_blocks),which can be unfavorable to latency-critical GPU applications (e.g., servers).Upon setting this threshold (e.g., 0.8), the allocator will start reclaimingGPU memory blocks if the GPU memory capacity usage exceeds the threshold (i.e.,80% of the total memory allocated to the GPU application). The algorithm prefersto free old & unused blocks first to avoid freeing blocks that are actively beingreused. The threshold value should be between greater than 0.0 and less than 1.0.The default value is set at 1.0.

    garbage_collection_threshold is only meaningful withbackend:native.Withbackend:cudaMallocAsync,garbage_collection_threshold is ignored.

  • expandable_segments (experimental, default:False) If set toTrue, this setting instructsthe allocator to create CUDA allocations that can later be expanded to better handle caseswhere a job changing allocation sizes frequently, such as having a changing batch size.Normally for large (>2MB) allocations, the allocator calls cudaMalloc to get allocationsthat are the same size as what the user requests. In the future, parts of theseallocations can be reused for other requests if they are free. This works wellwhen the program makes many requests of exactly the same size or of sizes thateven multiples of that size. Many deep learning models follow this behavior.However, one common exception is when the batch size changes slightly from oneiteration to the next, e.g. in batched inference. When the program runsinitially with batch sizeN, it will make allocations appropriate for that size.If in the future, it runs at sizeN - 1, the existing allocations will still bebig enough. However, if it runs at sizeN + 1, then it will have to make newallocations that are slightly larger. Not all the tensors are the same size.Some might be(N + 1)*A and others(N + 1)*A*B whereA andB are some non-batchdimensions in the model. Because the allocator reuses existing allocations whenthey are big enough, some number of(N + 1)*A allocations will actually fit inthe already existingN*B*A segments, though not perfectly. As the model runs itwill partially fill up all of these segments leaving unusable free slices ofmemory at the end of these segments. The allocator at some point will need tocudaMalloc a new(N + 1)*A*B segment. If there is not enough memory, there isnow no way to recover the slices of memory that are free at the end of existingsegments. With models 50+ layers deep, this pattern might repeat 50+ timescreating many slivers.

    expandable_segments allows the allocator to create a segment initially and thenexpand its size later when more memory is needed. Instead of making one segmentper allocation, it tries to make one segment (per stream) that grows asnecessary. Now when theN + 1 case runs, the allocations will tile nicely intothe one large segment until it fills up. Then more memory is requested andappended to the end of the segment. This process does not create as many sliversof unusable memory, so it is more likely to succeed at finding this memory.

  • pinned_use_cuda_host_register option is a boolean flag that determines whether touse the CUDA API’s cudaHostRegister function for allocating pinned memory insteadof the default cudaHostAlloc. When set to True, the memory is allocated using regularmalloc and then pages are mapped to the memory before calling cudaHostRegister.This pre-mapping of pages helps reduce the lock time during the executionof cudaHostRegister.

  • pinned_num_register_threads option is only valid when pinned_use_cuda_host_registeris set to True. By default, one thread is used to map the pages. This option allowsusing more threads to parallelize the page mapping operations to reduce the overallallocation time of pinned memory. A good value for this option is 8 based onbenchmarking results.

  • pinned_use_background_threads option is a boolean flag to enable background threadfor processing events. This avoids any slow path associated with querying/processing ofevents in the fast allocation path. This feature is disabled by default.

  • graph_capture_record_stream_reuse (experimental, default:False)If set toTrue, the CUDA caching allocator will attempt to reclaim device memory duringCUDA Graph capture by using the graph topology (instead of CUDA events) to determinewhen a freed block is safe to reuse. This can reduce peak memory during long captures that freeand reallocate buffers across multiple streams, especially when the capture DAG frequentlyreaches joined frontiers. Note: Enabling this option can significantly increase the time spentcapturing the graph.

Note

Some stats reported by theCUDA memory management APIare specific tobackend:native, and are not meaningful withbackend:cudaMallocAsync.See each function’s docstring for details.

Using custom memory allocators for CUDA#

It is possible to define allocators as simple functions in C/C++ and compilethem as a shared library, the code below shows a basic allocator that justtraces all the memory operations.

#include<sys/types.h>#include<cuda_runtime_api.h>#include<iostream>// Compile with g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPICextern"C"{void*my_malloc(ssize_tsize,intdevice,cudaStream_tstream){void*ptr;cudaMalloc(&ptr,size);std::cout<<"alloc "<<ptr<<size<<std::endl;returnptr;}voidmy_free(void*ptr,ssize_tsize,intdevice,cudaStream_tstream){std::cout<<"free "<<ptr<<" "<<stream<<std::endl;cudaFree(ptr);}}

This can be used in python through thetorch.cuda.memory.CUDAPluggableAllocator.The user is responsible for supplying the path to the.so file and the nameof the alloc/free functions that match the signatures specified above.

importtorch# Load the allocatornew_alloc=torch.cuda.memory.CUDAPluggableAllocator('alloc.so','my_malloc','my_free')# Swap the current allocatortorch.cuda.memory.change_current_allocator(new_alloc)# This will allocate memory in the device using the new allocatorb=torch.zeros(10,device='cuda')
importtorch# Do an initial memory allocatorb=torch.zeros(10,device='cuda')# Load the allocatornew_alloc=torch.cuda.memory.CUDAPluggableAllocator('alloc.so','my_malloc','my_free')# This will error since the current allocator was already instantiatedtorch.cuda.memory.change_current_allocator(new_alloc)

Mixing different CUDA system allocators in the same program#

Depending on your use case,change_current_allocator() may not be what youwant to use, since it swaps the CUDA allocator for the entire program (similar toPYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync). For instance, if the swapped allocator doesn’thave caching mechanism, you will lose all the benefits of PyTorch’s CUDACachingAllocator. Instead,you can selectively mark a region of PyTorch code to use a custom allocator usingtorch.cuda.MemPool. This will let you use multiple CUDA system allocators in the samePyTorch program, along with most of the benefits of the CUDACachingAllocator (e.g. caching).Usingtorch.cuda.MemPool, you can utilize custom allocators that enable several features,such as:

  • Allocating output buffers for an all-reduce usingncclMemAlloc allocator can enable NVLinkSwitch Reductions (NVLS). This can reduce contention between overlapping compute and communicationkernels on GPU resources (SMs, and Copy Engines), especially on tensor-parallel workloads.

  • For Grace CPU based systems, allocating host outputs buffers for an all-gather usingcuMemCreateand specifyingCU_MEM_LOCATION_TYPE_HOST_NUMA can enable Extended GPU Memory (EGM) based memory transfersfrom source GPUs to the destination CPU. This accelerates the all-gather since the transferhappens over NVLinks, which otherwise would have happened over bandwidth-limited, Network InterfaceCard (NIC) links. Such an accelerated all-gather can in turn speed up model checkpointing.

  • If you are crafting a model and don’t want to think about the optimal memory placements of a memoryintensive module at first (e.g. an embedding table), or perhaps you have a module which is notperformance sensitive and doesn’t fit in the GPU, then you could just allocate that module withcudaMallocManaged with preferred CPU location and get your model working first.

Note

WhilecudaMallocManaged offers convenient automatic memory management using CUDA Unified Virtual Memory (UVM),it is not recommended for DL workloads. For DL workloads that fit in GPU memory, explicit placement consistentlyoutperforms UVM, since there are no page faults and access patterns remain predictable. When GPU memory getssaturated, UVM has to perform costly double transfers, evicting pages to CPU before bringing in new ones.

The code below showsncclMemAlloc wrapped in atorch.cuda.memory.CUDAPluggableAllocator.

importosimporttorchimporttorch.distributedasdistfromtorch.cuda.memoryimportCUDAPluggableAllocatorfromtorch.distributed.distributed_c10dimport_get_default_groupfromtorch.utilsimportcpp_extension# create allocatornccl_allocator_source="""#include <nccl.h>#include <iostream>extern "C" {void* nccl_alloc_plug(size_t size, int device, void* stream) {  std::cout << "Using ncclMemAlloc" << std::endl;  void* ptr;  ncclResult_t err = ncclMemAlloc(&ptr, size);  return ptr;}void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {  std::cout << "Using ncclMemFree" << std::endl;  ncclResult_t err = ncclMemFree(ptr);}}"""nccl_allocator_libname="nccl_allocator"nccl_allocator=torch.utils.cpp_extension.load_inline(name=nccl_allocator_libname,cpp_sources=nccl_allocator_source,with_cuda=True,extra_ldflags=["-lnccl"],verbose=True,is_python_module=False,build_directory="./",)allocator=CUDAPluggableAllocator(f"./{nccl_allocator_libname}.so","nccl_alloc_plug","nccl_free_plug").allocator()# setup distributedrank=int(os.getenv("RANK"))local_rank=int(os.getenv("LOCAL_RANK"))world_size=int(os.getenv("WORLD_SIZE"))torch.cuda.set_device(local_rank)dist.init_process_group(backend="nccl")device=torch.device(f"cuda:{local_rank}")default_pg=_get_default_group()backend=default_pg._get_backend(device)# Note: for convenience, ProcessGroupNCCL backend provides# the ncclMemAlloc allocator as backend.mem_allocatorallocator=backend.mem_allocator

You can now define a new memory pool by passing this allocator totorch.cuda.MemPool:

pool=torch.cuda.MemPool(allocator)

The pool can then be used with thetorch.cuda.use_mem_pool context manager toallocate tensors into that pool:

withtorch.cuda.use_mem_pool(pool):# tensor gets allocated with ncclMemAlloc passed in the pooltensor=torch.arange(1024*1024*2,device=device)print(f"tensor ptr on rank{rank} is{hex(tensor.data_ptr())}")# register user buffers using ncclCommRegister (called under the hood)backend.register_mem_pool(pool)# Collective uses Zero Copy NVLSdist.all_reduce(tensor[0:4])torch.cuda.synchronize()print(tensor[0:4])

Note the usage ofregister_mem_pool in the above example. This is an extra step forNVLS reductions, where the user buffers need to be registered with NCCL. A user cande-register the buffers with a similarderegister_mem_pool call.

To reclaim memory, users will first need to ensure nothing is using the pool. When noneof the tensors are holding a reference to the pool,empty_cache() willbe called internally on deletion of the pool, hence returning all the memory to the system.

deltensor,delpool

Users can optionally specify ause_on_oom bool (which is False by default) during MemPoolcreation. If true, then the CUDACachingAllocator will be able to use memory in this pool asa last resort instead of OOMing.

pool=torch.cuda.MemPool(allocator,use_on_oom=True)withtorch.cuda.use_mem_pool(pool):a=torch.randn(40*1024*1024,dtype=torch.uint8,device="cuda")dela# at the memory limit, this will succeed by using pool's memory in order to avoid the oomb=torch.randn(40*1024*1024,dtype=torch.uint8,device="cuda")

The followingtorch.cuda.MemPool.use_count() andtorch.cuda.MemPool.snapshot()APIs can be used for debugging purposes:

pool=torch.cuda.MemPool(allocator)# pool's use count should be 1 at this point as MemPool object# holds a referenceassertpool.use_count()==1nelem_1mb=1024*1024//4withtorch.cuda.use_mem_pool(pool):out_0=torch.randn(nelem_1mb,device="cuda")# pool's use count should be 2 at this point as use_mem_pool# holds a referenceassertpool.use_count()==2# pool's use count should be back to 1 at this point as use_mem_pool# released its referenceassertpool.use_count()==1withtorch.cuda.use_mem_pool(pool):# pool should have 1 segment since we made a small allocation (1 MB)# above and so the CUDACachingAllocator packed it into a 2 MB bufferassertlen(pool.snapshot())==1out_1=torch.randn(nelem_1mb,device="cuda")# pool should still have 1 segment since we made another small allocation# (1 MB) that got packed into the existing 2 MB bufferassertlen(pool.snapshot())==1out_2=torch.randn(nelem_1mb,device="cuda")# pool now should have 2 segments since the CUDACachingAllocator had# to make a new 2 MB buffer to accommodate out_2assertlen(pool.snapshot())==2

Note

  • torch.cuda.MemPool holds a reference to the pool. When you use thetorch.cuda.use_mem_pool context manager, it will also acquire another referenceto the pool. On exit of the context manager, it will release its reference. After that,ideally it should only be tensors holding references to the pool. Once the tensors releasetheir references, the use count of the pool will be 1, reflecting that only thetorch.cuda.MemPool object is holding a reference. Only at that point, can the memoryheld by the pool be returned to the system when the pool’s destructor is called usingdel.

  • torch.cuda.MemPool doesn’t currently supportexpandable_segments mode ofCUDACachingAllocator.

  • NCCL has specific requirements for a buffer to be compatible with NVLS reductions.These requirements can be broken in a dynamic workload, for instance, the buffer beingsent to NCCL by the CUDACachingAllocator might be split and hence, not correctly aligned.In those cases, NCCL can use a fallback algorithm instead of NVLS.

  • Allocators likencclMemAlloc can use more memory than requested, due to alignmentrequirements (CU_MULTICAST_GRANULARITY_RECOMMENDED,CU_MULTICAST_GRANULARITY_MINIMUM),and can cause your workload to run out of memory.

Tuning NVLink Performance with Custom Memory Allocator on H100/H200 GPUs#

In rare cases, performance of NVLink on H100/H200 GPUs can be influenced by the physical memorylayout of data, creating an opportunity for developers to tune their applications for optimalthroughput.

An example of how physical memory layout of data affects performance is when communicationkernels issue unbalanced NVLink read/write operations. In the following figure, we can seethat each warp accesses memory addresses with a consistent strided pattern in each single wave.We can have a more balanced load by tuning the stride size in the workload or we can implementa custom CUDA allocator.

_____________________________________________________________________________________________|Warp0Reading|No-reading||Warp1Reading|No-reading|...WarpNReading|No-reading|_____________________________________________________________________________________________<----------------------------->Stridesize

Such an allocator can maintain contiguous virtual memory addresses for the kernel while strategicallyarranging the mapping to physical memory addresses (e.g., through shuffling). This technique allowsdevelopers to explore different physical access patterns to find the most efficient one, unlockinghigher performance without modifying the kernel’s logic. A practical implementation of such an allocatorcan be achieved using PyTorch’s custom allocator support as mentioned before, where the malloc and freefunctions are:

// assuming a system with 8 GPUsstructCustomAllocInfo{void**devPtr;// This will be the usable virtual memory addressCUdeviceptrdptr;size_ttotalSize;// Total size of the allocated memorysize_tpadded_size;intdevice_id;std::vector<CUmemGenericAllocationHandle>handles;// Handles to physical memory allocations};// loop over pagescudaError_tcustomCudaMalloc(CustomAllocInfo*info){if(!info)returncudaErrorInvalidValue;CUdeviceptrdptr;// Handles to redundant physical memory allocations which help truncate stride pattern in physical memorystd::vector<CUmemGenericAllocationHandle>handles_redundant;size_tgranularity=0;CUmemAllocationPropprop={};intcurrentDev=info->device_id;size_ttotalSize=info->totalSize;prop.type=CU_MEM_ALLOCATION_TYPE_PINNED;prop.location.type=CU_MEM_LOCATION_TYPE_DEVICE;prop.location.id=currentDev;cuMemGetAllocationGranularity(&granularity,&prop,CU_MEM_ALLOC_GRANULARITY_MINIMUM);size_tpadded_size=ROUND_UP(totalSize,granularity);info->padded_size=padded_size;// loop over pagessize_titer_granularity=granularity*64;// 64 * granularity with shift_size = 2 worksuint32_titeration_count=(totalSize+iter_granularity-1)/iter_granularity;cuMemAddressReserve(&dptr,padded_size,0ULL,0ULL,0ULL);constintshift_size=2;for(size_ti=0;i<iteration_count;i+=shift_size){CUmemGenericAllocationHandleallocHandle[shift_size];for(intshift=0;(shift<shift_size)&&(i+shift<iteration_count);shift++){CHECK_CUDA(cuMemCreate(&allocHandle[shift],iter_granularity,&prop,0));info->handles.push_back(allocHandle[shift]);}for(intshift=0;(shift<shift_size)&&(i+shift<iteration_count);shift++){// mapping makes the shift (shift -> (shift+1)%shift_size  )CHECK_CUDA(cuMemMap(dptr+(i+shift)*iter_granularity,iter_granularity,0,allocHandle[(shift+1)%shift_size],0));setupMultiGPUAccess(dptr+(i+shift)*iter_granularity,iter_granularity,{0,1,2,3,4,5,6,7});// Enable access for all 8 GPUs}// std::cout << "Here we allocate one redundant page (2MB)..." << std::endl;// this is an extra optimization on top of the swizzling. It helps "break"// the physical access pattern even more. It can be left out if workload is already// performing at SOL with just swizzling.CUmemGenericAllocationHandleallocHandle_redundant;CHECK_CUDA(cuMemCreate(&allocHandle_redundant,granularity,&prop,0));handles_redundant.push_back(allocHandle_redundant);}*info->devPtr=(void*)dptr;info->dptr=dptr;// Release each redundant allocationfor(autohandle:handles_redundant){// std::cout << "Here we release one redundant page (2MB)..." << std::endl;CHECK_CUDA(cuMemRelease(handle));}returncudaSuccess;}voidcustomCudaFree(CustomAllocInfo*info){if(!info)return;// CHECK_CUDA(cudaSetDevice(info->device_id));CHECK_CUDA(cuMemUnmap(info->dptr,info->padded_size));// Unmap and release each allocationfor(autohandle:info->handles){CHECK_CUDA(cuMemRelease(handle));}// Unreserve the virtual address space// CHECK_CUDA(cuMemAddressFree((CUdeviceptr)*info->devPtr, info->padded_size));CHECK_CUDA(cuMemAddressFree(info->dptr,info->padded_size));}

cuBLAS workspaces#

For each combination of cuBLAS handle and CUDA stream, a cuBLAS workspace will be allocatedif that handle and stream combination executes a cuBLAS kernel that requires a workspace.In order to avoid repeatedly allocating workspaces, these workspaces are not deallocated unlesstorch._C._cuda_clearCublasWorkspaces() is called. The workspace size per allocation can bespecified via the environment variableCUBLAS_WORKSPACE_CONFIG with the format:[SIZE]:[COUNT].As an example, the default workspace size per allocation isCUBLAS_WORKSPACE_CONFIG=:4096:2:16:8which specifies a total size of2*4096+8*16KiB. To force cuBLAS to avoid using workspaces,setCUBLAS_WORKSPACE_CONFIG=:0:0.

cuFFT plan cache#

For each CUDA device, an LRU cache of cuFFT plans is used to speed up repeatedlyrunning FFT methods (e.g.,torch.fft.fft()) on CUDA tensors of same geometrywith same configuration. Because some cuFFT plans may allocate GPU memory,these caches have a maximum capacity.

You may control and query the properties of the cache of current device withthe following APIs:

  • torch.backends.cuda.cufft_plan_cache.max_size gives the capacity of thecache (default is 4096 on CUDA 10 and newer, and 1023 on older CUDA versions).Setting this value directly modifies the capacity.

  • torch.backends.cuda.cufft_plan_cache.size gives the number of planscurrently residing in the cache.

  • torch.backends.cuda.cufft_plan_cache.clear() clears the cache.

To control and query plan caches of a non-default device, you can index thetorch.backends.cuda.cufft_plan_cache object with either atorch.deviceobject or a device index, and access one of the above attributes. E.g., to setthe capacity of the cache for device1, one can writetorch.backends.cuda.cufft_plan_cache[1].max_size=10.

Just-in-Time Compilation#

PyTorch just-in-time compiles some operations, like torch.special.zeta, whenperformed on CUDA tensors. This compilation can be time consuming(up to a few seconds depending on your hardware and software)and may occur multiple times for a single operator since many PyTorch operators actuallyselect from a variety of kernels, each of which must be compiled once, depending on their input.This compilation occurs once per process, or just once if a kernel cache is used.

By default, PyTorch creates a kernel cache in $XDG_CACHE_HOME/torch/kernels ifXDG_CACHE_HOME is defined and $HOME/.cache/torch/kernels if it’s not (except on Windows,where the kernel cache is not yet supported). The caching behavior can be directlycontrolled with two environment variables. If USE_PYTORCH_KERNEL_CACHE is set to 0 then nocache will be used, and if PYTORCH_KERNEL_CACHE_PATH is set then that path will be usedas a kernel cache instead of the default location.

Best practices#

Device-agnostic code#

Due to the structure of PyTorch, you may need to explicitly writedevice-agnostic (CPU or GPU) code; an example may be creating a new tensor asthe initial hidden state of a recurrent neural network.

The first step is to determine whether the GPU should be used or not. A commonpattern is to use Python’sargparse module to read in user arguments, andhave a flag that can be used to disable CUDA, in combination withis_available(). In the following,args.device results in atorch.device object that can be used to move tensors to CPU or CUDA.

importargparseimporttorchparser=argparse.ArgumentParser(description='PyTorch Example')parser.add_argument('--disable-cuda',action='store_true',help='Disable CUDA')args=parser.parse_args()args.device=Noneifnotargs.disable_cudaandtorch.cuda.is_available():args.device=torch.device('cuda')else:args.device=torch.device('cpu')

Note

When assessing the availability of CUDA in a given environment (is_available()), PyTorch’s defaultbehavior is to call the CUDA Runtime API methodcudaGetDeviceCount. Because this call in turn initializes theCUDA Driver API (viacuInit) if it is not already initialized, subsequent forks of a process that has runis_available() will fail with a CUDA initialization error.

One can setPYTORCH_NVML_BASED_CUDA_CHECK=1 in your environment before importing PyTorch modules that executeis_available() (or before executing it directly) in order to directis_available() to attempt an NVML-based assessment (nvmlDeviceGetCount_v2). If theNVML-based assessment is successful (i.e. NVML discovery/initialization does not fail),is_available() calls will not poison subsequent forks.

If NVML discovery/initialization fails,is_available() will fallback to the standard CUDA RuntimeAPI assessment and the aforementioned fork constraint will apply.

Note that the above NVML-based CUDA availability assessment provides a weaker guarantee than the default CUDARuntime API approach (which requires CUDA initialization to succeed). In some circumstances, the NVML-based checkmay succeed while later CUDA initialization fails.

Now that we haveargs.device, we can use it to create a Tensor on thedesired device.

x=torch.empty((8,42),device=args.device)net=Network().to(device=args.device)

This can be used in a number of cases to produce device agnostic code. Belowis an example when using a dataloader:

cuda0=torch.device('cuda:0')# CUDA GPU 0fori,xinenumerate(train_loader):x=x.to(cuda0)

When working with multiple GPUs on a system, you can use theCUDA_VISIBLE_DEVICES environment flag to manage which GPUs are available toPyTorch. As mentioned above, to manually control which GPU a tensor is createdon, the best practice is to use atorch.cuda.device context manager.

print("Outside device is 0")# On device 0 (default in most scenarios)withtorch.cuda.device(1):print("Inside device is 1")# On device 1print("Outside device is still 0")# On device 0

If you have a tensor and would like to create a new tensor of the same type onthe same device, then you can use atorch.Tensor.new_* method(seetorch.Tensor).Whilst the previously mentionedtorch.* factory functions(Creation Ops) depend on the current GPU context andthe attributes arguments you pass in,torch.Tensor.new_* methods preservethe device and other attributes of the tensor.

This is the recommended practice when creating modules in which newtensors need to be created internally during the forward pass.

cuda=torch.device('cuda')x_cpu=torch.empty(2)x_gpu=torch.empty(2,device=cuda)x_cpu_long=torch.empty(2,dtype=torch.int64)y_cpu=x_cpu.new_full([3,2],fill_value=0.3)print(y_cpu)tensor([[0.3000,0.3000],[0.3000,0.3000],[0.3000,0.3000]])y_gpu=x_gpu.new_full([3,2],fill_value=-5)print(y_gpu)tensor([[-5.0000,-5.0000],[-5.0000,-5.0000],[-5.0000,-5.0000]],device='cuda:0')y_cpu_long=x_cpu_long.new_tensor([[1,2,3]])print(y_cpu_long)tensor([[1,2,3]])

If you want to create a tensor of the same type and size of another tensor, andfill it with either ones or zeros,ones_like() orzeros_like() are provided as convenient helper functions (whichalso preservetorch.device andtorch.dtype of a Tensor).

x_cpu=torch.empty(2,3)x_gpu=torch.empty(2,3)y_cpu=torch.ones_like(x_cpu)y_gpu=torch.zeros_like(x_gpu)

Use pinned memory buffers#

Warning

This is an advanced tip. If you overuse pinned memory, it can cause seriousproblems when running low on RAM, and you should be aware that pinning isoften an expensive operation.

Host to GPU copies are much faster when they originate from pinned (page-locked)memory. CPU tensors and storages expose apin_memory()method, that returns a copy of the object, with data put in a pinned region.

Also, once you pin a tensor or storage, you can use asynchronous GPU copies.Just pass an additionalnon_blocking=True argument to ato() or acuda() call. This can be usedto overlap data transfers with computation.

You can make theDataLoader return batches placed inpinned memory by passingpin_memory=True to its constructor.

Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel#

Most use cases involving batched inputs and multiple GPUs should default tousingDistributedDataParallel to utilize morethan one GPU.

There are significant caveats to using CUDA models withmultiprocessing; unless care is taken to meet the data handlingrequirements exactly, it is likely that your program will have incorrect orundefined behavior.

It is recommended to useDistributedDataParallel,instead ofDataParallel to do multi-GPU training, even ifthere is only a single node.

The difference betweenDistributedDataParallel andDataParallel is:DistributedDataParalleluses multiprocessing where a process is created for each GPU, whileDataParallel uses multithreading. By using multiprocessing,each GPU has its dedicated process, this avoids the performance overhead causedby GIL of Python interpreter.

If you useDistributedDataParallel, you could usetorch.distributed.launch utility to launch your program, seeLaunch utility.

CUDA Graphs#

A CUDA graph is a record of the work (mostly kernels and their arguments) that aCUDA stream and its dependent streams perform.For general principles and details on the underlying CUDA API, seeGetting Started with CUDA Graphs and theGraphs section of the CUDA C Programming Guide.

PyTorch supports the construction of CUDA graphs usingstream capture, which puts aCUDA stream incapture mode. CUDA work issued to a capturing stream doesn’t actuallyrun on the GPU. Instead, the work is recorded in a graph.

After capture, the graph can belaunched to run the GPU work as many times as needed.Each replay runs the same kernels with the same arguments. For pointer arguments thismeans the same memory addresses are used.By filling input memory with new data (e.g., from a new batch) before each replay,you can rerun the same work on new data.

Why CUDA Graphs?#

Replaying a graph sacrifices the dynamic flexibility of typical eager execution in exchange forgreatly reduced CPU overhead. A graph’s arguments and kernels are fixed, so a graph replayskips all layers of argument setup and kernel dispatch, including Python, C++, and CUDA driveroverheads. Under the hood, a replay submits the entire graph’s work to the GPU witha single call tocudaGraphLaunch. Kernels in a replay also execute slightly fasteron the GPU, but eliding CPU overhead is the main benefit.

You should try CUDA graphs if all or part of your network is graph-safe (usually this meansstatic shapes and static control flow, but see the otherconstraints)and you suspect its runtime is at least somewhat CPU-limited.

PyTorch API#

Warning

This API is in beta and may change in future releases.

PyTorch exposes graphs via a rawtorch.cuda.CUDAGraph classand two convenience wrappers,torch.cuda.graph andtorch.cuda.make_graphed_callables.

torch.cuda.graph is a simple, versatile context manager thatcaptures CUDA work in its context.Before capture, warm up the workload to be captured by runninga few eager iterations. Warmup must occur on a side stream.Because the graph reads from and writes to the same memory addresses in everyreplay, you must maintain long-lived references to tensors that holdinput and output data during capture.To run the graph on new input data, copy new data to the capture’s input tensor(s),replay the graph, then read the new output from the capture’s output tensor(s).Example:

g=torch.cuda.CUDAGraph()# Placeholder input used for capturestatic_input=torch.empty((5,),device="cuda")# Warmup before captures=torch.cuda.Stream()s.wait_stream(torch.cuda.current_stream())withtorch.cuda.stream(s):for_inrange(3):static_output=static_input*2torch.cuda.current_stream().wait_stream(s)# Captures the graph# To allow capture, automatically sets a side stream as the current stream in the contextwithtorch.cuda.graph(g):static_output=static_input*2# Fills the graph's input memory with new data to compute onstatic_input.copy_(torch.full((5,),3,device="cuda"))g.replay()# static_output holds the resultsprint(static_output)# full of 3 * 2 = 6# Fills the graph's input memory with more data to compute onstatic_input.copy_(torch.full((5,),4,device="cuda"))g.replay()print(static_output)# full of 4 * 2 = 8

SeeWhole-network capture,Usage with torch.cuda.amp, andUsage with multiple streamsfor realistic and advanced patterns.

make_graphed_callables is more sophisticated.make_graphed_callables accepts Python functions andtorch.nn.Modules. For each passed function or Module,it creates separate graphs of the forward-pass and backward-pass work. SeePartial-network capture.

Constraints#

A set of ops iscapturable if it doesn’t violate any of the following constraints.

Constraints apply to all work in atorch.cuda.graph context and all work in the forward and backward passesof any callable you pass totorch.cuda.make_graphed_callables().

Violating any of these will likely cause a runtime error:

  • Capture must occur on a non-default stream. (This is only a concern if you use the rawCUDAGraph.capture_begin andCUDAGraph.capture_end calls.graph andmake_graphed_callables() set a side stream for you.)

  • Ops that synchronize the CPU with the GPU (e.g.,.item() calls) are prohibited.

  • CUDA RNG operations are permitted, and when using multipletorch.Generator instances within a graph,they must be registered usingCUDAGraph.register_generator_state before graph capture.Avoid usingGenerator.get_state andGenerator.set_state during capture;instead, utilizeGenerator.graphsafe_set_state andGenerator.graphsafe_get_statefor managing generator states safely within the graph context. This ensures proper RNG operation and generator management within CUDA graphs.

Violating any of these will likely cause silent numerical errors or undefined behavior:

  • Within a process, only one capture may be underway at a time.

  • No non-captured CUDA work may run in this process (on any thread) while capture is underway.

  • CPU work is not captured. If the captured ops include CPU work, that work will be elided during replay.

  • Every replay reads from and writes to the same (virtual) memory addresses.

  • Dynamic control flow (based on CPU or GPU data) is prohibited.

  • Dynamic shapes are prohibited. The graph assumes every tensor in the captured op sequencehas the same size and layout in every replay.

  • Using multiple streams in a capture is allowed, but there arerestrictions.

Non-constraints#

  • Once captured, the graph may be replayed on any stream.

Whole-network capture#

If your entire network is capturable, you can capture and replay an entire iteration:

N,D_in,H,D_out=640,4096,2048,1024model=torch.nn.Sequential(torch.nn.Linear(D_in,H),torch.nn.Dropout(p=0.2),torch.nn.Linear(H,D_out),torch.nn.Dropout(p=0.1)).cuda()loss_fn=torch.nn.MSELoss()optimizer=torch.optim.SGD(model.parameters(),lr=0.1)# Placeholders used for capturestatic_input=torch.randn(N,D_in,device='cuda')static_target=torch.randn(N,D_out,device='cuda')# warmup# Uses static_input and static_target here for convenience,# but in a real setting, because the warmup includes optimizer.step()# you must use a few batches of real data.s=torch.cuda.Stream()s.wait_stream(torch.cuda.current_stream())withtorch.cuda.stream(s):foriinrange(3):optimizer.zero_grad(set_to_none=True)y_pred=model(static_input)loss=loss_fn(y_pred,static_target)loss.backward()optimizer.step()torch.cuda.current_stream().wait_stream(s)# captureg=torch.cuda.CUDAGraph()# Sets grads to None before capture, so backward() will create# .grad attributes with allocations from the graph's private pooloptimizer.zero_grad(set_to_none=True)withtorch.cuda.graph(g):static_y_pred=model(static_input)static_loss=loss_fn(static_y_pred,static_target)static_loss.backward()optimizer.step()real_inputs=[torch.rand_like(static_input)for_inrange(10)]real_targets=[torch.rand_like(static_target)for_inrange(10)]fordata,targetinzip(real_inputs,real_targets):# Fills the graph's input memory with new data to compute onstatic_input.copy_(data)static_target.copy_(target)# replay() includes forward, backward, and step.# You don't even need to call optimizer.zero_grad() between iterations# because the captured backward refills static .grad tensors in place.g.replay()# Params have been updated. static_y_pred, static_loss, and .grad# attributes hold values from computing on this iteration's data.

Partial-network capture#

If some of your network is unsafe to capture (e.g., due to dynamic control flow,dynamic shapes, CPU syncs, or essential CPU-side logic), you can run the unsafepart(s) eagerly and usetorch.cuda.make_graphed_callables() to graph onlythe capture-safe part(s).

By default, callables returned bymake_graphed_callables()are autograd-aware, and can be used in the training loop as direct replacementsfor the functions ornn.Modules you passed.

make_graphed_callables() internally createsCUDAGraph objects, runs warmup iterations, and maintainsstatic inputs and outputs as needed. Therefore (unlike withtorch.cuda.graph) you don’t need to handle those manually.

In the following example, data-dependent dynamic control flow means thenetwork isn’t capturable end-to-end, butmake_graphed_callables()lets us capture and run graph-safe sections as graphs regardless:

N,D_in,H,D_out=640,4096,2048,1024module1=torch.nn.Linear(D_in,H).cuda()module2=torch.nn.Linear(H,D_out).cuda()module3=torch.nn.Linear(H,D_out).cuda()loss_fn=torch.nn.MSELoss()optimizer=torch.optim.SGD(chain(module1.parameters(),module2.parameters(),module3.parameters()),lr=0.1)# Sample inputs used for capture# requires_grad state of sample inputs must match# requires_grad state of real inputs each callable will see.x=torch.randn(N,D_in,device='cuda')h=torch.randn(N,H,device='cuda',requires_grad=True)module1=torch.cuda.make_graphed_callables(module1,(x,))module2=torch.cuda.make_graphed_callables(module2,(h,))module3=torch.cuda.make_graphed_callables(module3,(h,))real_inputs=[torch.rand_like(x)for_inrange(10)]real_targets=[torch.randn(N,D_out,device="cuda")for_inrange(10)]fordata,targetinzip(real_inputs,real_targets):optimizer.zero_grad(set_to_none=True)tmp=module1(data)# forward ops run as a graphiftmp.sum().item()>0:tmp=module2(tmp)# forward ops run as a graphelse:tmp=module3(tmp)# forward ops run as a graphloss=loss_fn(tmp,target)# module2's or module3's (whichever was chosen) backward ops,# as well as module1's backward ops, run as graphsloss.backward()optimizer.step()

Usage with torch.cuda.amp#

For typical optimizers,GradScaler.step syncsthe CPU with the GPU, which is prohibited during capture. To avoid errors, either usepartial-network capture, or (if forward, loss,and backward are capture-safe) capture forward, loss, and backward but not theoptimizer step:

# warmup# In a real setting, use a few batches of real data.s=torch.cuda.Stream()s.wait_stream(torch.cuda.current_stream())withtorch.cuda.stream(s):foriinrange(3):optimizer.zero_grad(set_to_none=True)withtorch.cuda.amp.autocast():y_pred=model(static_input)loss=loss_fn(y_pred,static_target)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()torch.cuda.current_stream().wait_stream(s)# captureg=torch.cuda.CUDAGraph()optimizer.zero_grad(set_to_none=True)withtorch.cuda.graph(g):withtorch.cuda.amp.autocast():static_y_pred=model(static_input)static_loss=loss_fn(static_y_pred,static_target)scaler.scale(static_loss).backward()# don't capture scaler.step(optimizer) or scaler.update()real_inputs=[torch.rand_like(static_input)for_inrange(10)]real_targets=[torch.rand_like(static_target)for_inrange(10)]fordata,targetinzip(real_inputs,real_targets):static_input.copy_(data)static_target.copy_(target)g.replay()# Runs scaler.step and scaler.update eagerlyscaler.step(optimizer)scaler.update()

Usage with multiple streams#

Capture mode automatically propagates to any streams that sync with a capturing stream.Within capture, you may expose parallelism by issuing calls to different streams,but the overall stream dependency DAG must branch out from theinitial capturing stream after capture begins and rejoin the initial streambefore capture ends:

withtorch.cuda.graph(g):# at context manager entrance, torch.cuda.current_stream()# is the initial capturing stream# INCORRECT (does not branch out from or rejoin initial stream)withtorch.cuda.stream(s):cuda_work()# CORRECT:# branches out from initial streams.wait_stream(torch.cuda.current_stream())withtorch.cuda.stream(s):cuda_work()# rejoins initial stream before capture endstorch.cuda.current_stream().wait_stream(s)

Note

To avoid confusion for power users looking at replays in nsight systems or nvprof:Unlike eager execution, the graph interprets a nontrivial stream DAG in captureas a hint, not a command. During replay, the graph may reorganize independent opsonto different streams or enqueue them in a different order (while respecting youroriginal DAG’s overall dependencies).

Usage with DistributedDataParallel#

NCCL < 2.9.6#

NCCL versions earlier than 2.9.6 don’t allow collectives to be captured.You must usepartial-network capture,which defers allreduces to happen outside graphed sections of backward.

Callmake_graphed_callables() on graphable network sectionsbefore wrapping the network with DDP.

NCCL >= 2.9.6#

NCCL versions 2.9.6 or later allow collectives in the graph.Approaches that capture anentire backward passare a viable option, but need three setup steps.

  1. Disable DDP’s internal async error handling:

    os.environ["NCCL_ASYNC_ERROR_HANDLING"]="0"torch.distributed.init_process_group(...)
  2. Before full-backward capture, DDP must be constructed in a side-stream context:

    withtorch.cuda.stream(s):model=DistributedDataParallel(model)
  3. Your warmup must run at least 11 DDP-enabled eager iterations before capture.

Graph memory management#

A captured graph acts on the same virtual addresses every time it replays.If PyTorch frees the memory, a later replay can hit an illegal memory access.If PyTorch reassigns the memory to new tensors, the replay can corrupt the valuesseen by those tensors. Therefore, the virtual addresses used by the graph must bereserved for the graph across replays. The PyTorch caching allocator achieves thisby detecting when capture is underway and satisfying the capture’s allocationsfrom a graph-private memory pool. The private pool stays alive until itsCUDAGraph object and all tensors created during capturego out of scope.

Private pools are maintained automatically. By default, the allocator creates aseparate private pool for each capture. If you capture multiple graphs,this conservative approach ensures graph replays never corrupt each other’s values,but sometimes needlessly wastes memory.

Sharing memory across captures#

To economize the memory stashed in private pools,torch.cuda.graphandtorch.cuda.make_graphed_callables() optionally allow differentcaptures to share the same private pool.It’s safe for a set of graphs to share a private pool if you know they’ll alwaysbe replayed in the same order they were captured,and never be replayed concurrently.

torch.cuda.graph’spool argument is a hint to use a particular private pool,and can be used to share memory across graphs as shown:

g1=torch.cuda.CUDAGraph()g2=torch.cuda.CUDAGraph()# (create static inputs for g1 and g2, run warmups of their workloads...)# Captures g1withtorch.cuda.graph(g1):static_out_1=g1_workload(static_in_1)# Captures g2, hinting that g2 may share a memory pool with g1withtorch.cuda.graph(g2,pool=g1.pool()):static_out_2=g2_workload(static_in_2)static_in_1.copy_(real_data_1)static_in_2.copy_(real_data_2)g1.replay()g2.replay()

Withtorch.cuda.make_graphed_callables(), if you want to graph severalcallables and you know they’ll always run in the same order (and never concurrently)pass them as a tuple in the same order they’ll run in the live workload, andmake_graphed_callables() will capture their graphs using a sharedprivate pool.

If, in the live workload, your callables will run in an order that occasionally changes,or if they’ll run concurrently, passing them as a tuple to a single invocation ofmake_graphed_callables() is not allowed. Instead, you must callmake_graphed_callables() separately for each one.

On this page