torch.jit.fork#
- torch.jit.fork(func,*args,**kwargs)[source]#
Create an asynchronous task executingfunc and a reference to the value of the result of this execution.
fork will return immediately, so the return value offunc may not have been computed yet. To force completionof the task and access the return value invoketorch.jit.wait on the Future.fork invokedwith afunc which returnsT is typed astorch.jit.Future[T].fork calls can be arbitrarilynested, and may be invoked with positional and keyword arguments.Asynchronous execution will only occur when run in TorchScript. If run in pure python,fork will not execute in parallel.fork will also not execute in parallel when invokedwhile tracing, however thefork andwait calls will be captured in the exported IR Graph.
Warning
fork tasks will execute non-deterministically. We recommend only spawningparallel fork tasks for pure functions that do not modify their inputs,module attributes, or global state.
- Parameters
func (callable ortorch.nn.Module) – A Python function ortorch.nn.Modulethat will be invoked. If executed in TorchScript, it will execute asynchronously,otherwise it will not. Traced invocations of fork will be captured in the IR.
*args – arguments to invokefunc with.
**kwargs – arguments to invokefunc with.
- Returns
a reference to the execution offunc. The valueTcan only be accessed by forcing completion offunc throughtorch.jit.wait.
- Return type
torch.jit.Future[T]
Example (fork a free function):
importtorchfromtorchimportTensordeffoo(a:Tensor,b:int)->Tensor:returna+bdefbar(a):fut:torch.jit.Future[Tensor]=torch.jit.fork(foo,a,b=2)returntorch.jit.wait(fut)script_bar=torch.jit.script(bar)input=torch.tensor(2)# only the scripted version executes asynchronouslyassertscript_bar(input)==bar(input)# trace is not run asynchronously, but fork is captured in IRgraph=torch.jit.trace(bar,(input,)).graphassert"fork"instr(graph)
Example (fork a module method):
importtorchfromtorchimportTensorclassAddMod(torch.nn.Module):defforward(self,a:Tensor,b:int):returna+bclassMod(torch.nn.Module):def__init__(self)->None:super(self).__init__()self.mod=AddMod()defforward(self,input):fut=torch.jit.fork(self.mod,a,b=2)returntorch.jit.wait(fut)input=torch.tensor(2)mod=Mod()assertmod(input)==torch.jit.script(mod).forward(input)