Rate this Page

torch.func.linearize#

torch.func.linearize(func,*primals)[source]#

Returns the value offunc atprimals and linear approximationatprimals.

Parameters
  • func (Callable) – A Python function that takes one or more arguments.

  • primals (Tensors) – Positional arguments tofunc that must all beTensors. These are the values at which the function is linearly approximated.

Returns

Returns a(output,jvp_fn) tuple containing the output offuncapplied toprimals and a function that computes the jvp offunc evaluated atprimals.

Return type

tuple[Any,Callable]

linearize is useful if jvp is to be computed multiple times atprimals. However,to achieve this, linearize saves intermediate computation and has higher memory requirementsthan directly applyingjvp. So, if all thetangents are known, it maybe more efficientto compute vmap(jvp) instead of using linearize.

Note

linearize evaluatesfunc twice. Please file an issue for an implementationwith a single evaluation.

Example:

>>>importtorch>>>fromtorch.funcimportlinearize>>>deffn(x):...returnx.sin()...>>>output,jvp_fn=linearize(fn,torch.zeros(3,3))>>>jvp_fn(torch.ones(3,3))tensor([[1., 1., 1.],        [1., 1., 1.],        [1., 1., 1.]])>>>