torch.compile#
- torch.compile(model:Callable[[_InputT],_RetT],*,fullgraph:bool=False,dynamic:Optional[bool]=None,backend:Union[str,Callable]='inductor',mode:Optional[str]=None,options:Optional[dict[str,Union[str,int,bool,Callable]]]=None,disable:bool=False)→Callable[[_InputT],_RetT][source]#
- torch.compile(model:None=None,*,fullgraph:bool=False,dynamic:Optional[bool]=None,backend:Union[str,Callable]='inductor',mode:Optional[str]=None,options:Optional[dict[str,Union[str,int,bool,Callable]]]=None,disable:bool=False)→Callable[[Callable[[_InputT],_RetT]],Callable[[_InputT],_RetT]]
Optimizes given model/function using TorchDynamo and specified backend.If you are compiling an
torch.nn.Module, you can also usetorch.nn.Module.compile()to compile the module inplace without changing its structure.Concretely, for every frame executed within the compiled region, we will attemptto compile it and cache the compiled result on the code object for futureuse. A single frame may be compiled multiple times if previous compiledresults are not applicable for subsequent calls (this is called a “guardfailure), you can use TORCH_LOGS=guards to debug these situations.Multiple compiled results can be associated with a frame up to
torch._dynamo.config.recompile_limit, which defaults to 8; at whichpoint we will fall back to eager. Note that compile caches are percode object, not frame; if you dynamically create multiple copies of afunction, they will all share the same code cache.- Parameters
model (Callable orNone) – Module/function to optimize
fullgraph (bool) – If False (default), torch.compile attempts to discover compilable regionsin the function that it will optimize. If True, then we require that the entire function becapturable into a single graph. If this is not possible (that is, if there are graph breaks),then this will raise an error.
dynamic (bool orNone) – Use dynamic shape tracing. When this is True, we will up-front attemptto generate a kernel that is as dynamic as possible to avoid recompilations whensizes change. This may not always work as some operations/optimizations willforce specialization; use TORCH_LOGS=dynamic to debug overspecialization.When this is False, we will NEVER generate dynamic kernels, we will always specialize.By default (None), we automatically detect if dynamism has occurred and compile a moredynamic kernel upon recompile.
backend (str orCallable) –
backend to be used
”inductor” is the default backend, which is a good balance between performance and overhead
Non experimental in-tree backends can be seen withtorch._dynamo.list_backends()
Experimental or debug in-tree backends can be seen withtorch._dynamo.list_backends(None)
To register an out-of-tree custom backend:https://pytorch.org/docs/main/torch.compiler_custom_backends.html#registering-custom-backends
mode (str) –
Can be either “default”, “reduce-overhead”, “max-autotune” or “max-autotune-no-cudagraphs”
”default” is the default mode, which is a good balance between performance and overhead
”reduce-overhead” is a mode that reduces the overhead of python with CUDA graphs,useful for small batches. Reduction of overhead can come at the cost of more memoryusage, as we will cache the workspace memory required for the invocation so that wedo not have to reallocate it on subsequent runs. Reduction of overhead is not guaranteedto work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs.There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hintsto debug.
”max-autotune” is a mode that leverages Triton or template based matrix multiplicationson supported devices and Triton based convolutions on GPU.It enables CUDA graphs by default on GPU.
”max-autotune-no-cudagraphs” is a mode similar to “max-autotune” but without CUDA graphs
To see the exact configs that each mode sets you can calltorch._inductor.list_mode_options()
options (dict) –
A dictionary of options to pass to the backend. Some notable ones to try out are
epilogue_fusion which fuses pointwise ops into templates. Requiresmax_autotune to also be set
max_autotune which will profile to pick the best matmul configuration
fallback_random which is useful when debugging accuracy issues
shape_padding which pads matrix shapes to better align loads on GPUs especially for tensor cores
triton.cudagraphs which will reduce the overhead of python with CUDA graphs
trace.enabled which is the most useful debugging flag to turn on
trace.graph_diagram which will show you a picture of your graph after fusion
guard_filter_fn that controls which dynamo guards are saved with compilations.This is an unsafe feature and there is no backward compatibility guarantee providedfor dynamo guards as data types.For stable helper functions to use, see the documentations intorch.compiler, for example:-torch.compiler.skip_guard_on_inbuilt_nn_modules_unsafe-torch.compiler.skip_guard_on_all_nn_modules_unsafe-torch.compiler.keep_tensor_guards_unsafe
For inductor you can see the full list of configs that it supports by callingtorch._inductor.list_options()
disable (bool) – Turn torch.compile() into a no-op for testing
Example:
@torch.compile(options={"triton.cudagraphs":True},fullgraph=True)deffoo(x):returntorch.sin(x)+torch.cos(x)