torch.autograd.Function.forward#
- staticFunction.forward(*args,**kwargs)[source]#
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses.There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethoddefforward(ctx:Any,*args:Any,**kwargs:Any)->Any:pass
It must accept a context ctx as the first argument, followed by anynumber of arguments (tensors or other types).
SeeCombined or separate forward() and setup_context() for more details
Usage 2 (Separate forward and ctx):
@staticmethoddefforward(*args:Any,**kwargs:Any)->Any:pass@staticmethoddefsetup_context(ctx:Any,inputs:Tuple[Any,...],output:Any)->None:pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()staticmethod to handle setting up thectxobject.outputis the output of the forward,inputsare a Tuple of inputsto the forward.SeeExtending torch.autograd for more details
The context can be used to store arbitrary data that can be thenretrieved during the backward pass. Tensors should not be storeddirectly onctx (though this is not currently enforced forbackward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()if they are intended to be used inbackward(equivalently,vjp) orctx.save_for_forward()if they are intended to be used for injvp.- Return type: