Rate this Page

torch.func.grad#

torch.func.grad(func,argnums=0,has_aux=False)[source]#

grad operator helps computing gradients offunc with respect to theinput(s) specified byargnums. This operator can be nested tocompute higher-order gradients.

Parameters
  • func (Callable) – A Python function that takes one or more arguments.Must return a single-element Tensor. If specifiedhas_aux equalsTrue,function can return a tuple of single-element Tensor and other auxiliary objects:(output,aux).

  • argnums (int orTuple[int]) – Specifies arguments to compute gradients with respect to.argnums can be single integer or tuple of integers. Default: 0.

  • has_aux (bool) – Flag indicating thatfunc returns a tensor and otherauxiliary objects:(output,aux). Default: False.

Returns

Function to compute gradients with respect to its inputs. By default, the output ofthe function is the gradient tensor(s) with respect to the first argument.If specifiedhas_aux equalsTrue, tuple of gradients and output auxiliary objectsis returned. Ifargnums is a tuple of integers, a tuple of output gradients withrespect to eachargnums value is returned.

Return type

Callable

Example of usinggrad:

>>>fromtorch.funcimportgrad>>>x=torch.randn([])>>>cos_x=grad(lambdax:torch.sin(x))(x)>>>asserttorch.allclose(cos_x,x.cos())>>>>>># Second-order gradients>>>neg_sin_x=grad(grad(lambdax:torch.sin(x)))(x)>>>asserttorch.allclose(neg_sin_x,-x.sin())

When composed withvmap,grad can be used to compute per-sample-gradients:

>>>fromtorch.funcimportgrad,vmap>>>batch_size,feature_size=3,5>>>>>>defmodel(weights,feature_vec):>>># Very simple linear model with activation>>>assertfeature_vec.dim()==1>>>returnfeature_vec.dot(weights).relu()>>>>>>defcompute_loss(weights,example,target):>>>y=model(weights,example)>>>return((y-target)**2).mean()# MSELoss>>>>>>weights=torch.randn(feature_size,requires_grad=True)>>>examples=torch.randn(batch_size,feature_size)>>>targets=torch.randn(batch_size)>>>inputs=(weights,examples,targets)>>>grad_weight_per_example=vmap(grad(compute_loss),in_dims=(None,0,0))(...*inputs...)

Example of usinggrad withhas_aux andargnums:

>>>fromtorch.funcimportgrad>>>defmy_loss_func(y,y_pred):>>>loss_per_sample=(0.5*y_pred-y)**2>>>loss=loss_per_sample.mean()>>>returnloss,(y_pred,loss_per_sample)>>>>>>fn=grad(my_loss_func,argnums=(0,1),has_aux=True)>>>y_true=torch.rand(4)>>>y_preds=torch.rand(4,requires_grad=True)>>>out=fn(y_true,y_preds)>>># > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))

Note

Using PyTorchtorch.no_grad together withgrad.

Case 1: Usingtorch.no_grad inside a function:

>>>deff(x):>>>withtorch.no_grad():>>>c=x**2>>>returnx-c

In this case,grad(f)(x) will respect the innertorch.no_grad.

Case 2: Usinggrad insidetorch.no_grad context manager:

>>>withtorch.no_grad():>>>grad(f)(x)

In this case,grad will respect the innertorch.no_grad, but not theouter one. This is becausegrad is a “function transform”: its resultshould not depend on the result of a context manager outside off.