Rate this Page

Dynamo Overview#

Created On: Jun 13, 2025 | Last Updated On: Jun 13, 2025

Before you read this section, readtorch.compiler.

TorchDynamo (or simply Dynamo) is a Python-level Just-In-Time (JIT) compiler designed to makeunmodified PyTorch programs faster. Dynamo hooks into the frame evaluationAPI in CPython (PEP 523) todynamically modify Python bytecode right before it is executed. Itrewrites Python bytecode to extract sequences of PyTorchoperations into anFX Graphwhich is then compiled with a customizable backend.It creates this FX Graph through bytecode analysis and is designed tomix Python execution with compiled backends to get the best of bothworlds — usability and performance.

Dynamo makes it easy to experiment with different compilerbackends to make PyTorch code faster with a single line decoratortorch._dynamo.optimize() which is wrapped for convenience bytorch.compile()

The following diagram demonstrates how PyTorch works withtorch.compileand without it:

_images/TorchDynamo.png

TorchInductor is one of the backendssupported byDynamo GraphintoTriton for GPUs orC++/OpenMP for CPUs. We have atraining performance dashboardthat provides performance comparison for different training backends. You can readmore in theTorchInductor post on PyTorchdev-discuss.

For an in-depth overview, read the sections below, watch the deep-dive video,and check out the dev-discuss topics.

Dynamo Internals#

Author:Jason Ansel andKaichao You

This section will go over some of the Dynamo internals and willdemonstrate how Dynamo works under the hood.

What is a guard?#

Dynamo operates just-in-time and specializes graphs based ondynamic properties. Below is a basic example of how to use Dynamo.One can decorate a function or a method usingtorchdynamo.optimize to enableDynamo optimization:

fromtypingimportListimporttorchfromtorchimport_dynamoastorchdynamodefmy_compiler(gm:torch.fx.GraphModule,example_inputs:List[torch.Tensor]):print("my_compiler() called with FX graph:")gm.graph.print_tabular()returngm.forward# return a python callable@torchdynamo.optimize(my_compiler)deftoy_example(a,b):x=a/(torch.abs(a)+1)ifb.sum()<0:b=b*-1returnx*bfor_inrange(100):toy_example(torch.randn(10),torch.randn(10))

For example, the first graph above has the followingguards:

GUARDS:hasattr(L['a'],'_dynamo_dynamic_indices')==Falsehasattr(L['b'],'_dynamo_dynamic_indices')==Falseutils_device.CURRENT_DEVICE==None___skip_backend_check()or___current_backend()==___lookup_backend(140355900538256)check_tensor(L['a'],Tensor,DispatchKeySet(CPU,BackendSelect,ADInplaceOrView,AutogradCPU),torch.float32,device=None,requires_grad=False,size=[10],stride=[1])check_tensor(L['b'],Tensor,DispatchKeySet(CPU,BackendSelect,ADInplaceOrView,AutogradCPU),torch.float32,device=None,requires_grad=False,size=[10],stride=[1])

If any of those guards fail, the graph will be recaptured andrecompiled. The interesting guard there ischeck_tensor, whichchecks the followingtorch.Tensor properties:

  • Python class of the tensor (tensor subclassing, etc)

  • dtype

  • device

  • requires_grad

  • dispatch_key (with thread-local includes/excludes applied)

  • ndim

  • sizes*

  • strides*

The full specialization mode allows the backend compiler to assume anentirely static graph. Unfortunately, most backends require this.Operators which return dynamic shapes will trigger a graph break whennot in dynamic shape mode.

What is Dynamo doing?#

If you want to understand better what Dynamo is doing, you can run your code with:

TORCH_LOGS="+dynamo,guards,bytecode"

If you are not familiar with Python bytecode, you can add a decompiler hookto decompile the bytecode into human-readable source code. One availabletool isdepyf. If you don’t havedepyf already installed, runpipinstalldepyf. Then, add thefollowing code to install decompilation hooks before you run any code.

importdepyfdepyf.install()

This code triggers useful (but spammy) printouts.

