Rate this Page
torch.compile">

Note

Go to the endto download the full example code.

Introduction totorch.compile#

Created On: Mar 15, 2023 | Last Updated: Nov 25, 2025 | Last Verified: Nov 05, 2024

Author: William Wen

torch.compile is the new way to speed up your PyTorch code!torch.compile makes PyTorch code run faster byJIT-compiling PyTorch code into optimized kernels,while requiring minimal code changes.

torch.compile accomplishes this by tracing throughyour Python code, looking for PyTorch operations.Code that is difficult to trace will result agraph break, which are lost optimization opportunities, ratherthan errors or silent incorrectness.

torch.compile is available in PyTorch 2.0 and later.

This introduction covers basictorch.compile usageand demonstrates the advantages oftorch.compile overour previous PyTorch compiler solution,TorchScript.

For an end-to-end example on a real model, check out ourend-to-end torch.compile tutorial.

To troubleshoot issues and to gain a deeper understanding of how to applytorch.compile to your code, check outthe torch.compile programming model.

Contents

Required pip dependencies for this tutorial

  • torch>=2.0

  • numpy

  • scipy

System requirements- A C++ compiler, such asg++- Python development package (python-devel/python-dev)

Basic Usage#

We turn on some logging to help us to see whattorch.compile is doingunder the hood in this tutorial.The following code will print out the PyTorch ops thattorch.compile traced.

importtorchtorch._logging.set_logs(graph_code=True)

torch.compile is a decorator that takes an arbitrary Python function.

deffoo(x,y):a=torch.sin(x)b=torch.cos(y)returna+bopt_foo1=torch.compile(foo)print(opt_foo1(torch.randn(3,3),torch.randn(3,3)))@torch.compiledefopt_foo2(x,y):a=torch.sin(x)b=torch.cos(y)returna+bprint(opt_foo2(torch.randn(3,3),torch.randn(3,3)))
TRACED GRAPH ===== __compiled_fn_1_6164a48f_b777_4abf_8845_e18d0a75488c ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):        l_x_ = L_x_        l_y_ = L_y_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:74 in foo, code: a = torch.sin(x)        a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_);  l_x_ = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:75 in foo, code: b = torch.cos(y)        b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_);  l_y_ = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:76 in foo, code: return a + b        add: "f32[3, 3][3, 1]cpu" = a + b;  a = b = None        return (add,)tensor([[ 0.0718, -0.1230,  1.5293],        [-0.1492,  1.3393,  1.5498],        [-0.0053, -0.7290,  0.4285]])TRACED GRAPH ===== __compiled_fn_3_70c057f0_78eb_4cb4_bed8_84ccd4347302 ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):        l_x_ = L_x_        l_y_ = L_y_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:85 in opt_foo2, code: a = torch.sin(x)        a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_);  l_x_ = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:86 in opt_foo2, code: b = torch.cos(y)        b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_);  l_y_ = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:87 in opt_foo2, code: return a + b        add: "f32[3, 3][3, 1]cpu" = a + b;  a = b = None        return (add,)tensor([[ 1.2133,  1.2286,  1.7181],        [ 1.3848,  1.5102,  0.4478],        [-0.4852,  0.5372,  1.4708]])

torch.compile is applied recursively, so nested function callswithin the top-level compiled function will also be compiled.

definner(x):returntorch.sin(x)@torch.compiledefouter(x,y):a=inner(x)b=torch.cos(y)returna+bprint(outer(torch.randn(3,3),torch.randn(3,3)))
TRACED GRAPH ===== __compiled_fn_5_2af6e2ac_da55_41f2_8a81_294797bdbe4c ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):        l_x_ = L_x_        l_y_ = L_y_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:98 in inner, code: return torch.sin(x)        a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_);  l_x_ = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:104 in outer, code: b = torch.cos(y)        b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_);  l_y_ = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:105 in outer, code: return a + b        add: "f32[3, 3][3, 1]cpu" = a + b;  a = b = None        return (add,)tensor([[1.8491, 1.2302, 1.9142],        [1.1232, 1.1618, 1.0098],        [0.3962, 1.5476, 0.4530]])

We can also optimizetorch.nn.Module instances by either callingits.compile() method or by directlytorch.compile-ing the module.This is equivalent totorch.compile-ing the module’s__call__ method(which indirectly callsforward).

