Rate this Page

torch.jit.trace_module#

torch.jit.trace_module(mod,inputs,optimize=None,check_trace=True,check_inputs=None,check_tolerance=1e-05,strict=True,_force_outplace=False,_module_class=None,_compilation_unit=<torch.jit.CompilationUnitobject>,example_inputs_is_kwarg=False,_store_inputs=True)[source]#

Trace a module and return an executableScriptModule that will be optimized using just-in-time compilation.

When a module is passed totorch.jit.trace, onlytheforward method is run and traced. Withtrace_module, you can specify a dictionary ofmethod names to example inputs to trace (see theinputs) argument below.

Seetorch.jit.trace for more information on tracing.

Parameters
  • mod (torch.nn.Module) – Atorch.nn.Module containing methods whose names arespecified ininputs. The given methods will be compiledas a part of a singleScriptModule.

  • inputs (dict) – A dict containing sample inputs indexed by method names inmod.The inputs will be passed to methods whose names correspond to inputs’keys while tracing.{'forward':example_forward_input,'method2':example_method2_input}

Keyword Arguments
  • check_trace (bool, optional) – Check if the same inputs run throughtraced code produce the same outputs. Default:True. You might wantto disable this if, for example, your network contains non-deterministic ops or if you are sure that the network is correct despitea checker failure.

  • check_inputs (list ofdicts,optional) – A list of dicts of input arguments that should be usedto check the trace against what is expected. Each tupleis equivalent to a set of input arguments that wouldbe specified ininputs. For best results, pass in aset of checking inputs representative of the space ofshapes and types of inputs you expect the network to see.If not specified, the originalinputs are used for checking

  • check_tolerance (float,optional) – Floating-point comparison tolerance to use in the checker procedure.This can be used to relax the checker strictness in the event thatresults diverge numerically for a known reason, such as operator fusion.

  • example_inputs_is_kwarg (bool, optional) – This parameter indicate whether the example inputs is a packpack of keyword arguments. Default:False.

Returns

AScriptModule object with a singleforward method containing the traced code.Whenfunc is atorch.nn.Module, the returnedScriptModule will have the same set ofsub-modules and parameters asfunc.

Example (tracing a module with multiple methods):

importtorchimporttorch.nnasnnclassNet(nn.Module):def__init__(self)->None:super().__init__()self.conv=nn.Conv2d(1,1,3)defforward(self,x):returnself.conv(x)defweighted_kernel_sum(self,weight):returnweight*self.conv.weightn=Net()example_weight=torch.rand(1,1,3,3)example_forward_input=torch.rand(1,1,3,3)# Trace a specific method and construct `ScriptModule` with# a single `forward` methodmodule=torch.jit.trace(n.forward,example_forward_input)# Trace a module (implicitly traces `forward`) and construct a# `ScriptModule` with a single `forward` methodmodule=torch.jit.trace(n,example_forward_input)# Trace specific methods on a module (specified in `inputs`), constructs# a `ScriptModule` with `forward` and `weighted_kernel_sum` methodsinputs={"forward":example_forward_input,"weighted_kernel_sum":example_weight,}module=torch.jit.trace_module(n,inputs)