For example, the printouts for the first graph in thetoy_exampleare:

__compiled_fn_0<eval_with_key>.1opcodenametargetargskwargs--------------------------------------------------------------------------------------------------placeholderaa(){}placeholderbb(){}call_functionabs_1<built-inmethodabsoftypeobjectat0x7f9ca082f8a0>(a,){}call_functionadd<built-infunctionadd>(abs_1,1){}call_functiontruediv<built-infunctiontruediv>(a,add){}call_methodsum_1sum(b,){}call_functionlt<built-infunctionlt>(sum_1,0){}outputoutputoutput((truediv,lt),){}ORIGINALBYTECODEtoy_exampleexample.pyline12140LOAD_FAST0(a)2LOAD_GLOBAL0(torch)4LOAD_METHOD1(abs)6LOAD_FAST0(a)8CALL_METHOD110LOAD_CONST1(1)12BINARY_ADD14BINARY_TRUE_DIVIDE16STORE_FAST2(x)1518LOAD_FAST1(b)20LOAD_METHOD2(sum)22CALL_METHOD024LOAD_CONST2(0)26COMPARE_OP0(<)28POP_JUMP_IF_FALSE19(to38)1630LOAD_FAST1(b)32LOAD_CONST3(-1)34BINARY_MULTIPLY36STORE_FAST1(b)17>>38LOAD_FAST2(x)40LOAD_FAST1(b)42BINARY_MULTIPLY44RETURN_VALUEMODIFIEDBYTECODEtoy_exampleexample.pyline12120LOAD_GLOBAL3(__compiled_fn_0)2LOAD_FAST0(a)4LOAD_FAST1(b)6CALL_FUNCTION28UNPACK_SEQUENCE210STORE_FAST2(x)12POP_JUMP_IF_FALSE12(to24)14LOAD_GLOBAL4(__resume_at_30_1)16LOAD_FAST1(b)18LOAD_FAST2(x)20CALL_FUNCTION222RETURN_VALUE>>24LOAD_GLOBAL5(__resume_at_38_2)26LOAD_FAST1(b)28LOAD_FAST2(x)30CALL_FUNCTION232RETURN_VALUEpossiblesourcecode:deftoy_example(a,b):__temp_1=__compiled_fn_0(a,b)x=__temp_1[0]if__temp_1[1]:return__resume_at_30_1(b,x)return__resume_at_38_2(b,x)Ifyoufindthedecompiledcodeiswrong,pleasesubmitanissueathttps://github.com/youkaichao/depyf/issues.

At the top you can see the FX graph.Next, you see the original bytecode of the function, followed by themodified bytecode generated by Dynamo, and the decompiled sourcecode for reference. Finally, you see the guards which we covered above.

In the modified bytecode,__compiled_fn_0 is the return value ofmy_compiler() (the compiled graph).__resume_at_30_1 and__resume_at_38_2 are both generated continuation functions that pickup execution after a graph break (at bytecode offsets 30 and 38). Eachof these functions take the form:

__resume_at_<offset>:...restorestackstateifneeded...JUMP_ABSOLUTE<offset>intotoy_example...originalbytecodeoftoy_example...

By generating thisresume_at function, we force the remainder of thefunction to be executed in a new Python frame which recursivelytriggers Dynamo to restart its capture once execution reaches thatpoint for the first time.

How to inspect artifacts generated by Dynamo?#

To inspect the artifacts generated by Dynamo, there is an APItorch._dynamo.eval_frame._debug_get_cache_entry_list that retrieves compiled code and guards out of a function’s__code__ object. A compiled function can have several cache entries, and each cache entry consists a generated function to check guards, and atypes.CodeType object to keep the code to be executed if the guarding conditions are satisfied.

fromtorch._dynamo.eval_frameimport_debug_get_cache_entry_list,innermost_fncache_entries=_debug_get_cache_entry_list(innermost_fn(toy_example))cache_entry=cache_entries[0]guard,code=cache_entry.check_fn,cache_entry.code# the guard takes the local variables of an input frame, and tells whether a re-compilation should be triggered.importdisdis.dis(guard)dis.dis(code)

