Rate this Page

Autograd in C++ Frontend#

Created On: Apr 01, 2020 | Last Updated: Jan 21, 2025 | Last Verified: Not Verified

Theautograd package is crucial for building highly flexible and dynamic neuralnetworks in PyTorch. Most of the autograd APIs in PyTorch Python frontend are also availablein C++ frontend, allowing easy translation of autograd code from Python to C++.

In this tutorial explore several examples of doing autograd in PyTorch C++ frontend.Note that this tutorial assumes that you already have a basic understanding ofautograd in Python frontend. If that’s not the case, please first readAutograd: Automatic Differentiation.

Basic autograd operations#

(Adapted fromthis tutorial)

Create a tensor and settorch::requires_grad() to track computation with it

autox=torch::ones({2,2},torch::requires_grad());std::cout<<x<<std::endl;

Out:

1111[CPUFloatType{2,2}]

Do a tensor operation:

autoy=x+2;std::cout<<y<<std::endl;

Out:

3333[CPUFloatType{2,2}]

y was created as a result of an operation, so it has agrad_fn.

std::cout<<y.grad_fn()->name()<<std::endl;

Out:

AddBackward1

Do more operations ony

autoz=y*y*3;autoout=z.mean();std::cout<<z<<std::endl;std::cout<<z.grad_fn()->name()<<std::endl;std::cout<<out<<std::endl;std::cout<<out.grad_fn()->name()<<std::endl;

Out:

27272727[CPUFloatType{2,2}]MulBackward127[CPUFloatType{}]MeanBackward0

.requires_grad_(...) changes an existing tensor’srequires_grad flag in-place.

autoa=torch::randn({2,2});a=((a*3)/(a-1));std::cout<<a.requires_grad()<<std::endl;a.requires_grad_(true);std::cout<<a.requires_grad()<<std::endl;autob=(a*a).sum();std::cout<<b.grad_fn()->name()<<std::endl;

Out:

falsetrueSumBackward0

Let’s backprop now. Becauseout contains a single scalar,out.backward()is equivalent toout.backward(torch::tensor(1.)).

out.backward();

Print gradients d(out)/dx

std::cout<<x.grad()<<std::endl;

Out:

4.50004.50004.50004.5000[CPUFloatType{2,2}]

You should have got a matrix of4.5. For explanations on how we arrive at this value,please seethe corresponding section in this tutorial.

Now let’s take a look at an example of vector-Jacobian product:

x=torch::randn(3,torch::requires_grad());y=x*2;while(y.norm().item<double>()<1000){y=y*2;}std::cout<<y<<std::endl;std::cout<<y.grad_fn()->name()<<std::endl;

Out:

-1021.4020314.6695-613.4944[CPUFloatType{3}]MulBackward1

If we want the vector-Jacobian product, pass the vector tobackward as argument:

autov=torch::tensor({0.1,1.0,0.0001},torch::kFloat);y.backward(v);std::cout<<x.grad()<<std::endl;

Out:

102.40001024.00000.1024[CPUFloatType{3}]

You can also stop autograd from tracking history on tensors that require gradientseither by puttingtorch::NoGradGuard in a code block

std::cout<<x.requires_grad()<<std::endl;std::cout<<x.pow(2).requires_grad()<<std::endl;{torch::NoGradGuardno_grad;std::cout<<x.pow(2).requires_grad()<<std::endl;}

Out:

truetruefalse

Or by using.detach() to get a new tensor with the same content but that doesnot require gradients:

std::cout<<x.requires_grad()<<std::endl;y=x.detach();std::cout<<y.requires_grad()<<std::endl;std::cout<<x.eq(y).all().item<bool>()<<std::endl;

Out:

truefalsetrue

For more information on C++ tensor autograd APIs such asgrad /requires_grad /is_leaf /backward /detach /detach_ /register_hook /retain_grad,please seethe corresponding C++ API docs.

Computing higher-order gradients in C++#

One of the applications of higher-order gradients is calculating gradient penalty.Let’s see an example of it usingtorch::autograd::grad:

#include<torch/torch.h>automodel=torch::nn::Linear(4,3);autoinput=torch::randn({3,4}).requires_grad_(true);autooutput=model(input);// Calculate lossautotarget=torch::randn({3,3});autoloss=torch::nn::MSELoss()(output,target);// Use norm of gradients as penaltyautograd_output=torch::ones_like(output);autogradient=torch::autograd::grad({output},{input},/*grad_outputs=*/{grad_output},/*create_graph=*/true)[0];autogradient_penalty=torch::pow((gradient.norm(2,/*dim=*/1)-1),2).mean();// Add gradient penalty to lossautocombined_loss=loss+gradient_penalty;combined_loss.backward();std::cout<<input.grad()<<std::endl;

Out:

