Rate this Page

torch.jit.trace#

torch.jit.trace(func,example_inputs=None,optimize=None,check_trace=True,check_inputs=None,check_tolerance=1e-05,strict=True,_force_outplace=False,_module_class=None,_compilation_unit=<torch.jit.CompilationUnitobject>,example_kwarg_inputs=None,_store_inputs=True)[source]#

Trace a function and return an executable orScriptFunction that will be optimized using just-in-time compilation.

Tracing is ideal for code that operates only onTensor\s and lists, dictionaries, andtuples ofTensor\s.

Usingtorch.jit.trace andtorch.jit.trace_module, you can turn anexisting module or Python function into a TorchScriptScriptFunction orScriptModule. You must provide exampleinputs, and we run the function, recording the operations performed on allthe tensors.

  • The resulting recording of a standalone function producesScriptFunction.

  • The resulting recording ofnn.Module.forward ornn.Module producesScriptModule.

This module also contains any parameters that the originalmodule had as well.

Warning

Tracing only correctly records functions and modules which are not datadependent (e.g., do not have conditionals on data in tensors) and do not haveany untracked external dependencies (e.g., perform input/output oraccess global variables). Tracing only records operations done when the givenfunction is run on the given tensors. Therefore, the returnedScriptModule will always run the same traced graph on any input. Thishas some important implications when your module is expected to rundifferent sets of operations, depending on the input and/or the modulestate. For example,

  • Tracing will not record any control-flow like if-statements or loops.When this control-flow is constant across your module, this is fineand it often inlines the control-flow decisions. But sometimes thecontrol-flow is actually part of the model itself. For instance, arecurrent network is a loop over the (possibly dynamic) length of aninput sequence.

  • In the returnedScriptModule, operations that have differentbehaviors intraining andeval modes will always behave as ifit is in the mode it was in during tracing, no matter which mode theScriptModule is in.

In cases like these, tracing would not be appropriate andscripting is a better choice. If you tracesuch models, you may silently get incorrect results on subsequentinvocations of the model. The tracer will try to emit warnings whendoing something that may cause an incorrect trace to be produced.

Parameters

func (callable ortorch.nn.Module) – A Python function ortorch.nn.Modulethat will be run withexample_inputs.func arguments and returnvalues must be tensors or (possibly nested) tuples that containtensors. When a module is passedtorch.jit.trace, only theforward method is run and traced (seetorch.jit.trace for details).

Keyword Arguments
  • example_inputs (tuple ortorch.Tensor orNone,optional) – A tuple of exampleinputs that will be passed to the function while tracing.Default:None. Either this argument orexample_kwarg_inputsshould be specified. The resulting trace can be run with inputs ofdifferent types and shapes assuming the traced operations support thosetypes and shapes.example_inputs may also be a single Tensor in whichcase it is automatically wrapped in a tuple. When the value is None,example_kwarg_inputs should be specified.

  • check_trace (bool, optional) – Check if the same inputs run throughtraced code produce the same outputs. Default:True. You might wantto disable this if, for example, your network contains non-deterministic ops or if you are sure that the network is correct despitea checker failure.

  • check_inputs (list oftuples,optional) – A list of tuples of inputarguments that should be used to check the trace against what isexpected. Each tuple is equivalent to a set of input arguments thatwould be specified inexample_inputs. For best results, pass ina set of checking inputs representative of the space of shapes andtypes of inputs you expect the network to see. If not specified,the originalexample_inputs are used for checking

  • check_tolerance (float,optional) – Floating-point comparison toleranceto use in the checker procedure. This can be used to relax thechecker strictness in the event that results diverge numericallyfor a known reason, such as operator fusion.

  • strict (bool, optional) – run the tracer in a strict mode or not(default:True). Only turn this off when you want the tracer torecord your mutable container types (currentlylist/dict)and you are sure that the container you are using in yourproblem is aconstant structure and does not get used ascontrol flow (if, for) conditions.

  • example_kwarg_inputs (dict,optional) – This parameter is a pack of keywordarguments of example inputs that will be passed to the function whiletracing. Default:None. Either this argument orexample_inputsshould be specified. The dict will be unpacking by the arguments nameof the traced function. If the keys of the dict don’t not match withthe traced function’s arguments name, a runtime exception will be raised.

Returns

Iffunc isnn.Module orforward ofnn.Module,trace returnsaScriptModule object with a singleforward methodcontaining the traced code. The returnedScriptModule willhave the same set of sub-modules and parameters as the originalnn.Module. Iffunc is a standalone function,tracereturnsScriptFunction.

Example (tracing a function):

importtorchdeffoo(x,y):return2*x+y# Run `foo` with the provided inputs and record the tensor operationstraced_foo=torch.jit.trace(foo,(torch.rand(3),torch.rand(3)))# `traced_foo` can now be run with the TorchScript interpreter or saved# and loaded in a Python-free environment

Example (tracing an existing module):

importtorchimporttorch.nnasnnclassNet(nn.Module):def__init__(self)->None:super().__init__()self.conv=nn.Conv2d(1,1,3)defforward(self,x):returnself.conv(x)n=Net()example_weight=torch.rand(1,1,3,3)example_forward_input=torch.rand(1,1,3,3)# Trace a specific method and construct `ScriptModule` with# a single `forward` methodmodule=torch.jit.trace(n.forward,example_forward_input)# Trace a module (implicitly traces `forward`) and construct a# `ScriptModule` with a single `forward` methodmodule=torch.jit.trace(n,example_forward_input)