If you know Python bytecode, you can understand the above output.

For the guard function, there is no need to inspect the bytecode. We can directly access its guarding conditions:

forcode_partinguard.code_parts:print(code_part)

The output is:

___guarded_code.valid___check_global_state()hasattr(L['a'],'_dynamo_dynamic_indices')==Falsehasattr(L['b'],'_dynamo_dynamic_indices')==Falseutils_device.CURRENT_DEVICE==None___skip_backend_check()or___current_backend()==___lookup_backend(140215810860528)___check_tensors(L['a'],L['b'],tensor_check_names=tensor_check_names)

Only when all the conditions are satisfied, the guard function returns true, and the compiled code is executed.

For the compiled code, we cannot directly access its source but have to decompile it.

fromdepyfimportdecompileprint(decompile(code))

The output is:

deftoy_example(a,b):__temp_1=__compiled_fn_0(a,b)x=__temp_1[0]if__temp_1[1]:return__resume_at_30_1(b,x)return__resume_at_38_2(b,x)

Some names referenced in the code are:

  • Compiled functions, stored in the global namespace of the module containing the original functiontoy_example. These include names like__compiled_fn_0 /__resume_at_30_1 /__resume_at_38_2.

  • Closure variables used for checking guards. The names can be accessed fromguard.__code__.co_freevars, and the values are stored inguard.__closure__. These include names like___guarded_code /___is_grad_enabled /___are_deterministic_algorithms_enabled /___is_torch_function_enabled /utils_device /___check_tensors /tensor_check_names.

  • ArgumentL of theguard function. This is a dict mapping the name of arguments oftoy_example to its values. This is only available when the function is called, where the frame evaluation API comes into play. In short,L is adict with structure of{'a':value_a,'b':value_b}. Therefore, you can see the code usesL['a'] to refer to the input variablea.

The graph break is shown in the code of compiledtoy_example, where we have to use Python interpreter to select the following graph to execute.

Note that we pass a simplemy_compiler function as the backend compiler, therefore the subgraph code__resume_at_38_2,__resume_at_30_1, and__compiled_fn_0 remain Python code. This can also be inspected (please ignore the function name, and only use the function signature and function body code):

print("source code of __compiled_fn_0:")print(innermost_fn(__compiled_fn_0).__self__.code)print("="*60)print("source code of __resume_at_30_1:")print(decompile(__resume_at_30_1))print("="*60)print("source code of __resume_at_38_2:")print(decompile(__resume_at_38_2))
sourcecodeof__compiled_fn_0:defforward(self,L_a_:torch.Tensor,L_b_:torch.Tensor):l_a_=L_a_l_b_=L_b_abs_1=torch.abs(l_a_)add=abs_1+1;abs_1=Nonetruediv=l_a_/add;l_a_=add=Nonesum_1=l_b_.sum();l_b_=Nonelt=sum_1<0;sum_1=Nonereturn(truediv,lt)# To see more debug info, please use ``graph_module.print_readable()``============================================================sourcecodeof__resume_at_30_1:def<resumeintoy_example>(b,x):b=b*-1returnx*b============================================================sourcecodeof__resume_at_38_2:def<resumeintoy_example>(b,x):returnx*b

However, if we use other backends like the built-ininductor, the subgraph code will be compiled CUDA kernels for GPU or C++ code for CPU.

To summarize, the compiled code is conceptually equivalent to the code below:

defcompiled_example(a,b):L={'a':a,'b':b}forguard,codeinget_cache_entries():ifguard(L):returncode(a,b)recompile_and_add_another_cache_entry()

The following diagram demonstrates howtorch.compile transforms and optimizes user-written code: it first extracts computation graphs from the user-written function, and compiles these graphs into optimized functions, then assembles them into a new function, which is functionally equivalent to the user-written code but optimized to have a good computation speed.

_images/flowchart.jpg

To learn more about how all this is implemented internally, seeDynamo Deep-Dive.