Rate this Page

CUDAGraph Trees#

Created On: May 19, 2023 | Last Updated On: Jul 30, 2025

Background#

CUDAGraph#

For a longer background on CUDAGraphs, readaccelerating pytorch with CUDAGraphs.

CUDA Graphs, which made its debut in CUDA 10, let a series of CUDA kernels to be defined and encapsulated as a single unit, i.e., a graph of operations, rather than a sequence of individually-launched operations. It provides a mechanism to launch multiple GPU operations through a single CPU operation, and hence reduces the launching overheads.

CUDA Graphs can give large speedups, especially for models with high CPU overhead or small compute. There are a number of limitations from requiring the same kernels to be run with the same arguments and dependencies, and memory addresses.

  • Control Flow is not possible

  • Kernels which trigger host to device syncs (such as .item()) errors

  • All input arguments to kernels are fixed to what they were recorded

  • CUDA Memory addresses are fixed, however the values of the memory at those addresses can change

  • No Essential CPU ops or CPU side effects

PyTorch CUDAGraph Integration#

PyTorch provides aconvenience wrapper around CUDAGraphs that handles a couple of tricky interactions with PyTorch’s caching allocator.

The CachingAllocator uses a separate memory pool for all the new allocations. During CUDAGraph recording, memory is accounted for, allocated, and freed exactly as during eager run. On replay, just the kernels are invoked, and there are no changes to the allocator. Subsequent to initial recording, the allocator does not know which memory is actively being used in user programs.

Using a separate memory pool between eager allocations and cudagraph allocations may increase the memory of your program if there is substantial memory allocated to both.

Make Graphed Callables#

Make Graphed Callables is a PyTorch Abstraction to share a single memory pool over a series of callables. Graphed Callables takes advantage of the fact that on CUDA Graph recording, memory is exactly accounted for by the caching allocator to safely share memory between separate CUDA Graph recordings. In each invocation, outputs are preserved as live memory, preventing one callable from overwriting the live memory of another. Graphed Callables can only be invoked in a single order; memory addresses from the first run are burned into the second, and so forth.

TorchDynamo Previous CUDA Graphs Integration#

Running withcudagraph_trees=False does not reuse memory across separate graph captures, which can lead to large memory regressions. Even for a model that has no graph breaks, this has issues. The forward and backward are separate graph captures, so the memory pools for forward and backward are not shared. In particular, memory for activations that are saved in the forward cannot be reclaimed in the backward.

CUDAGraph Trees Integration#

Like Graph Callables, CUDA Graph Trees use a single memory pool across all graph captures. However, instead of requiring a single sequence of invocations, CUDA Graph Trees create separate trees of CUDA Graph captures. Let’s take a look at an illustrative example:

@torch.compile(mode="reduce-overhead")deffoo(x):# GRAPH 1y=x*x*x# graph break triggered hereify.sum()>0:# GRAPH 2z=y**yelse:# GRAPH 3z=(y.abs()**y.abs())torch._dynamo.graph_break()# GRAPH 4returnz*torch.rand_like(z)# the first run warms up each graph, which does things like CuBlas or Triton benchmarkingfoo(torch.arange(0,10,device="cuda"))# The second run does a CUDA Graph recording, and replays itfoo(torch.arange(0,10,device="cuda"))# Finally we hit the optimized, CUDA Graph replay pathfoo(torch.arange(0,10,device="cuda"))

In this example, there are two separate paths that we make through the function: 1 -> 2 -> 4, or 1 -> 3 -> 4.

We share all of the memory in a single memory pool between separate recordings by building up a tape of CUDA Graph recordings, in this instance, 1 -> 2 -> 4. We add invariants to ensure that memory is always in the same location as it were recorded, and no live tensors exist in user programs that might be overwritten.

  • Same constraints from CUDA Graphs apply: same kernels must be invoked with the same arguments (static sizes, addresses, etc)

  • The same pattern of memory must be observed between recording and replay: if a tensor output of one graph dies subsequent to another graph during recording, it must also do so during replay.

  • Live memory in the CUDA pool forces a dependence between two recordings

  • These recordings can only be invoked in a single order 1 - > 2 -> 4

