Frequently Asked Questions#
Created On: Jun 16, 2025 | Last Updated On: Jun 16, 2025
Author:Mark Saroufim
Doestorch.compile support training?#
torch.compile supports training, using AOTAutograd to capture backwards:
The
.forward()graph andoptimizer.step()is captured byTorchDynamo’s pythonevalframefrontend.For each segment of
.forward()that torchdynamo captures, it usesAOTAutograd to generate a backward graph segment.Each pair of forward and backward graph are (optionally) min-cutpartitioned to save the minimal state between forward and backward.
The forward and backward pairs are wrapped in
autograd.functionmodules.User code calling
.backward()still triggers eager’s autograd engine,which runs eachcompiled backward graph as if it were one op, also runningany non-compiled eager ops’.backward()functions.
Do you support Distributed code?#
torch.compile supportsDistributedDataParallel (DDP).Support for other distributed training libraries is being considered.
The main reason why Distributed code is challenging with dynamo isbecause AOTAutograd unrolls both the forward and backward pass andprovides 2 graphs for backends to optimize. This is a problem fordistributed code because we’d like to ideally overlap communicationoperations with computations. Eager pytorch accomplishes this indifferent ways for DDP/FSDP- using autograd hooks, module hooks, andmodifications/mutations of module states. In a naive application ofdynamo, hooks that should run directly after an operation duringbackwards may be delayed until after the entire compiled region ofbackwards ops, due to how AOTAutograd compiled functions interact withdispatcher hooks.
The basic strategy for optimizing DDP with Dynamo is outlined indistributed.pywhere the main idea will be to graph break onDDP bucketboundaries.
When each node in DDP needs to synchronize its weights with the othernodes it organizes its gradients and parameters into buckets whichreduces communication times and allows a node to broadcast a fraction ofits gradients to other waiting nodes.
Graph breaks in distributed code mean you can expect dynamo and itsbackends to optimize the compute overhead of a distributed program butnot its communication overhead. Graph-breaks may interfere withcompilation speedups, if the reduced graph-size robs the compiler offusion opportunities. However, there are diminishing returns withincreasing graph size since most of the current compute optimizationsare local fusions. So in practice this approach may be sufficient.
Do I still need to export whole graphs?#
For the vast majority of models you probably don’t and you can usetorch.compile() as is but there are a few situations wherefull graphs are necessary and you can can ensure a full graph by simplyrunningtorch.compile(...,fullgraph=True). These situations include:
Large scale training runs, such as $250K+ that require pipeline parallelismand other advanced sharding strategies.
Inference optimizers likeTensorRTorAITemplate thatrely on fusing much more aggressively than training optimizers.
Mobile training or inference.
Future work will include tracing communication operations into graphs,coordinating these operations with compute optimizations, and optimizingthe communication operations.
Why is my code crashing?#
If your code ran just fine withouttorch.compile and started tocrash with it is enabled, then the most important first step is figuringout which part of the stack your failure occurred. To troubleshoot that,follow the steps below and only try the next step if the previous onesucceeded.
torch.compile(...,backend="eager")which only runs TorchDynamoforward graph capture and then runs the captured graph with PyTorch.If this fails then there’s an issue with TorchDynamo.torch.compile(...,backend="aot_eager")which runs TorchDynamo to capture a forward graph, and then AOTAutogradto trace the backward graph without any additional backend compilersteps. PyTorch eager will then be used to run the forward and backwardgraphs. If this fails then there’s an issue with AOTAutograd.torch.compile(...,backend="inductor")which runs TorchDynamo to capture aforward graph, and then AOTAutograd to trace the backward graph with theTorchInductor compiler. If this fails then there’s an issue with TorchInductor
Why is compilation slow?#
Dynamo Compilation– TorchDynamo has a builtin stats function forcollecting and displaying the time spent in each compilation phase.These stats can be accessed by calling
torch._dynamo.utils.compile_times()after executingtorch._dynamo. By default, this returns a stringrepresentation of the compile times spent in each TorchDynamo function by name.Inductor Compilation– TorchInductor has a builtin stats and trace functionfor displaying time spent in each compilation phase, output code, outputgraph visualization and IR dump.
envTORCH_COMPILE_DEBUG=1pythonrepro.py.This is a debugging tool designed to make it easier to debug/understand theinternals of TorchInductor with an output that will look something likethisEach file in that debug trace can be enabled/disabled viatorch._inductor.config.trace.*. The profile and the diagram are bothdisabled by default since they are expensive to generate. See theexample debug directoryoutputfor more examples.Excessive RecompilationWhen TorchDynamo compiles a function (or part of one), it makes certainassumptions about locals and globals in order to allow compileroptimizations, and expresses these assumptions as guards that checkparticular values at runtime. If any of these guards fail, Dynamo willrecompile that function (or part) up to
torch._dynamo.config.recompile_limittimes. If your program ishitting the cache limit, you will first need to determine which guard isfailing and what part of your program is triggering it. TheUseTORCH_TRACE/tlparseorTORCH_LOGS=recompilesto trace the root of the issue, checktorch.compile Troubleshooting for more details.
Why are you recompiling in production?#
In some cases, you may not want unexpected compiles after a program haswarmed up. For example, if you are serving production traffic in alatency critical application. For this, TorchDynamo provides analternate mode where prior compiled graphs are used, but no new ones aregenerated:
frozen_toy_example=dynamo.run(toy_example)frozen_toy_example(torch.randn(10),torch.randn(10))
How are you speeding up my code?#
There are 3 major ways to accelerate PyTorch code:
Kernel fusion via vertical fusions which fuse sequential operations to avoidexcessive read/writes. For example, fuse 2 subsequent cosines means youcan can do 1 read 1 write instead 2 reads 2 writes 2. Horizontal fusion:the simplest example being batching where a single matrix is multipliedwith a batch of examples but the more general scenario is a grouped GEMMwhere a group of matrix multiplications are scheduled together
Out of order execution: A general optimization for compilers, by looking aheadat the exact data dependencies within a graph we can decide on the mostopportune time to execute a node and which buffers can be reused
Automatic work placement: Similar of the out of order execution point,but by matching nodes of a graph to resources like physical hardware ormemory we can design an appropriate schedule
The above are general principles for accelerating PyTorch code butdifferent backends will each make different tradeoffs on what tooptimize. For example Inductor first takes care of fusing whatever itcan and only then generatesTritonkernels.
Triton in addition offers speedups because of automatic memorycoalescing, memory management and scheduling within each StreamingMultiprocessor and has been designed to handle tiled computations.
However, regardless of the backend you use it’s best to use a benchmarkand see approach so try out the PyTorch profiler, visually inspect thegenerated kernels and try to see what’s going on for yourself.
Why am I not seeing speedups?#
Graph Breaks#
The main reason you won’t see the speedups you’d like to by using dynamois excessive graph breaks. So what’s a graph break?
Given a program like:
defsome_fun(x):...torch.compile(some_fun)(x)...
Torchdynamo will attempt to compile all of the torch/tensor operationswithinsome_fun() into a single FX graph, but it may fail to captureeverything into one graph.
Some graph break reasons are insurmountable to TorchDynamo like callinginto a C extension other than PyTorch is invisible to TorchDynamo, andcould do arbitrary things without TorchDynamo being able to introducenecessary guards to ensure that the compiled program would be safe to reuse.
To maximize performance, it’s important to have as few graph breaksas possible.
Identifying the cause of a graph break#
To identify all graph breaks in a program and the associated reasons forthe breaks,torch._dynamo.explain can be used. This tool runsTorchDynamo on the supplied function and aggregates the graph breaksthat are encountered. Here is an example usage:
importtorchimporttorch._dynamoasdynamodeftoy_example(a,b):x=a/(torch.abs(a)+1)print("woo")ifb.sum()<0:b=b*-1returnx*bexplanation=dynamo.explain(toy_example)(torch.randn(10),torch.randn(10))print(explanation)"""Graph Count: 3Graph Break Count: 2Op Count: 5Break Reasons: Break Reason 1: Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False User Stack: <FrameSummary file foo.py, line 5 in toy_example> Break Reason 2: Reason: generic_jump TensorVariable() User Stack: <FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>Ops per Graph: ...Out Guards: ..."""
To throw an error on the first graph break encountered you candisable python fallbacks by usingfullgraph=True, this should befamiliar if you’ve worked with export based compilers.
deftoy_example(a,b):...torch.compile(toy_example,fullgraph=True,backend=<compiler>)(a,b)
Why didn’t my code recompile when I changed it?#
If you enabled dynamic shapes by settingenvTORCHDYNAMO_DYNAMIC_SHAPES=1pythonmodel.py then your codewon’t recompile on shape changes. We’ve added support for dynamic shapeswhich avoids recompilations in the case when shapes vary by less than afactor of 2. This is especially useful in scenarios like varying imagesizes in CV or variable sequence length in NLP. In inference scenariosit’s often not possible to know what a batch size will be beforehandbecause you take what you can get from different client apps.
In general, TorchDynamo tries very hard not to recompile thingsunnecessarily so if for example TorchDynamo finds 3 graphs and yourchange only modified one graph then only that graph will recompile. Soanother tip to avoid potentially slow compilation times is to warmup amodel by compiling it once after which subsequent compilations will bemuch faster. Cold start compile times is still a metric we trackvisibly.
Why am I getting incorrect results?#
Accuracy issues can also be minified if you set the environment variableTORCHDYNAMO_REPRO_LEVEL=4, it operates with a similar git bisectmodel and a full repro might be something likeTORCHDYNAMO_REPRO_AFTER="aot"TORCHDYNAMO_REPRO_LEVEL=4 the reasonwe need this is downstream compilers will codegen code whether it’sTriton code or the C++ backend, the numerics from those downstreamcompilers can be different in subtle ways yet have dramatic impact onyour training stability. So the accuracy debugger is very useful for usto detect bugs in our codegen or with a backend compiler.
If you’d like to ensure that random number generation is the same across both torchand triton then you can enabletorch._inductor.config.fallback_random=True
Why am I getting OOMs?#
Dynamo is still an alpha product so there’s a few sources of OOMs and ifyou’re seeing an OOM try disabling the following configurations in thisorder and then open an issue on GitHub so we can solve the root problem1. If you’re using dynamic shapes try disabling them, we’ve disabledthem by default:envTORCHDYNAMO_DYNAMIC_SHAPES=0pythonmodel.py 2.CUDA graphs with Triton are enabled by default in inductor but removingthem may alleviate some OOM issues:torch._inductor.config.triton.cudagraphs=False.
Doestorch.func work withtorch.compile (forgrad andvmap transforms)?#
Applying atorch.func transform to a function that usestorch.compiledoes work:
importtorch@torch.compiledeff(x):returntorch.sin(x)defg(x):returntorch.grad(f)(x)x=torch.randn(2,3)g(x)
Callingtorch.func transform inside of a function handled withtorch.compile#
Compilingtorch.func.grad withtorch.compile#
importtorchdefwrapper_fn(x):returntorch.func.grad(lambdax:x.sin().sum())(x)x=torch.randn(3,3,3)grad_x=torch.compile(wrapper_fn)(x)
Compilingtorch.vmap withtorch.compile#
importtorchdefmy_fn(x):returntorch.vmap(lambdax:x.sum(1))(x)x=torch.randn(3,3,3)output=torch.compile(my_fn)(x)
Compiling functions besides the ones which are supported (escape hatch)#
For other transforms, as a workaround, usetorch._dynamo.allow_in_graph
allow_in_graph is an escape hatch. If your code does not work withtorch.compile, which introspects Python bytecode, but you believe itwill work via a symbolic tracing approach (likejax.jit), then useallow_in_graph.
By usingallow_in_graph to annotate a function, you must make sureyour code meets the following requirements:
All outputs in your function only depend on the inputs anddo not depend on any captured Tensors.
Your function is functional. That is, it does not mutate any state. This maybe relaxed; we actually support functions that appear to be functional fromthe outside: they may have in-place PyTorch operations, but may not mutateglobal state or inputs to the function.
Your function does not raise data-dependent errors.
importtorch@torch.compiledeff(x):returntorch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)x=torch.randn(2,3)f(x)
A common pitfall is usingallow_in_graph to annotate a function thatinvokes annn.Module. This is because the outputs now depend on theparameters of thenn.Module. To get this to work, usetorch.func.functional_call to extract the module state.
Does NumPy work withtorch.compile?#
Starting in 2.1,torch.compile understands native NumPy programs thatwork on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorchto NumPy and back viax.numpy(),torch.from_numpy, and related functions.
Which NumPy features doestorch.compile support?#
NumPy withintorch.compile follows NumPy 2.0 pre-release.
Generally,torch.compile is able to trace through most NumPy constructions,and when it cannot, it falls back to eager and lets NumPy execute that piece ofcode. Even then, there are a few features wheretorch.compile semanticsslightly deviate from those of NumPy:
NumPy scalars: We model them as 0-D arrays. That is,
np.float32(3)returnsa 0-D array undertorch.compile. To avoid a graph break, it is best to use this 0-Darray. If this breaks your code, you can workaround this by casting the NumPy scalarto the relevant Python scalar typebool/int/float.Negative strides:
np.flipand slicing with a negative step return a copy.Type promotion: NumPy’s type promotion will change in NumPy 2.0. The new rulesare described inNEP 50.
torch.compileimplements NEP 50 rather than the current soon-to-be deprecated rules.{tril,triu}_indices_from/{tril,triu}_indicesreturn arrays rather than a tuple of arrays.
There are other features for which we do not support tracing and we gracefullyfallback to NumPy for their execution:
Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays.
Long dtypes
np.float128/np.complex256and some unsigned dtypesnp.uint16/np.uint32/np.uint64.ndarraysubclasses.Masked arrays.
Esoteric ufunc machinery like
axes=[(n,k),(k,m)->(n,m)]and ufunc methods (e.g.,np.add.reduce).Sorting / ordering
complex64/complex128arrays.NumPy
np.poly1dandnp.polynomial.Positional
out1,out2args in functions with 2 or more returns (out=tupledoes work).__array_function__,__array_interface__and__array_wrap__.ndarray.ctypesattribute.
Can I compile NumPy code usingtorch.compile?#
Of course you do!torch.compile understands NumPy code natively, and treats itas if it were PyTorch code. To do so, simply wrap NumPy code with thetorch.compiledecorator.
importtorchimportnumpyasnp@torch.compiledefnumpy_fn(X:np.ndarray,Y:np.ndarray)->np.ndarray:returnnp.sum(X[:,:,None]*Y[:,None,:],axis=(-2,-1))X=np.random.randn(1024,64)Y=np.random.randn(1024,64)Z=numpy_fn(X,Y)assertisinstance(Z,np.ndarray)
Executing this example with the environment variableTORCH_LOGS=output_code, we can seethattorch.compile was able to fuse the multiplication and the sum into one C++ kernel.It was also able to execute them in parallel using OpenMP (native NumPy is single-threaded).This can easily make your NumPy coden times faster, wheren is the number of coresin your processor!
Tracing NumPy code this way also supports graph breaks within the compiled code.
Can I execute NumPy code on CUDA and compute gradients viatorch.compile?#
Yes you can! To do so, you may simply execute your code within atorch.device("cuda")context. Consider the example
importtorchimportnumpyasnp@torch.compiledefnumpy_fn(X:np.ndarray,Y:np.ndarray)->np.ndarray:returnnp.sum(X[:,:,None]*Y[:,None,:],axis=(-2,-1))X=np.random.randn(1024,64)Y=np.random.randn(1024,64)withtorch.device("cuda"):Z=numpy_fn(X,Y)assertisinstance(Z,np.ndarray)
In this example,numpy_fn will be executed in CUDA. For this to bepossible,torch.compile automatically movesX andY from CPUto CUDA, and then it moves the resultZ from CUDA to CPU. If we areexecuting this function several times in the same program run, we may wantto avoid all these rather expensive memory copies. To do so, we just needto tweak ournumpy_fn so that it accepts cuda Tensors and returns tensors.We can do so by usingtorch.compiler.wrap_numpy:
@torch.compile(fullgraph=True)@torch.compiler.wrap_numpydefnumpy_fn(X,Y):returnnp.sum(X[:,:,None]*Y[:,None,:],axis=(-2,-1))X=torch.randn(1024,64,device="cuda")Y=torch.randn(1024,64,device="cuda")Z=numpy_fn(X,Y)assertisinstance(Z,torch.Tensor)assertZ.device.type=="cuda"
Here, we explicitly create the tensors in CUDA memory, and pass them to thefunction, which performs all the computations on the CUDA device.wrap_numpy is in charge of marking anytorch.Tensor input as an inputwithnp.ndarray semantics at atorch.compile level. Marking tensorsinside the compiler is a very cheap operation, so no data copy or data movementhappens during runtime.
Using this decorator, we can also differentiate through NumPy code!
@torch.compile(fullgraph=True)@torch.compiler.wrap_numpydefnumpy_fn(X,Y):returnnp.mean(np.sum(X[:,:,None]*Y[:,None,:],axis=(-2,-1)))X=torch.randn(1024,64,device="cuda",requires_grad=True)Y=torch.randn(1024,64,device="cuda")Z=numpy_fn(X,Y)assertisinstance(Z,torch.Tensor)Z.backward()# X.grad now holds the gradient of the computationprint(X.grad)
We have been usingfullgraph=True as graph break are problematic in this context.When a graph break occurs, we need to materialize the NumPy arrays. Since NumPy arraysdo not have a notion ofdevice orrequires_grad, this information is lost duringa graph break.
We cannot propagate gradients through a graph break, as the graph break code may executearbitrary code that don’t know how to differentiate. On the other hand, in the case ofthe CUDA execution, we can work around this problem as we did in the first example, byusing thetorch.device("cuda") context manager:
@torch.compile@torch.compiler.wrap_numpydefnumpy_fn(X,Y):prod=X[:,:,None]*Y[:,None,:]print("oops, a graph break!")returnnp.sum(prod,axis=(-2,-1))X=torch.randn(1024,64,device="cuda")Y=torch.randn(1024,64,device="cuda")withtorch.device("cuda"):Z=numpy_fn(X,Y)assertisinstance(Z,torch.Tensor)assertZ.device.type=="cuda"
During the graph break, the intermediary tensors still need to be moved to CPU, but when thetracing is resumed after the graph break, the rest of the graph is still traced on CUDA.Given this CUDA <> CPU and CPU <> CUDA movement, graph breaks are fairly costly in the NumPycontext and should be avoided, but at least they allow tracing through complex pieces of code.
How do I debug NumPy code undertorch.compile?#
Debugging JIT compiled code is challenging, given the complexity of moderncompilers and the daunting errors that they raise.The torch.compile troubleshooting doccontains a few tips and tricks on how to tackle this task.
If the above is not enough to pinpoint the origin of the issue, there are stilla few other NumPy-specific tools we can use. We can discern whether the bugis entirely in the PyTorch code by disabling tracing through NumPy functions:
fromtorch._dynamoimportconfigconfig.trace_numpy=False
If the bug lies in the traced NumPy code, we can execute the NumPy code eagerly (withouttorch.compile)using PyTorch as a backend by importingimporttorch._numpyasnp.This should just be used fordebugging purposes and is in no way areplacement for the PyTorch API, as it ismuch less performant and, as aprivate API,may change without notice. At any rate,torch._numpy is aPython implementation of NumPy in terms of PyTorch and it is used internally bytorch.compile totransform NumPy code into Pytorch code. It is rather easy to read and modify,so if you find any bug in it feel free to submit a PR fixing it or simply openan issue.
If the program does work when importingtorch._numpyasnp, chances arethat the bug is in TorchDynamo. If this is the case, please feel free to open an issuewith aminimal reproducer.
Itorch.compile some NumPy code and I did not see any speed-up.#
The best place to start is thetutorial with general advice for how to debug these sort of torch.compile issues.
Some graph breaks may happen because of the use of unsupported features. SeeWhich NumPy features does torch.compile support?. More generally, it is useful to keep in mindthat some widely used NumPy features do not play well with compilers. Forexample, in-place modifications make reasoning difficult within the compiler andoften yield worse performance than their out-of-place counterparts.As such, it is best to avoidthem. Same goes for the use of theout= parameter. Instead, preferout-of-place ops and lettorch.compile optimize the memory use. Same goesfor data-dependent ops like masked indexing through boolean masks, ordata-dependent control flow likeif orwhile constructions.
Which API to use for fine grain tracing?#
In some cases, you might need to exclude small parts of your code from thetorch.compile compilations. This section provides some of the answers andyou can find more information inTorchDynamo APIs for fine-grained tracing.
How do I graph break on a function?#
Graph break on a function is not enough to sufficiently express what you wantPyTorch to do. You need to be more specific about your use case. Some of themost common use cases you might want to consider:
If you want to disable compilation on this function frame and the recursivelyinvoked frames, use
torch._dynamo.disable.If you want a particular operator, such as
fbgemmto use the eager mode,usetorch._dynamo.disallow_in_graph.
Some of the uncommon use cases include:
If you want to disable TorchDynamo on the function frame but enable it backon the recursively invoked frames – use
torch._dynamo.disable(recursive=False).If you want to prevent inlining of a function frame – use
torch._dynamo.graph_breakat the beginning of the function you want to prevent inlining.
What’s the difference betweentorch._dynamo.disable andtorch._dynamo.disallow_in_graph#
Disallow-in-graph works at the level of operators, or more specifically,the operators that you see in the TorchDynamo extracted graphs.
Disable works at the function frame level and decides if TorchDynamoshould look into the function frame or not.
What’s the difference betweentorch._dynamo.disable andtorch._dynamo_skip#
Note
torch._dynamo_skip is deprecated.
You most likely needtorch._dynamo.disable. But in an unlikely scenario, youmight need even finer control. Suppose you want to disable the tracing on justthea_fn function, but want to continue the tracing back inaa_fn andab_fn. The image below demonstrates this use case:

In this case, you can usetorch._dynamo.disable(recursive=False).In previous versions, this functionality was provided bytorch._dynamo.skip.This is now supported by therecursive flag insidetorch._dynamo.disable.