Rate this Page

torch.cuda.make_graphed_callables#

torch.cuda.make_graphed_callables(callables:Union[Module,Callable[[...],object]],sample_args:tuple[torch.Tensor,...],num_warmup_iters:int=3,allow_unused_input:bool=False,pool:Optional[_POOL_HANDLE]=None)Union[Module,Callable[[...],object]][source]#
torch.cuda.make_graphed_callables(callables:tuple[Union[torch.nn.modules.module.Module,Callable[...,object]],...],sample_args:tuple[tuple[torch.Tensor,...],...],num_warmup_iters:int=3,allow_unused_input:bool=False,pool:Optional[_POOL_HANDLE]=None)tuple[Union[torch.nn.modules.module.Module,Callable[...,object]],...]

Accept callables (functions ornn.Modules) and returns graphed versions.

Each graphed callable’s forward pass runs its source callable’sforward CUDA work as a CUDA graph inside a single autograd node.

The graphed callable’s forward pass also appendsa backward node to the autograd graph. During backward, this node runs thecallable’s backward work as a CUDA graph.

Therefore, each graphed callable should be a drop-in replacement for its source callablein an autograd-enabled training loop.

SeePartial-network capture for detailed use and constraints.

If you pass a tuple of several callables, their captures will use the same memory pool.SeeGraph memory management for when this is appropriate.

Parameters
  • callables (torch.nn.Module orPython function, ortuple ofthese) – Callable or callables to graph.SeeGraph memory management for when passing a tuple of callablesis appropriate. If you pass a tuple of callables, their order in the tuple must be the same orderthey’ll run in the live workload.

  • sample_args (tuple ofTensors, ortuple oftuples ofTensors) – Samples args for each callable.If a single callable was passed,sample_args must be a single tuple of argument Tensors.If a tuple of callables was passed,sample_args must be tuple of tuples of argument Tensors.

  • num_warmup_iters (int) – The number of warmup iterations. Currently,DataDistributedParallel needs11 iterations for warm up. Default:3.

  • allow_unused_input (bool) – If False, specifying inputs that were not used when computing outputs(and therefore their grad is always zero) is an error. Defaults to False.

  • pool (optional) – Token (returned bygraph_pool_handle() orother_Graph_instance.pool()) that hints this graph may share memorywith the indicated pool. SeeGraph memory management.

Note

Therequires_grad state of each Tensor insample_args must match the statethat’s expected for the corresponding real input in the training loop.

Warning

This API is in beta and may change in future releases.

Warning

sample_args for each callable must contain only Tensors. Other types are not allowed.

Warning

Returned callables do not support higher order differentiation (e.g., double backward).

Warning

In anyModule passed tomake_graphed_callables(), only parametersmay be trainable. Buffers must haverequires_grad=False.

Warning

After you pass atorch.nn.Module throughmake_graphed_callables(),you may not add or remove any of that Module’s parameters or buffers.

Warning

torch.nn.Modules passed tomake_graphed_callables() must not have module hooksregistered on them at the time they are passed. However, registering hooks on modulesafter passing themthroughmake_graphed_callables() is allowed.

Warning

When running a graphed callable, you must pass its arguments in the same order and formatthey appeared in that callable’ssample_args.

Warning

The automatic mixed precision is supported inmake_graphed_callables() only with disabledcaching. The context managertorch.cuda.amp.autocast() must havecache_enabled=False.