t=torch.randn(10,100)classMyModule(torch.nn.Module):def__init__(self):super().__init__()self.lin=torch.nn.Linear(3,3)defforward(self,x):returntorch.nn.functional.relu(self.lin(x))mod1=MyModule()mod1.compile()print(mod1(torch.randn(3,3)))mod2=MyModule()mod2=torch.compile(mod2)print(mod2(torch.randn(3,3)))
TRACED GRAPH ===== __compiled_fn_7_c7a8ff22_adab_4860_9361_8db4207ca87c ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_self_modules_lin_parameters_weight_: "f32[3, 3][3, 1]cpu", L_self_modules_lin_parameters_bias_: "f32[3][1]cpu", L_x_: "f32[3, 3][3, 1]cpu"):        l_self_modules_lin_parameters_weight_ = L_self_modules_lin_parameters_weight_        l_self_modules_lin_parameters_bias_ = L_self_modules_lin_parameters_bias_        l_x_ = L_x_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:126 in forward, code: return torch.nn.functional.relu(self.lin(x))        linear: "f32[3, 3][3, 1]cpu" = torch._C._nn.linear(l_x_, l_self_modules_lin_parameters_weight_, l_self_modules_lin_parameters_bias_);  l_x_ = l_self_modules_lin_parameters_weight_ = l_self_modules_lin_parameters_bias_ = None        relu: "f32[3, 3][3, 1]cpu" = torch.nn.functional.relu(linear);  linear = None        return (relu,)tensor([[0.0000, 0.6799, 0.0347],        [0.1565, 0.0000, 0.1398],        [0.3636, 0.0000, 0.0199]], grad_fn=<CompiledFunctionBackward>)tensor([[0.6172, 0.5201, 0.4229],        [0.4646, 0.0000, 0.0000],        [0.2731, 0.2055, 0.1687]], grad_fn=<CompiledFunctionBackward>)

Demonstrating Speedups#

Now let’s demonstrate howtorch.compile speeds up a simple PyTorch example.For a demonstration on a more complex model, see ourend-to-end torch.compile tutorial.

deffoo3(x):y=x+1z=torch.nn.functional.relu(y)u=z*2returnuopt_foo3=torch.compile(foo3)# Returns the result of running `fn()` and the time it took for `fn()` to run,# in seconds. We use CUDA events and synchronization for the most accurate# measurements.deftimed(fn):start=torch.cuda.Event(enable_timing=True)end=torch.cuda.Event(enable_timing=True)start.record()result=fn()end.record()torch.cuda.synchronize()returnresult,start.elapsed_time(end)/1000inp=torch.randn(4096,4096).cuda()print("compile:",timed(lambda:opt_foo3(inp))[1])print("eager:",timed(lambda:foo3(inp))[1])
TRACED GRAPH ===== __compiled_fn_9_765405f9_0051_438e_ae84_884f157ab300 ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_x_: "f32[4096, 4096][4096, 1]cuda:0"):        l_x_ = L_x_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:147 in foo3, code: y = x + 1        y: "f32[4096, 4096][4096, 1]cuda:0" = l_x_ + 1;  l_x_ = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:148 in foo3, code: z = torch.nn.functional.relu(y)        z: "f32[4096, 4096][4096, 1]cuda:0" = torch.nn.functional.relu(y);  y = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:149 in foo3, code: u = z * 2        u: "f32[4096, 4096][4096, 1]cuda:0" = z * 2;  z = None        return (u,)compile: 0.39984454345703124eager: 0.030113792419433592

Notice thattorch.compile appears to take a lot longer to completecompared to eager. This is becausetorch.compile takes extra time to compilethe model on the first few executions.torch.compile re-uses compiled code whever possible,so if we run our optimized model several more times, we shouldsee a significant improvement compared to eager.

# turn off logging for now to prevent spamtorch._logging.set_logs(graph_code=False)eager_times=[]foriinrange(10):_,eager_time=timed(lambda:foo3(inp))eager_times.append(eager_time)print(f"eager time{i}:{eager_time}")print("~"*10)compile_times=[]foriinrange(10):_,compile_time=timed(lambda:opt_foo3(inp))compile_times.append(compile_time)print(f"compile time{i}:{compile_time}")print("~"*10)importnumpyasnpeager_med=np.median(eager_times)compile_med=np.median(compile_times)speedup=eager_med/compile_medassertspeedup>1print(f"(eval) eager median:{eager_med}, compile median:{compile_med}, speedup:{speedup}x")print("~"*10)
eager time 0: 0.0009062399864196777eager time 1: 0.0008652160167694091eager time 2: 0.0008673279881477356eager time 3: 0.0008683519959449768eager time 4: 0.0008663039803504944eager time 5: 0.0008622080087661744eager time 6: 0.0008642560243606567eager time 7: 0.0008652799725532531eager time 8: 0.0008622080087661744eager time 9: 0.0008693760037422181~~~~~~~~~~compile time 0: 0.0005109760165214538compile time 1: 0.0003665919899940491compile time 2: 0.0003717119991779327compile time 3: 0.0003676159977912903compile time 4: 0.0003665919899940491compile time 5: 0.0003624959886074066compile time 6: 0.00036351999640464784compile time 7: 0.00036351999640464784compile time 8: 0.00036351999640464784compile time 9: 0.000364544004201889~~~~~~~~~~(eval) eager median: 0.0008657919764518738, compile median: 0.00036556799709796904, speedup: 2.368347293321327x~~~~~~~~~~