All of the memory is shared in a single memory pool, so there is no additional memory overhead compared to eager. Now, what happens if we were to hit a new path and run Graph 3?

Graph 1 gets replayed, and then we hit Graph 3, which we have not yet recorded. On graph replays, the private memory pool is not updated, so y is not reflected in the allocator. Without care, we would overwrite it. To support reusing the same memory pool after replaying other graphs, we checkpoint the memory pool back to its state at the end of graph 1. Now that our live tensors are reflected in the caching allocator, we are safe to run a new graph.

First, we would hit the optimized, CUDAGraph.replay() path that we have already recorded in graph 1. Then we would hit Graph 3. Just as before, we will need to warm up the graph once before recording. On the warmup run, the memory addresses are not fixed, so graph 4 will also fallback to the inductor, non-cudagraph invocation.

The second time we hit graph 3 we are warmed up and ready to record. We record graph 3 and then record graph 4 again since the input memory addresses have changed. This creates a tree of CUDA Graph recordings. A CUDA Graph Tree!

1/ \\23 \\   \\44

Input Mutation Support#

Input mutation function refers to a function conducting in-place writes to an input tensor,as illustrated below:

deffoo(x,y):# mutates input xx.add_(1)returnx+y

Input mutation functions generally lead to challenges for CUDAGraph Trees. Due to the staticCUDA memory address requirement from CUDAGraph, for each input tensor x, CUDAGraph Trees mayallocate a static memory address x’. During execution, CUDAGraph Trees first copy the inputtensor x to the static memory address x’, and then replay the recorded CUDAGraph. For inputmutation function, x’ is in-place updated, which is not reflected on the input tensor x sincex and x’ reside on different CUDA memory addresses.

A closer look at input mutation functions reveals that there are three types of inputs:

  • inputs from eager: These tensors we assume will vary input tensor addresses fromexecution to execution. Because cudagraphs freeze memory addresses, we need to copy theseinputs to a static address tensor prior to graph recording and execution.

  • Parameters and buffers: These tensors we assume (and runtime-check) have the same tensoraddresses on every execution. We do not need to copy over their contents because the recordedmemory address will be the same as the executed memory address.

  • Tensors which are prior outputs from CUDAGraph Trees: Because the output tensor addressesof a cudagraph are fixed, if we run CUDAGraph1, then run CUDAGraph2, the inputs which came fromCUDAGraph1 into CUDAGraph2 will have a fixed memory address. These inputs, like parameters andbuffers, do not require copying over to a static address tensor. We check to make sure thatthese inputs are stable at runtime, and if they’re not we will re-record.

CUDAGraph Trees support input mutation on parameters and buffers, and tensors which are prioroutputs from CUDAGraph Trees. For mutation on inputs from eager, CUDAGraph Trees will run thefunction without CUDAGraph and emitskipping due to mutated inputs log. The following exampleshows CUDAGraph Trees’ support for tensors which are prior outputs from CUDAGraph Trees.

importtorch@torch.compile(mode="reduce-overhead")deffoo(x):returnx+1@torch.compile(mode="reduce-overhead")defmut(x):returnx.add_(2)# Enable input mutation supporttorch._inductor.config.triton.cudagraph_support_input_mutation=Trueforiinrange(3):torch.compiler.cudagraph_mark_step_begin()inp=torch.rand([4],device="cuda")# CUDAGraph is applied since `foo` does not mutate `inp`tmp=foo(inp)# Although `mut` mutates `tmp`, which is an output of a CUDAGraph# managed function. So CUDAGraph is still applied.mut(tmp)torch.compiler.cudagraph_mark_step_begin()inp=torch.rand([4],device="cuda")tmp=foo(inp)# While `tmp` is a CUDAGraph Tree managed function's output, `tmp.clone()`# is not. So CUDAGraph is not applied to `mut` and there is a log# `skipping cudagraphs due to mutated inputs`mut(tmp.clone())

