torch.autograd.Function.vmap#
- staticFunction.vmap(info,in_dims,*args)[source]#
Define the behavior for this autograd.Function underneath
torch.vmap().For a
torch.autograd.Function()to supporttorch.vmap(), you must either override this static method, or setgenerate_vmap_ruletoTrue(you may not do both).If you choose to override this staticmethod: it must accept
an
infoobject as the first argument.info.batch_sizespecifies the size of the dimension being vmapped over,whileinfo.randomnessis the randomness option passed totorch.vmap().an
in_dimstuple as the second argument.For each arg inargs,in_dimshas a correspondingOptional[int]. It isNoneif the arg is not a Tensor or ifthe arg is not being vmapped over, otherwise, it is an integerspecifying what dimension of the Tensor is being vmapped over.*args, which is the same as the args toforward().
The return of the vmap staticmethod is a tuple of
(output,out_dims).Similar toin_dims,out_dimsshould be of the same structure asoutputand contain oneout_dimper output that specifies if theoutput has the vmapped dimension and what index it is in.Please seeExtending torch.func with autograd.Function for more details.