Rate this Page
torch.compile">

Compiled Autograd: Capturing a larger backward graph fortorch.compile#

Created On: Oct 09, 2024 | Last Updated: Oct 23, 2024 | Last Verified: Oct 09, 2024

Author:Simon Fan

What you will learn
  • How compiled autograd interacts withtorch.compile

  • How to use the compiled autograd API

  • How to inspect logs usingTORCH_LOGS

Prerequisites

Overview#

Compiled Autograd is atorch.compile extension introduced in PyTorch 2.4that allows the capture of a larger backward graph.

Whiletorch.compile does capture the backward graph, it does sopartially. The AOTAutograd component captures the backward graph ahead-of-time, with certain limitations:

  • Graph breaks in the forward lead to graph breaks in the backward

  • Backward hooks are not captured

Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowingit to capture the full backward graph at runtime. Models with these two characteristics should tryCompiled Autograd, and potentially observe better performance.

However, Compiled Autograd introduces its own limitations:

  • Added runtime overhead at the start of the backward for cache lookup

  • More prone to recompiles and graph breaks in dynamo due to the larger capture

Note

Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features. For the latest status on a particular feature, refer toCompiled Autograd Landing Page.

Setup#

In this tutorial, we will base our examples on this simple neural network model.It takes a 10-dimensional input vector, processes it through a single linear layer, and outputs another 10-dimensional vector.

importtorchclassModel(torch.nn.Module):def__init__(self):super().__init__()self.linear=torch.nn.Linear(10,10)defforward(self,x):returnself.linear(x)

Basic usage#

Before calling thetorch.compile API, make sure to settorch._dynamo.config.compiled_autograd toTrue:

model=Model()x=torch.randn(10)torch._dynamo.config.compiled_autograd=True@torch.compiledeftrain(model,x):loss=model(x).sum()loss.backward()train(model,x)

In the code above, we create an instance of theModel class and generate a random 10-dimensional tensorx by usingtorch.randn(10).We define the training loop functiontrain and decorate it with @torch.compile to optimize its execution.Whentrain(model,x) is called:

  • Python Interpreter calls Dynamo, since this call was decorated with@torch.compile.

  • Dynamo intercepts the Python bytecode, simulates their execution and records the operations into a graph.

  • AOTDispatcher disables hooks and calls the autograd engine to compute gradients formodel.linear.weight andmodel.linear.bias, and records the operations into a graph. Usingtorch.autograd.Function, AOTDispatcher rewrites the forward and backward implementation oftrain.

  • Inductor generates a function corresponding to an optimized implementation of the AOTDispatcher forward and backward.

  • Dynamo sets the optimized function to be evaluated next by Python Interpreter.

  • Python Interpreter executes the optimized function, which executesloss=model(x).sum().

  • Python Interpreter executesloss.backward(), calling into the autograd engine, which routes to the Compiled Autograd engine since we settorch._dynamo.config.compiled_autograd=True.

  • Compiled Autograd computes the gradients formodel.linear.weight andmodel.linear.bias, and records the operations into a graph, including any hooks it encounters. During this process, it will record the backward previously rewritten by AOTDispatcher. Compiled Autograd then generates a new function which corresponds to a fully-traced implementation ofloss.backward(), and executes it withtorch.compile in inference mode.

  • The same steps recursively apply to the Compiled Autograd graph, but this time AOTDispatcher will not need to partition the graph.

Inspecting the compiled autograd logs#

Run the script with theTORCH_LOGS environment variables:

  • To only print the compiled autograd graph, useTORCH_LOGS="compiled_autograd"pythonexample.py

  • To print the graph with more tensor metadata and recompile reasons, at the cost of performance, useTORCH_LOGS="compiled_autograd_verbose"pythonexample.py

Rerun the snippet above, the compiled autograd graph should now be logged tostderr. Certain graph nodes will have names that are prefixed byaot0_,these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0, for example,aot0_view_2 corresponds toview_2 of the AOT backward graph with id=0.

In the image below, the red box encapsulates the AOT backward graph that is captured bytorch.compile without Compiled Autograd.

../_images/entire_verbose_log.png

Note

This is the graph on which we will calltorch.compile,NOT the optimized graph. Compiled Autograd essentially generates some unoptimized Python code to represent the entire C++ autograd execution.

Compiling the forward and backward pass using different flags#

You can use different compiler configs for the two compilations, for example, the backward may be a fullgraph even if there are graph breaks in the forward.

deftrain(model,x):model=torch.compile(model)loss=model(x).sum()torch._dynamo.config.compiled_autograd=Truetorch.compile(lambda:loss.backward(),fullgraph=True)()

Or you can use the context manager, which will apply to all autograd calls within its scope.

deftrain(model,x):model=torch.compile(model)loss=model(x).sum()withtorch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):loss.backward()

Compiled Autograd addresses certain limitations of AOTAutograd#

  1. Graph breaks in the forward pass no longer necessarily lead to graph breaks in the backward pass:

@torch.compile(backend="aot_eager")deffn(x):# 1st graphtemp=x+10torch._dynamo.graph_break()# 2nd graphtemp=temp+10torch._dynamo.graph_break()# 3rd graphreturntemp.sum()x=torch.randn(10,10,requires_grad=True)torch._dynamo.utils.counters.clear()loss=fn(x)# 1. base torch.compileloss.backward(retain_graph=True)assert(torch._dynamo.utils.counters["stats"]["unique_graphs"]==3)torch._dynamo.utils.counters.clear()# 2. torch.compile with compiled autogradwithtorch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):loss.backward()# single graph for the backwardassert(torch._dynamo.utils.counters["stats"]["unique_graphs"]==1)

In the firsttorch.compile case, we see that 3 backward graphs were produced due to the 2 graph breaks in the compiled functionfn.Whereas in the secondtorch.compile with compiled autograd case, we see that a full backward graph was traced despite the graph breaks.

Note

It is still possible for the Dynamo to graph break when tracing backward hooks captured by Compiled Autograd.

  1. Backward hooks can now be captured

@torch.compile(backend="aot_eager")deffn(x):returnx.sum()x=torch.randn(10,10,requires_grad=True)x.register_hook(lambdagrad:grad+10)loss=fn(x)withtorch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):loss.backward()

There should be acall_hook node in the graph, which dynamo will later inline into the following:

../_images/call_hook_node.png

Common recompilation reasons for Compiled Autograd#

  1. Due to changes in the autograd structure of the loss value:

torch._dynamo.config.compiled_autograd=Truex=torch.randn(10,requires_grad=True)foropin[torch.add,torch.sub,torch.mul,torch.div]:loss=op(x,x).sum()torch.compile(lambda:loss.backward(),backend="eager")()

In the example above, we call a different operator on each iteration, leading toloss tracking a different autograd history each time. You should see some recompile messages:Cache miss due to new autograd node.

../_images/recompile_due_to_node.png
  1. Due to tensors changing shapes:

torch._dynamo.config.compiled_autograd=Trueforiin[10,100,10]:x=torch.randn(i,i,requires_grad=True)loss=x.sum()torch.compile(lambda:loss.backward(),backend="eager")()

In the example above,x changes shapes, and compiled autograd will markx as a dynamic shape tensor after the first change. You should see recompiles messages:Cache miss due to changed shapes.

../_images/recompile_due_to_dynamic.png

Conclusion#

In this tutorial, we went over the high-level ecosystem oftorch.compile with compiled autograd, the basics of compiled autograd and a few common recompilation reasons. Stay tuned for deep dives ondev-discuss.