To enable CUDAGraph Trees for a function mutating inputs from eager, please re-writethe function to avoid input mutation.

Note
Enable input mutation support by settingtorch._inductor.config.cudagraph_support_input_mutation = True for “reduce-overhead” mode.

Dynamic Shape Support#

Dynamic shapemeans that an input tensor has different shapes across function calls. Since CUDAGraphrequires fixed tensor addresses, CUDAGraph Trees re-record CUDAGraph for every uniqueshape of an input tensor. This leads to multiple CUDAGraphs for a single inductor graph.When there are limited shapes (e.g., batch sizes in inference), it is profitable tore-record CUDAGraphs. However, if input tensor shapes change frequently or even onevery invocation, re-recording CUDAGraph may not be profitable. Nvidia uses 64 KB ofdevice memory per kernel launch in CUDAGraph, up until CUDA 12.4 and Driver Version 550+.This memory cost can be significant with many CUDAGraph re-recordings.

For functions with frequently changing input tensor shapes, we suggest padding inputtensors to a few fixed tensor shapes to still enjoy benefits from CUDAGraph. In addition,settingtorch._inductor.config.triton.cudagraph_skip_dynamic_graphs=Trueallows to skip cudagraphing functions with dynamic shape inputs and only cudagraphingfunctions with static input tensor shapes.

NCCL Support#

CUDAGraph Trees support functions with nccl operators. While CUDAGraph Trees perform per-devicerecord for CUDAGraph, NCCL support allows cross-device communication.

@torch.compile(mode="reduce-overhead")deffunc(x):y=x*xy=torch.distributed.all_reduce(y,op=torch.distributed.ReduceOp.SUM)x=torch.nn.functional.silu(x)returnx*y

Reasons for Skipping CUDAGraph#

Since CUDAGraph has requirements such as static input tensor addresses and not supportingCPU operators, CUDAGraph Trees check whether a function satisfies these requirements andmay skip CUDAGraph when necessary. Here, we list common reasons for skipping CUDAGraph.

  • Input mutation: CUDAGraph Trees skip functions that in-place mutates eager input.In-place mutating parameters and buffers, or output tensors from CUDAGraph Tree managedfunctions are still supported. Please seeInput Mutation Support section for more details.

  • CPU operators: Functions containing CPU operator are skipped. Please split thefunction into multiple functions and apply CUDAGraph Trees on functions with only GPU operators.

  • Multi-device operators: A function is skipped if it contains operators on multipledevices. Currently, CUDAGraph is applied on a per-device basis. Please use supportedlibraries such as NCCL for cross-device communication. Please seeNCCL Supportsection for more details.

  • Free unbacked symbols: Free unbacked symbols usually happen duringdynamic shapes.CUDAGraph Trees currently record a CUDAGraph for every unique input tensor shapes.Please seeDynamic Shape Support for more details.

  • CUDAGraph-unsafe custom ops: Some custom ops may include cudagraph unsafe ops, which causes cudagraph to be skipped. Please seeCUDAGraph Unsafe Custom Ops for more details.

  • Incompatible operators: CUDAGraph Trees skip a function if it contain incompatibleoperators. Please replace these operators in a function with supported operators. Weshow an exhaustive list of incompatible operators:

aten._fused_moving_avg_obs_fq_helper.defaultaten._fused_moving_avg_obs_fq_helper_functional.defaultaten.multinomial.defaultfbgemm.dense_to_jagged.defaultfbgemm.jagged_to_padded_dense.defaultrun_and_save_rng_staterun_with_rng_stateaten._local_scalar_denseaten._assert_scalar

The following operators are incompatible whentorch.are_deterministic_algorithms_enabled().

