torch.jit.trace_module#
- torch.jit.trace_module(mod,inputs,optimize=None,check_trace=True,check_inputs=None,check_tolerance=1e-05,strict=True,_force_outplace=False,_module_class=None,_compilation_unit=<torch.jit.CompilationUnitobject>,example_inputs_is_kwarg=False,_store_inputs=True)[source]#
Trace a module and return an executable
ScriptModulethat will be optimized using just-in-time compilation.When a module is passed to
torch.jit.trace, onlytheforwardmethod is run and traced. Withtrace_module, you can specify a dictionary ofmethod names to example inputs to trace (see theinputs) argument below.See
torch.jit.tracefor more information on tracing.- Parameters
mod (torch.nn.Module) – A
torch.nn.Modulecontaining methods whose names arespecified ininputs. The given methods will be compiledas a part of a singleScriptModule.inputs (dict) – A dict containing sample inputs indexed by method names in
mod.The inputs will be passed to methods whose names correspond to inputs’keys while tracing.{'forward':example_forward_input,'method2':example_method2_input}
- Keyword Arguments
check_trace (
bool, optional) – Check if the same inputs run throughtraced code produce the same outputs. Default:True. You might wantto disable this if, for example, your network contains non-deterministic ops or if you are sure that the network is correct despitea checker failure.check_inputs (list ofdicts,optional) – A list of dicts of input arguments that should be usedto check the trace against what is expected. Each tupleis equivalent to a set of input arguments that wouldbe specified in
inputs. For best results, pass in aset of checking inputs representative of the space ofshapes and types of inputs you expect the network to see.If not specified, the originalinputsare used for checkingcheck_tolerance (float,optional) – Floating-point comparison tolerance to use in the checker procedure.This can be used to relax the checker strictness in the event thatresults diverge numerically for a known reason, such as operator fusion.
example_inputs_is_kwarg (
bool, optional) – This parameter indicate whether the example inputs is a packpack of keyword arguments. Default:False.
- Returns
A
ScriptModuleobject with a singleforwardmethod containing the traced code.Whenfuncis atorch.nn.Module, the returnedScriptModulewill have the same set ofsub-modules and parameters asfunc.
Example (tracing a module with multiple methods):
importtorchimporttorch.nnasnnclassNet(nn.Module):def__init__(self)->None:super().__init__()self.conv=nn.Conv2d(1,1,3)defforward(self,x):returnself.conv(x)defweighted_kernel_sum(self,weight):returnweight*self.conv.weightn=Net()example_weight=torch.rand(1,1,3,3)example_forward_input=torch.rand(1,1,3,3)# Trace a specific method and construct `ScriptModule` with# a single `forward` methodmodule=torch.jit.trace(n.forward,example_forward_input)# Trace a module (implicitly traces `forward`) and construct a# `ScriptModule` with a single `forward` methodmodule=torch.jit.trace(n,example_forward_input)# Trace specific methods on a module (specified in `inputs`), constructs# a `ScriptModule` with `forward` and `weighted_kernel_sum` methodsinputs={"forward":example_forward_input,"weighted_kernel_sum":example_weight,}module=torch.jit.trace_module(n,inputs)