-0.1042-0.06380.01030.0723-0.2543-0.12220.00710.0814-0.1683-0.10520.03550.1024[CPUFloatType{3,4}]

Please see the documentation fortorch::autograd::backward(link)andtorch::autograd::grad(link)for more information on how to use them.

Using custom autograd function in C++#

(Adapted fromthis tutorial)

Adding a new elementary operation totorch::autograd requires implementing a newtorch::autograd::Functionsubclass for each operation.torch::autograd::Function s are whattorch::autograduses to compute the results and gradients, and encode the operation history. Everynew function requires you to implement 2 methods:forward andbackward, andplease seethis linkfor the detailed requirements.

Below you can find code for aLinear function fromtorch::nn:

#include<torch/torch.h>usingnamespacetorch::autograd;// Inherit from FunctionclassLinearFunction:publicFunction<LinearFunction>{public:// Note that both forward and backward are static functions// bias is an optional argumentstatictorch::Tensorforward(AutogradContext*ctx,torch::Tensorinput,torch::Tensorweight,torch::Tensorbias=torch::Tensor()){ctx->save_for_backward({input,weight,bias});autooutput=input.mm(weight.t());if(bias.defined()){output+=bias.unsqueeze(0).expand_as(output);}returnoutput;}statictensor_listbackward(AutogradContext*ctx,tensor_listgrad_outputs){autosaved=ctx->get_saved_variables();autoinput=saved[0];autoweight=saved[1];autobias=saved[2];autograd_output=grad_outputs[0];autograd_input=grad_output.mm(weight);autograd_weight=grad_output.t().mm(input);autograd_bias=torch::Tensor();if(bias.defined()){grad_bias=grad_output.sum(0);}return{grad_input,grad_weight,grad_bias};}};

Then, we can use theLinearFunction in the following way:

autox=torch::randn({2,3}).requires_grad_();autoweight=torch::randn({4,3}).requires_grad_();autoy=LinearFunction::apply(x,weight);y.sum().backward();std::cout<<x.grad()<<std::endl;std::cout<<weight.grad()<<std::endl;

Out:

0.53141.28071.48640.53141.28071.4864[CPUFloatType{2,3}]3.76080.91010.00733.76080.91010.00733.76080.91010.00733.76080.91010.0073[CPUFloatType{4,3}]

Here, we give an additional example of a function that is parametrized by non-tensor arguments:

#include<torch/torch.h>usingnamespacetorch::autograd;classMulConstant:publicFunction<MulConstant>{public:statictorch::Tensorforward(AutogradContext*ctx,torch::Tensortensor,doubleconstant){// ctx is a context object that can be used to stash information// for backward computationctx->saved_data["constant"]=constant;returntensor*constant;}statictensor_listbackward(AutogradContext*ctx,tensor_listgrad_outputs){// We return as many input gradients as there were arguments.// Gradients of non-tensor arguments to forward must be `torch::Tensor()`.return{grad_outputs[0]*ctx->saved_data["constant"].toDouble(),torch::Tensor()};}};

Then, we can use theMulConstant in the following way:

autox=torch::randn({2}).requires_grad_();autoy=MulConstant::apply(x,5.5);y.sum().backward();std::cout<<x.grad()<<std::endl;

Out:

5.50005.5000[CPUFloatType{2}]

For more information ontorch::autograd::Function, please seeits documentation.

Translating autograd code from Python to C++#

On a high level, the easiest way to use autograd in C++ is to have workingautograd code in Python first, and then translate your autograd code from Python toC++ using the following table:

Python

C++

torch.autograd.backward

torch::autograd::backward (link)

torch.autograd.grad

torch::autograd::grad (link)

torch.Tensor.detach

torch::Tensor::detach (link)

torch.Tensor.detach_

torch::Tensor::detach_ (link)

torch.Tensor.backward

torch::Tensor::backward (link)

torch.Tensor.register_hook

torch::Tensor::register_hook (link)

torch.Tensor.requires_grad

torch::Tensor::requires_grad_ (link)

torch.Tensor.retain_grad

torch::Tensor::retain_grad (link)

torch.Tensor.grad

torch::Tensor::grad (link)

torch.Tensor.grad_fn

torch::Tensor::grad_fn (link)

torch.Tensor.set_data

torch::Tensor::set_data (link)

torch.Tensor.data

torch::Tensor::data (link)

torch.Tensor.output_nr

torch::Tensor::output_nr (link)

torch.Tensor.is_leaf

torch::Tensor::is_leaf (link)

After translation, most of your Python autograd code should just work in C++.If that’s not the case, please file a bug report atGitHub issuesand we will fix it as soon as possible.

Conclusion#

You should now have a good overview of PyTorch’s C++ autograd API.You can find the code examples displayed in this notehere. As always, if you run into anyproblems or have questions, you can use ourforumorGitHub issues to get in touch.