And indeed, we can see that running our model withtorch.compileresults in a significant speedup. Speedup mainly comes from reducing Python overhead andGPU read/writes, and so the observed speedup may vary on factors such as modelarchitecture and batch size. For example, if a model’s architecture is simpleand the amount of data is large, then the bottleneck would beGPU compute and the observed speedup may be less significant.

To see speedups on a real model, check out ourend-to-end torch.compile tutorial.

Benefits over TorchScript#

Why should we usetorch.compile over TorchScript? Primarily, theadvantage oftorch.compile lies in its ability to handlearbitrary Python code with minimal changes to existing code.

Compare to TorchScript, which has a tracing mode (torch.jit.trace) anda scripting mode (torch.jit.script). Tracing mode is susceptible tosilent incorrectness, while scripting mode requires significant code changesand will raise errors on unsupported Python code.

For example, TorchScript tracing silently fails on data-dependent control flow(theifx.sum()<0: line below)because only the actual control flow path is traced.In comparison,torch.compile is able to correctly handle it.

deff1(x,y):ifx.sum()<0:return-yreturny# Test that `fn1` and `fn2` return the same result, given the same arguments `args`.deftest_fns(fn1,fn2,args):out1=fn1(*args)out2=fn2(*args)returntorch.allclose(out1,out2)inp1=torch.randn(5,5)inp2=torch.randn(5,5)traced_f1=torch.jit.trace(f1,(inp1,inp2))print("traced 1, 1:",test_fns(f1,traced_f1,(inp1,inp2)))print("traced 1, 2:",test_fns(f1,traced_f1,(-inp1,inp2)))compile_f1=torch.compile(f1)print("compile 1, 1:",test_fns(f1,compile_f1,(inp1,inp2)))print("compile 1, 2:",test_fns(f1,compile_f1,(-inp1,inp2)))print("~"*10)
/var/lib/workspace/intermediate_source/torch_compile_tutorial.py:239: TracerWarning:Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!traced 1, 1: Truetraced 1, 2: Falsecompile 1, 1: Truecompile 1, 2: True~~~~~~~~~~

TorchScript scripting can handle data-dependent control flow,but it can require major code changes and will raise errors when unsupported Pythonis used.

In the example below, we forget TorchScript type annotations and we receivea TorchScript error because the input type for argumenty, anint,does not match with the default argument type,torch.Tensor.In comparison,torch.compile works without requiring any type annotations.

importtracebackastbtorch._logging.set_logs(graph_code=True)deff2(x,y):returnx+yinp1=torch.randn(5,5)inp2=3script_f2=torch.jit.script(f2)try:script_f2(inp1,inp2)except:tb.print_exc()compile_f2=torch.compile(f2)print("compile 2:",test_fns(f2,compile_f2,(inp1,inp2)))print("~"*10)
Traceback (most recent call last):  File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 288, in <module>    script_f2(inp1, inp2)RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.Position: 1Value: 3Declaration: f2(Tensor x, Tensor y) -> TensorCast error details: Unable to cast 3 to TensorTRACED GRAPH ===== __compiled_fn_18_e1844d21_404e_4896_83ec_012999b217ef ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_x_: "f32[5, 5][5, 1]cpu"):        l_x_ = L_x_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:280 in f2, code: return x + y        add: "f32[5, 5][5, 1]cpu" = l_x_ + 3;  l_x_ = None        return (add,)compile 2: True~~~~~~~~~~

Graph Breaks#

The graph break is one of the most fundamental concepts withintorch.compile.It allowstorch.compile to handle arbitrary Python code by interruptingcompilation, running the unsupported code, then resuming compilation.The term “graph break” comes from the fact thattorch.compile attemptsto capture and optimize the PyTorch operation graph. When unsupported Python code is encountered,then this graph must be “broken”.Graph breaks result in lost optimization opportunities, which may still be undesirable,but this is better than silent incorrectness or a hard crash.

