Rate this Page

Template Struct Function#

Struct Documentation#

template<classT>
structFunction#

To use custom autograd operations, implement aFunction subclass with static forward and backward functions:

forward can take as many arguments as you want and should return either a variable list or a Variable. Use of any direct Variable arguments will be registered in the graph but no vectors/sets or any other data structures will be traversed. You can use std::optional<Tensor> as one of the arguments and it will be registered as a variable in the graph if the argument has a value. It should take a pointer totorch::autograd::AutogradContext as the first argument. Variables can be saved in thectx usingctx->save_for_backward (seetorch::autograd::AutogradContext::save_for_backward) and other data can be saved in thectx->saved_data map (seetorch::autograd::AutogradContext::saved_data) in the form of<std::string,at::IValue> pairs.

backward should take a pointer totorch::autograd::AutogradContext and a variable list containing as many Variables as there were outputs fromforward as arguments. It should return as many Variables as there were inputs with each of them containing the gradient w.r.t. its corresponding input. Variables saved inforward can be accessed withctx->get_saved_variables (seetorch::autograd::AutogradContext::get_saved_variables) and other saved data can be accessed fromctx->saved_data. To enable compiled autograd support (torch.compile for backward) for your custom autograd operation, you can set MyFunction::is_traceable (see Function::istraceable notes below).

For example:

classMyFunction:publicFunction<MyFunction>{public:staticconstexprboolis_traceable=true;staticvariable_listforward(AutogradContext*ctx,intn,Variablevar){// Save data for backward in contextctx->saved_data["n"]=n;var.mul_(n);// Mark var as modified by inplace operationctx->mark_dirty({var});return{var};}staticvariable_listbackward(AutogradContext*ctx,variable_listgrad_output){// Use data saved in forwardauton=ctx->saved_data["n"].toInt();return{grad_output[0]*n};}};

To useMyFunction:

Variablex;autoy=MyFunction::apply(6,x);// Example backward cally[0].sum().backward();

Public Static Functions

template<typenameX=T,typename...Args>
staticautoapply(Args&&...args)->std::enable_if_t<std::is_same_v<X,T>,forward_t<X,Args...>>#

Public Static Attributes

staticconstexprboolis_traceable=false#