torch.jit.script#
- torch.jit.script(obj,optimize=None,_frames_up=0,_rcb=None,example_inputs=None)[source]#
Script the function.
Scripting a function or
nn.Modulewill inspect the source code, compileit as TorchScript code using the TorchScript compiler, and return aScriptModuleorScriptFunction. TorchScript itself is a subset of the Python language, so not allfeatures in Python work, but we provide enough functionality to compute ontensors and do control-dependent operations. For a complete guide, see theTorchScript Language Reference.Scripting a dictionary or list copies the data inside it into a TorchScript instance than can besubsequently passed by reference between Python and TorchScript with zero copy overhead.
torch.jit.scriptcan be used as a function for modules, functions, dictionaries and listsand as a decorator
@torch.jit.scriptfor torchscript-classes and functions.
- Parameters
obj (Callable,class, ornn.Module) – The
nn.Module, function, class type,dictionary, or list to compile.example_inputs (Union[List[Tuple],Dict[Callable,List[Tuple]],None]) – Provide example inputsto annotate the arguments for a function or
nn.Module.
- Returns
If
objisnn.Module,scriptreturnsaScriptModuleobject. The returnedScriptModulewillhave the same set of sub-modules and parameters as theoriginalnn.Module. Ifobjis a standalone function,aScriptFunctionwill be returned. Ifobjis adict, thenscriptreturns an instance oftorch._C.ScriptDict. Ifobjis alist,thenscriptreturns an instance oftorch._C.ScriptList.
- Scripting a function
The
@torch.jit.scriptdecorator will construct aScriptFunctionby compiling the body of the function.Example (scripting a function):
importtorch@torch.jit.scriptdeffoo(x,y):ifx.max()>y.max():r=xelse:r=yreturnrprint(type(foo))# torch.jit.ScriptFunction# See the compiled graph as Python codeprint(foo.code)# Call the function using the TorchScript interpreterfoo(torch.ones(2,2),torch.ones(2,2))
- **Scripting a function using example_inputs
Example inputs can be used to annotate a function arguments.
Example (annotating a function before scripting):
importtorchdeftest_sum(a,b):returna+b# Annotate the arguments to be intscripted_fn=torch.jit.script(test_sum,example_inputs=[(3,4)])print(type(scripted_fn))# torch.jit.ScriptFunction# See the compiled graph as Python codeprint(scripted_fn.code)# Call the function using the TorchScript interpreterscripted_fn(20,100)
- Scripting an nn.Module
Scripting an
nn.Moduleby default will compile theforwardmethod and recursivelycompile any methods, submodules, and functions called byforward. If ann.Moduleonly usesfeatures supported in TorchScript, no changes to the original module code should be necessary.scriptwill constructScriptModulethat has copies of the attributes, parameters, and methods ofthe original module.Example (scripting a simple module with a Parameter):
importtorchclassMyModule(torch.nn.Module):def__init__(self,N,M):super().__init__()# This parameter will be copied to the new ScriptModuleself.weight=torch.nn.Parameter(torch.rand(N,M))# When this submodule is used, it will be compiledself.linear=torch.nn.Linear(N,M)defforward(self,input):output=self.weight.mv(input)# This calls the `forward` method of the `nn.Linear` module, which will# cause the `self.linear` submodule to be compiled to a `ScriptModule` hereoutput=self.linear(output)returnoutputscripted_module=torch.jit.script(MyModule(2,3))
Example (scripting a module with traced submodules):
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassMyModule(nn.Module):def__init__(self)->None:super().__init__()# torch.jit.trace produces a ScriptModule's conv1 and conv2self.conv1=torch.jit.trace(nn.Conv2d(1,20,5),torch.rand(1,1,16,16))self.conv2=torch.jit.trace(nn.Conv2d(20,20,5),torch.rand(1,20,16,16))defforward(self,input):input=F.relu(self.conv1(input))input=F.relu(self.conv2(input))returninputscripted_module=torch.jit.script(MyModule())
To compile a method other than
forward(and recursively compile anything it calls), addthe@torch.jit.exportdecorator to the method. To opt out of compilationuse@torch.jit.ignoreor@torch.jit.unused.Example (an exported and ignored method in a module):
importtorchimporttorch.nnasnnclassMyModule(nn.Module):def__init__(self)->None:super().__init__()@torch.jit.exportdefsome_entry_point(self,input):returninput+10@torch.jit.ignoredefpython_only_fn(self,input):# This function won't be compiled, so any# Python APIs can be usedimportpdbpdb.set_trace()defforward(self,input):ifself.training:self.python_only_fn(input)returninput*99scripted_module=torch.jit.script(MyModule())print(scripted_module.some_entry_point(torch.randn(2,2)))print(scripted_module(torch.randn(2,2)))
Example ( Annotating forward of nn.Module using example_inputs):
importtorchimporttorch.nnasnnfromtypingimportNamedTupleclassMyModule(NamedTuple):result:List[int]classTestNNModule(torch.nn.Module):defforward(self,a)->MyModule:result=MyModule(result=a)returnresultpdt_model=TestNNModule()# Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forwardscripted_model=torch.jit.script(pdt_model,example_inputs={pdt_model:[([10,20,],),],})# Run the scripted_model with actual inputsprint(scripted_model([20]))