Let’s look at a data-dependent control flow example to better see how graph breaks work.

defbar(a,b):x=a/(torch.abs(a)+1)ifb.sum()<0:b=b*-1returnx*bopt_bar=torch.compile(bar)inp1=torch.ones(10)inp2=torch.ones(10)opt_bar(inp1,inp2)opt_bar(inp1,-inp2)
TRACED GRAPH ===== __compiled_fn_20_c96b9e77_88a8_4ac0_837d_7b322a19a080 ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):        l_a_ = L_a_        l_b_ = L_b_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:312 in bar, code: x = a / (torch.abs(a) + 1)        abs_1: "f32[10][1]cpu" = torch.abs(l_a_)        add: "f32[10][1]cpu" = abs_1 + 1;  abs_1 = None        x: "f32[10][1]cpu" = l_a_ / add;  l_a_ = add = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:313 in bar, code: if b.sum() < 0:        sum_1: "f32[][]cpu" = l_b_.sum();  l_b_ = None        lt: "b8[][]cpu" = sum_1 < 0;  sum_1 = None        return (lt, x)TRACED GRAPH ===== __compiled_fn_24_c2140ada_e6d9_4cd8_ae20_9c5893eb4e6e ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_x_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):        l_x_ = L_x_        l_b_ = L_b_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b        mul: "f32[10][1]cpu" = l_x_ * l_b_;  l_x_ = l_b_ = None        return (mul,)TRACED GRAPH ===== __compiled_fn_26_12c25458_1ba4_4608_92a0_6e0ab2eeaf6f ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):        l_b_ = L_b_        l_x_ = L_x_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:314 in torch_dynamo_resume_in_bar_at_313, code: b = b * -1        b: "f32[10][1]cpu" = l_b_ * -1;  l_b_ = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b        mul_1: "f32[10][1]cpu" = l_x_ * b;  l_x_ = b = None        return (mul_1,)tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,        0.5000])

The first time we runbar, we see thattorch.compile traced 2 graphscorresponding to the following code (noting thatb.sum()<0 is False):

  1. x=a/(torch.abs(a)+1);b.sum()

  2. returnx*b

The second time we runbar, we take the other branch of the if statementand we get 1 traced graph corresponding to the codeb=b*-1;returnx*b.We do not see a graph ofx=a/(torch.abs(a)+1) outputted the second timesincetorch.compile cached this graph from the first run and re-used it.

Let’s investigate by example how TorchDynamo would step throughbar.Ifb.sum()<0, then TorchDynamo would run graph 1, letPython determine the result of the conditional, then rungraph 2. On the other hand, ifnotb.sum()<0, then TorchDynamowould run graph 1, let Python determine the result of the conditional, thenrun graph 3.

We can see all graph breaks by usingtorch._logging.set_logs(graph_breaks=True).

# Reset to clear the torch.compile cachetorch._dynamo.reset()opt_bar(inp1,inp2)opt_bar(inp1,-inp2)
TRACED GRAPH ===== __compiled_fn_28_0ffdb3a3_72e0_4202_b5e6_2788105d64db ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):        l_a_ = L_a_        l_b_ = L_b_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:312 in bar, code: x = a / (torch.abs(a) + 1)        abs_1: "f32[10][1]cpu" = torch.abs(l_a_)        add: "f32[10][1]cpu" = abs_1 + 1;  abs_1 = None        x: "f32[10][1]cpu" = l_a_ / add;  l_a_ = add = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:313 in bar, code: if b.sum() < 0:        sum_1: "f32[][]cpu" = l_b_.sum();  l_b_ = None        lt: "b8[][]cpu" = sum_1 < 0;  sum_1 = None        return (lt, x)TRACED GRAPH ===== __compiled_fn_32_b79cdfe5_c13d_4ee7_9bb3_98253e016ed0 ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_x_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):        l_x_ = L_x_        l_b_ = L_b_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b        mul: "f32[10][1]cpu" = l_x_ * l_b_;  l_x_ = l_b_ = None        return (mul,)TRACED GRAPH ===== __compiled_fn_34_6890a3e5_6693_42bc_8261_6e6e92dc8144 ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):        l_b_ = L_b_        l_x_ = L_x_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:314 in torch_dynamo_resume_in_bar_at_313, code: b = b * -1        b: "f32[10][1]cpu" = l_b_ * -1;  l_b_ = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b        mul_1: "f32[10][1]cpu" = l_x_ * b;  l_x_ = b = None        return (mul_1,)tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,        0.5000])