aten._fused_moving_avg_obs_fq_helper.defaultaten._fused_moving_avg_obs_fq_helper_functional.defaultaten.multinomial.defaultfbgemm.dense_to_jagged.defaultfbgemm.jagged_to_padded_dense.defaultrun_and_save_rng_staterun_with_rng_stateaten._local_scalar_denseaten._assert_scalar

CUDAGraph Unsafe Custom Ops#

Custom ops are assumed to be safe for CUDAGraph by default. However, some custom ops may include unsupported ops such as cpu ops. Since custom op are treated as black boxes by the compiler, users must explicitly mark these ops as unsafe for CUDAGraph by setting thetorch._C.Tag.cudagraph_unsafe tag, as demonstrated in the example below. When a function contains cudagraph-unsafe custom ops, it will be skipped by CUDAGraph unlessCUDAGraph partition is enabled.

@torch.library.custom_op("mylib::modify",mutates_args=(),tags=(torch._C.Tag.cudagraph_unsafe,),)defmodify(pic:torch.Tensor)->torch.Tensor:pic1=pic+1pic1_cpu=(pic1.cpu()+1)*2returnpic1_cpu.cuda()+pic@modify.register_fakedef_(pic):returntorch.empty_like(pic)

CUDAGraph Partition#

As we discussed earlier, CUDAGraph does not support some ops (e.g., cpu ops) which may limit its adoption. CUDAGraph partition is a compiler solution that automatically splits off these ops, reorders ops to reduce the number of partitions, and applies CUDAGraph to each partition individually. Please settorch._inductor.config.graph_partition=True to enable CUDAGraph partition.

Consider the following example wherex andy are gpu inputs buty_cpu is a cpu tensor. Without graph partition, this function must be skipped due to cpu ops. With graph partition, the CPU ops are split off, and the remaining GPU ops are cudagraphified, resulting in two separate separate CUDAGraphs.

deff(x,y):x1=x+1y1=y+1y_cpu=y1.cpu()+1z=x@yreturnx1+y1+z+y_cpu.cuda()

Currently, CUDAGraph partition supports splitting off the following types of ops:

  • Non-GPU Ops: Popular examples include computation on cpu tensors.

  • Device Copy Ops: Data transfers between devices, such as they1.cpu() in the example above.

  • Control Flow Ops:Control flow ops are split off since they are not yet supported by CUDAGraph.

  • CUDAGraph Unsafe Custom Ops: Custom ops tagged withtorch._C.Tag.cudagraph_unsafe are split off. SeeCUDAGraph Unsafe Custom Ops section for details.

  • Unbacked Symints: Please refer toDynamic Shape Support section for more information.

Limitations#

Because CUDA Graph fixes memory addresses, CUDA Graphs do not have a great way of handling live tensors from a previous invocation.

Let’s say we are benchmarking running inference with the following code:

importtorch@torch.compile(mode="reduce-overhead")defmy_model(x):y=torch.matmul(x,x)returnyx=torch.randn(10,10,device="cuda")y1=my_model(x)y2=my_model(x)print(y1)# RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.

In the Separate CUDA Graph implementation, the output from the first invocation will be overwritten by the second invocation. In CUDAGraphTrees, we don’t want to add unintended dependencies between iterations that would cause us to not hit the hot path, nor do we want we wantto prematurely free memory from a prior invocation. Our heuristics are in inference we start a new iteration on each invocation fortorch.compile, and in training we do the same so long as there is not a pending backward that has not been invoked. If those heuristicsare wrong, you can mark the start of a new iteration withtorch.compiler.mark_step_begin(), or clonetensors of a prior iteration (outside of torch.compile) before you begin the next run.

Comparisons#

Footguns

Separate CudaGraph

CUDAGraph Trees

Memory Can Increase

On each graph compilation (new sizes, etc.)

If you are also running non-cudagraph memory

Recordings

On any new invocation of a graph

Will re-record on any new, unique path you take through your program

Footguns

Invocation of one graph will overwrite prior invocation

Cannot persist memory between separate runs through your model - one training loop training, or one run of inference