NestedIOFunction#
- classtorch.autograd.function.NestedIOFunction(*args,**kwargs)[source]#
This class is here only for backward compatibility reasons.Use
Functioninstead of this for any new use case.- staticjvp(ctx,*grad_inputs)[source]#
Define a formula for differentiating the operation with forward mode automatic differentiation.
This function is to be overridden by all subclasses.It must accept a context
ctxas the first argument, followed byas many inputs as theforward()got (None will be passed infor non tensor inputs of the forward function),and it should return as many tensors as there were outputs toforward(). Each argument is the gradient w.r.t the given input,and each returned value should be the gradient w.r.t. thecorresponding output. If an output is not a Tensor or the function is notdifferentiable with respect to that output, you can just pass None as agradient for that input.You can use the
ctxobject to pass any value from the forward to thisfunctions.- Return type
- save_for_forward(*tensors)[source]#
Save given tensors for a future call to
jvp().save_for_forwardshould be called at most once, in either thesetup_context()orforward()methods, and all argumentsshould be tensors.In
jvp(), saved objects can be accessed through thesaved_tensorsattribute.Arguments can also be
None. This is a no-op.SeeExtending torch.autograd for more details on how to use this method.
Example:
>>>classFunc(torch.autograd.Function):>>>@staticmethod>>>defforward(ctx,x:torch.Tensor,y:torch.Tensor,z:int):>>>ctx.save_for_backward(x,y)>>>ctx.save_for_forward(x,y)>>>ctx.z=z>>>returnx*y*z>>>>>>@staticmethod>>>defjvp(ctx,x_t,y_t,_):>>>x,y=ctx.saved_tensors>>>z=ctx.z>>>returnz*(y*x_t+x*y_t)>>>>>>@staticmethod>>>defvjp(ctx,grad_out):>>>x,y=ctx.saved_tensors>>>z=ctx.z>>>returnz*grad_out*y,z*grad_out*x,None>>>>>>a=torch.tensor(1.,requires_grad=True,dtype=torch.double)>>>t=torch.tensor(1.,dtype=torch.double)>>>b=torch.tensor(2.,requires_grad=True,dtype=torch.double)>>>c=4>>>>>>withfwAD.dual_level():>>>a_dual=fwAD.make_dual(a,t)>>>d=Func.apply(a_dual,b,c)
- propertysaved_tensors#
See
Function.saved_tensors().
- set_materialize_grads(value)[source]#
Set whether to materialize grad tensors. Default is
True.This should be called only from either the
setup_context()orforward()methods.If
True, undefined grad tensors will be expanded to tensors full of zerosprior to calling thebackward()andjvp()methods.Example:
>>>classSimpleFunc(Function):>>>@staticmethod>>>defforward(ctx,x):>>>returnx.clone(),x.clone()>>>>>>@staticmethod>>>@once_differentiable>>>defbackward(ctx,g1,g2):>>>returng1+g2# No check for None necessary>>>>>># We modify SimpleFunc to handle non-materialized grad outputs>>>classFunc(Function):>>>@staticmethod>>>defforward(ctx,x):>>>ctx.set_materialize_grads(False)>>>ctx.save_for_backward(x)>>>returnx.clone(),x.clone()>>>>>>@staticmethod>>>@once_differentiable>>>defbackward(ctx,g1,g2):>>>x,=ctx.saved_tensors>>>grad_input=torch.zeros_like(x)>>>ifg1isnotNone:# We must check for None now>>>grad_input+=g1>>>ifg2isnotNone:>>>grad_input+=g2>>>returngrad_input>>>>>>a=torch.tensor(1.,requires_grad=True)>>>b,_=Func.apply(a)# induces g2 to be undefined
- staticsetup_context(ctx,inputs,output)[source]#
There are two ways to define the forward pass of an autograd.Function.
Either:
Override forward with the signature
forward(ctx,*args,**kwargs).setup_contextis not overridden. Setting up the ctx for backwardhappens inside theforward.Override forward with the signature
forward(*args,**kwargs)andoverridesetup_context. Setting up the ctx for backward happensinsidesetup_context(as opposed to inside theforward)
See
torch.autograd.Function.forward()andExtending torch.autograd for more details.- Return type
- staticvjp(ctx,*grad_outputs)[source]#
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses.(Defining this function is equivalent to defining the
vjpfunction.)It must accept a context
ctxas the first argument, followed byas many outputs as theforward()returned (None will be passed infor non tensor outputs of the forward function),and it should return as many tensors, as there were inputs toforward(). Each argument is the gradient w.r.t the given output,and each returned value should be the gradient w.r.t. thecorresponding input. If an input is not a Tensor or is a Tensor notrequiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forwardpass. It also has an attribute
ctx.needs_input_gradas a tupleof booleans representing whether each input needs gradient. E.g.,backward()will havectx.needs_input_grad[0]=Trueif thefirst input toforward()needs gradient computed w.r.t. theoutput.- Return type
- staticvmap(info,in_dims,*args)[source]#
Define the behavior for this autograd.Function underneath
torch.vmap().For a
torch.autograd.Function()to supporttorch.vmap(), you must either override this static method, or setgenerate_vmap_ruletoTrue(you may not do both).If you choose to override this staticmethod: it must accept
an
infoobject as the first argument.info.batch_sizespecifies the size of the dimension being vmapped over,whileinfo.randomnessis the randomness option passed totorch.vmap().an
in_dimstuple as the second argument.For each arg inargs,in_dimshas a correspondingOptional[int]. It isNoneif the arg is not a Tensor or ifthe arg is not being vmapped over, otherwise, it is an integerspecifying what dimension of the Tensor is being vmapped over.*args, which is the same as the args toforward().
The return of the vmap staticmethod is a tuple of
(output,out_dims).Similar toin_dims,out_dimsshould be of the same structure asoutputand contain oneout_dimper output that specifies if theoutput has the vmapped dimension and what index it is in.Please seeExtending torch.func with autograd.Function for more details.