torch.func.linearize#
- torch.func.linearize(func,*primals)[source]#
Returns the value of
funcatprimalsand linear approximationatprimals.- Parameters
func (Callable) – A Python function that takes one or more arguments.
primals (Tensors) – Positional arguments to
functhat must all beTensors. These are the values at which the function is linearly approximated.
- Returns
Returns a
(output,jvp_fn)tuple containing the output offuncapplied toprimalsand a function that computes the jvp offuncevaluated atprimals.- Return type
linearize is useful if jvp is to be computed multiple times at
primals. However,to achieve this, linearize saves intermediate computation and has higher memory requirementsthan directly applyingjvp. So, if all thetangentsare known, it maybe more efficientto compute vmap(jvp) instead of using linearize.Note
linearize evaluates
functwice. Please file an issue for an implementationwith a single evaluation.Example:
>>>importtorch>>>fromtorch.funcimportlinearize>>>deffn(x):...returnx.sin()...>>>output,jvp_fn=linearize(fn,torch.zeros(3,3))>>>jvp_fn(torch.ones(3,3))tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]])>>>