Template Struct Function#
Defined inFile custom_function.h
Struct Documentation#
- template<classT>
structFunction# To use custom autograd operations, implement aFunction subclass with static forward and backward functions:
forwardcan take as many arguments as you want and should return either a variable list or a Variable. Use of any direct Variable arguments will be registered in the graph but no vectors/sets or any other data structures will be traversed. You can use std::optional<Tensor> as one of the arguments and it will be registered as a variable in the graph if the argument has a value. It should take a pointer totorch::autograd::AutogradContextas the first argument. Variables can be saved in thectxusingctx->save_for_backward(seetorch::autograd::AutogradContext::save_for_backward) and other data can be saved in thectx->saved_datamap (seetorch::autograd::AutogradContext::saved_data) in the form of<std::string,at::IValue>pairs.backwardshould take a pointer totorch::autograd::AutogradContextand a variable list containing as many Variables as there were outputs fromforwardas arguments. It should return as many Variables as there were inputs with each of them containing the gradient w.r.t. its corresponding input. Variables saved inforwardcan be accessed withctx->get_saved_variables(seetorch::autograd::AutogradContext::get_saved_variables) and other saved data can be accessed fromctx->saved_data. To enable compiled autograd support (torch.compile for backward) for your custom autograd operation, you can set MyFunction::is_traceable (see Function::istraceable notes below).For example:
classMyFunction:publicFunction<MyFunction>{public:staticconstexprboolis_traceable=true;staticvariable_listforward(AutogradContext*ctx,intn,Variablevar){// Save data for backward in contextctx->saved_data["n"]=n;var.mul_(n);// Mark var as modified by inplace operationctx->mark_dirty({var});return{var};}staticvariable_listbackward(AutogradContext*ctx,variable_listgrad_output){// Use data saved in forwardauton=ctx->saved_data["n"].toInt();return{grad_output[0]*n};}};
To use
MyFunction:Variablex;autoy=MyFunction::apply(6,x);// Example backward cally[0].sum().backward();
Public Static Functions
Public Static Attributes
- staticconstexprboolis_traceable=false#
- staticconstexprboolis_traceable=false#