Rate this Page

torch.autograd.Function.vmap#

staticFunction.vmap(info,in_dims,*args)[source]#

Define the behavior for this autograd.Function underneathtorch.vmap().

For atorch.autograd.Function() to supporttorch.vmap(), you must either override this static method, or setgenerate_vmap_rule toTrue (you may not do both).

If you choose to override this staticmethod: it must accept

  • aninfo object as the first argument.info.batch_sizespecifies the size of the dimension being vmapped over,whileinfo.randomness is the randomness option passed totorch.vmap().

  • anin_dims tuple as the second argument.For each arg inargs,in_dims has a correspondingOptional[int]. It isNone if the arg is not a Tensor or ifthe arg is not being vmapped over, otherwise, it is an integerspecifying what dimension of the Tensor is being vmapped over.

  • *args, which is the same as the args toforward().

The return of the vmap staticmethod is a tuple of(output,out_dims).Similar toin_dims,out_dims should be of the same structure asoutput and contain oneout_dim per output that specifies if theoutput has the vmapped dimension and what index it is in.

Please seeExtending torch.func with autograd.Function for more details.