PyTorch 2.0 NNModule Support#
Created On: Apr 06, 2023 | Last Updated On: Jun 10, 2025
Author:Will Constable
torch.compile has special handling for torch.nn.Module objects, tracing them differently than it tracesarbitrary python classes, with the intent of producing faster code by making assumptions about the structure.
This doc describes some of the tradeoffs or edge cases that come up due to this specialization.
NNModule Hooks Support#
Previously,torch.compile had no support for hooks on nn.Modules, and if hooks were registeredthey would simply be ignored in the compiled program. Indeed many users do notuse nn.Module hooks at all, or only use them for debug workflows, but there are valid use casesfor composing nn.Module hooks withtorch.compile.
Hooks that are orchestrated via nn.Module.call implementation include_forward_pre_hooks,forward_hooks,_backward_pre_hooks, and_backward_hooks, and will be referred to as ‘call hooks’.These hooks are partially supported bytorch.compile with limitations described below.
Another category of hooks includes_state_dict_hooks and itspre andload_ variants, and are stillunsupported bytorch.compile.
nn.Module.__call__ Hooks Usage and limitations#
By default,torch.compile will trace the contents ofnn.Module.__call__ which means it will encounterand run forward/pre-forward hooks. If you install hooks before callingtorch.compile and then do not removeor alter the hooks later, your use case should be supported by default.
Backward/Pre-backward hooks are generally also supported, with similar caveats: currently graph-breaks in dynamooccur when accessing backward_hooks dicts, which is probably avoiable with some work. Graph-breaks also impact thetiming of firing backward hooks, since graph-segments are run as autograd-functions which produce all their grads atthe same time. Assuming it were possible for dynamo to not graph-break on the presence of backward-hooks, we wouldstill expect the backward hooks for a series of modules to all fire together after the whole compiled graph’s backwardran.
hooks on ‘allowed modules’torch.compile treats common modules such as torch.conv, as well as modules that are difficult to trace, speciallyby allowing them to be called opaquely in the dynamo graph instead of traced into by dynamo. For such modules, hookscurrently trigger a graph-break so that the affected modules run outside of dynamo. Depending on the model, this couldintroduce a significant performance regression, and additional work is required to improve this support.
skip_nnmodule_hook_guardsBy default,torch._dynamo.config.skip_nnmodule_hook_guards is set to True, meaning no guards will be installedon each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticingif any hook dict is changed after compilation.
If you want to be able to remove or modify hooks after compilation and havetorch.compile react appropriately(by recompiling), then you need to setskip_nnmodule_hook_guards=False and expect a runtime penalty for the addedguards.
TODO: confirm if backward/pre_backward hooks are working or not and document accordingly
state_dict Hooks#
State dict hooks have not yet been supported intorch.compile.
TODO: warn_once if graph-breaking on hooks. warn_once to point to this doc if hooks are present.