Custom Backends#
Created On: Jun 10, 2025 | Last Updated On: Jun 10, 2025
Overview#
torch.compile provides a straightforward method to enable usersto define custom backends.
A backend function has the contract(gm:torch.fx.GraphModule,example_inputs:List[torch.Tensor])->Callable.
Backend functions can be called by TorchDynamo, the graph tracing component oftorch.compile,after tracing an FX graph and areexpected to return a compiled function that is equivalent to the traced FX graph.The returned callable should have the same contract as theforward function of the originaltorch.fx.GraphModulepassed into the backend:(*args:torch.Tensor)->List[torch.Tensor].
In order for TorchDynamo to call your backend, pass your backend function as thebackend kwarg intorch.compile. For example,
importtorchdefmy_custom_backend(gm,example_inputs):returngm.forwarddeff(...):...f_opt=torch.compile(f,backend=my_custom_backend)@torch.compile(backend=my_custom_backend)defg(...):...
See below for more examples.
Registering Custom Backends#
You can register your backend using theregister_backend decorator, for example,
fromtorch._dynamoimportregister_backend@register_backenddefmy_compiler(gm,example_inputs):...
Besides theregister_backend decorator, if your backend is in another python package, you could also register yourbackend through entry points of python package, which provides a way for a package to register a plugin for another one.
Hint
You can learn more aboutentry_points in thepython packaging documentation.
To register your backend throughentry_points, you could add your backend function to thetorch_dynamo_backends entry point group in thesetup.py file of your package like:
...setup(...'torch_dynamo_backends':['my_compiler = your_module.submodule:my_compiler',]...)
Please replace themy_compiler before= to the name of your backend’s name and replace the part after= tothe module and function name of your backend function.The entry point will be added to your python environment after the installation of the package.When you calltorch.compile(model,backend="my_compiler"), PyTorch would first search the backend namedmy_compilerthat has been registered withregister_backend. If not found, it will continue to search in all backends registeredviaentry_points.
Registration serves two purposes:
You can pass a string containing your backend function’s name to
torch.compileinstead of the function itself,for example,torch.compile(model,backend="my_compiler").It is required for use with theminifier. Any generatedcode from the minifier must call your code that registers your backend function, typically through an
importstatement.
Custom Backends after AOTAutograd#
It is possible to define custom backends that are called by AOTAutograd rather than TorchDynamo.This is useful for 2 main reasons:
Users can define backends that support model training, as AOTAutograd can generate the backward graph for compilation.
AOTAutograd produces FX graphs consisting ofcore Aten ops. As a result,custom backends only need to support the core Aten opset, which is a significantly smaller opset than the entire torch/Aten opset.
Wrap your backend withtorch._dynamo.backends.common.aot_autograd and usetorch.compile with thebackend kwarg as before.Backend functions wrapped byaot_autograd should have the same contract as before.
Backend functions are passed toaot_autograd through thefw_compiler (forward compiler)orbw_compiler (backward compiler) kwargs. Ifbw_compiler is not specified, the backward compile functiondefaults to the forward compile function.
One caveat is that AOTAutograd requires compiled functions returned by backends to be “boxed”. This can be done by wrappingthe compiled function withfunctorch.compile.make_boxed_func.
For example,
fromtorch._dynamo.backends.commonimportaot_autogradfromfunctorch.compileimportmake_boxed_funcdefmy_compiler(gm,example_inputs):returnmake_boxed_func(gm.forward)my_backend=aot_autograd(fw_compiler=my_compiler)# bw_compiler=my_compilermodel_opt=torch.compile(model,backend=my_backend)
Examples#
Debugging Backend#
If you want to better understand what is going on during acompilation, you can create a custom compiler, which is referred to asbackend in this section, that will print pretty print the fxGraphModule extracted from Dynamo’s bytecode analysisand return aforward() callable.
For example:
fromtypingimportListimporttorchdefmy_compiler(gm:torch.fx.GraphModule,example_inputs:List[torch.Tensor]):print("my_compiler() called with FX graph:")gm.graph.print_tabular()returngm.forward# return a python callable@torch.compile(backend=my_compiler)deffn(x,y):a=torch.cos(x)b=torch.sin(y)returna+bfn(torch.randn(10),torch.randn(10))
Running the above example produces the following output:
my_compiler()calledwithFXgraph:opcodenametargetargskwargs-------------------------------------------------------------------------------------------placeholderxx(){}placeholderyy(){}call_functioncos<built-inmethodcosoftypeobjectat0x7f1a894649a8>(x,){}call_functionsin<built-inmethodsinoftypeobjectat0x7f1a894649a8>(y,){}call_functionadd<built-infunctionadd>(cos,sin){}outputoutputoutput((add,),){}
This works fortorch.nn.Module as well as shown below:
fromtypingimportListimporttorchdefmy_compiler(gm:torch.fx.GraphModule,example_inputs:List[torch.Tensor]):print("my_compiler() called with FX graph:")gm.graph.print_tabular()returngm.forward# return a python callableclassMockModule(torch.nn.Module):def__init__(self):super().__init__()self.relu=torch.nn.ReLU()defforward(self,x):returnself.relu(torch.cos(x))mod=MockModule()optimized_mod=torch.compile(mod,backend=my_compiler)optimized_mod(torch.randn(10))
Let’s take a look at one more example with control flow:
fromtypingimportListimporttorchdefmy_compiler(gm:torch.fx.GraphModule,example_inputs:List[torch.Tensor]):print("my_compiler() called with FX graph:")gm.graph.print_tabular()returngm.forward# return a python callable@torch.compile(backend=my_compiler)deftoy_example(a,b):x=a/(torch.abs(a)+1)ifb.sum()<0:b=b*-1returnx*bfor_inrange(100):toy_example(torch.randn(10),torch.randn(10))
Running this example produces the following output:
my_compiler()calledwithFXgraph:opcodenametargetargskwargs--------------------------------------------------------------------------------------------------placeholderaa(){}placeholderbb(){}call_functionabs_1<built-inmethodabsoftypeobjectat0x7f8d259298a0>(a,){}call_functionadd<built-infunctionadd>(abs_1,1){}call_functiontruediv<built-infunctiontruediv>(a,add){}call_methodsum_1sum(b,){}call_functionlt<built-infunctionlt>(sum_1,0){}outputoutputoutput((truediv,lt),){}my_compiler()calledwithFXgraph:opcodenametargetargskwargs-------------------------------------------------------------placeholderbb(){}placeholderxx(){}call_functionmul<built-infunctionmul>(b,-1){}call_functionmul_1<built-infunctionmul>(x,mul){}outputoutputoutput((mul_1,),){}my_compiler()calledwithFXgraph:opcodenametargetargskwargs-----------------------------------------------------------placeholderbb(){}placeholderxx(){}call_functionmul<built-infunctionmul>(x,b){}outputoutputoutput((mul,),){}Theorderofthelasttwographsisnondeterministicdependingonwhichoneisencounteredfirstbythejust-in-timecompiler.
Speedy Backend#
Integrating a custom backend that offers superior performance is alsoeasy and we’ll integrate a real onewithoptimize_for_inference:
defoptimize_for_inference_compiler(gm:torch.fx.GraphModule,example_inputs:List[torch.Tensor]):scripted=torch.jit.script(gm)returntorch.jit.optimize_for_inference(scripted)
And then you should be able to optimize any existing code with:
@torch.compile(backend=optimize_for_inference_compiler)defcode_to_accelerate():...
Composable Backends#
TorchDynamo includes many backends, which can be listed withtorch._dynamo.list_backends(). You can combine these backendstogether with the following code:
fromtorch._dynamoimportlookup_backenddefmy_compiler(gm:torch.fx.GraphModule,example_inputs:List[torch.Tensor]):try:trt_compiled=lookup_backend("tensorrt")(gm,example_inputs)iftrt_compiledisnotNone:returntrt_compiledexceptException:pass# first backend failed, try something else...try:inductor_compiled=lookup_backend("inductor")(gm,example_inputs)ifinductor_compiledisnotNone:returninductor_compiledexceptException:passreturngm.forward