In order to maximize speedup, graph breaks should be limited.We can force TorchDynamo to raise an error upon the first graphbreak encountered by usingfullgraph=True:

# Reset to clear the torch.compile cachetorch._dynamo.reset()opt_bar_fullgraph=torch.compile(bar,fullgraph=True)try:opt_bar_fullgraph(torch.randn(10),torch.randn(10))except:tb.print_exc()
Traceback (most recent call last):  File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 360, in <module>    opt_bar_fullgraph(torch.randn(10), torch.randn(10))  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 841, in compile_wrapper    raise e.with_traceback(None) from e.__cause__  # User compiler errortorch._dynamo.exc.Unsupported: Data-dependent branching  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.  Hint: Use `torch.cond` to express dynamic control flow.  Developer debug context: attempted to jump with TensorVariable() For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.htmlfrom user code:   File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 313, in bar    if b.sum() < 0:Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

In our example above, we can work around this graph break by replacingthe if statement with atorch.cond:

fromfunctorch.experimental.control_flowimportcond@torch.compile(fullgraph=True)defbar_fixed(a,b):x=a/(torch.abs(a)+1)deftrue_branch(y):returny*-1deffalse_branch(y):# NOTE: torch.cond doesn't allow aliased outputsreturny.clone()x=cond(b.sum()<0,true_branch,false_branch,(b,))returnx*bbar_fixed(inp1,inp2)bar_fixed(inp1,-inp2)
TRACED GRAPH ===== __compiled_fn_37_b0ca4d9f_9ec5_452b_97ca_b2f57bdab312 ===== /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):    def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):        l_a_ = L_a_        l_b_ = L_b_         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:373 in bar_fixed, code: x = a / (torch.abs(a) + 1)        abs_1: "f32[10][1]cpu" = torch.abs(l_a_)        add: "f32[10][1]cpu" = abs_1 + 1;  abs_1 = None        x: "f32[10][1]cpu" = l_a_ / add;  l_a_ = add = x = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:382 in bar_fixed, code: x = cond(b.sum() < 0, true_branch, false_branch, (b,))        sum_1: "f32[][]cpu" = l_b_.sum()        lt: "b8[][]cpu" = sum_1 < 0;  sum_1 = None         # File: /usr/local/lib/python3.10/dist-packages/torch/_higher_order_ops/cond.py:186 in cond, code: return cond_op(pred, true_fn, false_fn, operands)        cond_true_0 = self.cond_true_0        cond_false_0 = self.cond_false_0        cond = torch.ops.higher_order.cond(lt, cond_true_0, cond_false_0, (l_b_,));  lt = cond_true_0 = cond_false_0 = None        x_1: "f32[10][1]cpu" = cond[0];  cond = None         # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:383 in bar_fixed, code: return x * b        mul: "f32[10][1]cpu" = x_1 * l_b_;  x_1 = l_b_ = None        return (mul,)    class cond_true_0(torch.nn.Module):        def forward(self, l_b_: "f32[10][1]cpu"):            l_b__1 = l_b_             # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:376 in true_branch, code: return y * -1            mul: "f32[10][1]cpu" = l_b__1 * -1;  l_b__1 = None            return (mul,)    class cond_false_0(torch.nn.Module):        def forward(self, l_b_: "f32[10][1]cpu"):            l_b__1 = l_b_             # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:380 in false_branch, code: return y.clone()            clone: "f32[10][1]cpu" = l_b__1.clone();  l_b__1 = None            return (clone,)tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.])

In order to serialize graphs or to run graphs on different (i.e. Python-less)environments, consider usingtorch.export instead (from PyTorch 2.1+).One important restriction is thattorch.export does not support graph breaks. Please checkthe torch.export tutorialfor more details ontorch.export.

Check out oursection on graph breaks in the torch.compile programming modelfor tips on how to work around graph breaks.

Troubleshooting#

Istorch.compile failing to speed up your model? Is compile time unreasonably long?Is your code recompiling excessively? Are you having difficulties dealing with graph breaks?Are you looking for tips on how to best usetorch.compile?Or maybe you simply want to learn more about the inner workings oftorch.compile?

Check outthe torch.compile programming model.

Conclusion#

In this tutorial, we introducedtorch.compile by coveringbasic usage, demonstrating speedups over eager mode, comparing to TorchScript,and briefly describing graph breaks.

For an end-to-end example on a real model, check out ourend-to-end torch.compile tutorial.

To troubleshoot issues and to gain a deeper understanding of how to applytorch.compile to your code, check outthe torch.compile programming model.

We hope that you will givetorch.compile a try!

Total running time of the script: (0 minutes 16.352 seconds)