Rate this Page

torch.autograd.Function.forward#

staticFunction.forward(*args,**kwargs)[source]#

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses.There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethoddefforward(ctx:Any,*args:Any,**kwargs:Any)->Any:pass

Usage 2 (Separate forward and ctx):

@staticmethoddefforward(*args:Any,**kwargs:Any)->Any:pass@staticmethoddefsetup_context(ctx:Any,inputs:Tuple[Any,...],output:Any)->None:pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override thetorch.autograd.Function.setup_context()staticmethod to handle setting up thectx object.output is the output of the forward,inputs are a Tuple of inputsto the forward.

  • SeeExtending torch.autograd for more details

The context can be used to store arbitrary data that can be thenretrieved during the backward pass. Tensors should not be storeddirectly onctx (though this is not currently enforced forbackward compatibility). Instead, tensors should be saved either withctx.save_for_backward() if they are intended to be used inbackward (equivalently,vjp) orctx.save_for_forward()if they are intended to be used for injvp.

Return type:

Any