Note
Go to the endto download the full example code.
Zeroing out gradients in PyTorch#
Created On: Apr 20, 2020 | Last Updated: Apr 28, 2025 | Last Verified: Nov 05, 2024
It is beneficial to zero out gradients when building a neural network.This is because by default, gradients are accumulated in buffers (i.e,not overwritten) whenever.backward() is called.
Introduction#
When training your neural network, models are able to increase theiraccuracy through gradient descent. In short, gradient descent is theprocess of minimizing our loss (or error) by tweaking the weights andbiases in our model.
torch.Tensor is the central class of PyTorch. When you create atensor, if you set its attribute.requires_grad asTrue, thepackage tracks all operations on it. This happens on subsequent backwardpasses. The gradient for this tensor will be accumulated into.gradattribute. The accumulation (or sum) of all the gradients is calculatedwhen .backward() is called on the loss tensor.
There are cases where it may be necessary to zero-out the gradients of atensor. For example: when you start your training loop, you should zeroout the gradients so that you can perform this tracking correctly.In this recipe, we will learn how to zero out gradients using thePyTorch library. We will demonstrate how to do this by training a neuralnetwork on theCIFAR10 dataset built into PyTorch.
Setup#
Since we will be training data in this recipe, if you are in a runnablenotebook, it is best to switch the runtime to GPU or TPU.Before we begin, we need to installtorch andtorchvision ifthey aren’t already available.
pipinstalltorchvision
Steps#
Steps 1 through 4 set up our data and neural network for training. Theprocess of zeroing out the gradients happens in step 5. If you alreadyhave your data and neural network built, skip to 5.
Import all necessary libraries for loading our data
Load and normalize the dataset
Build the neural network
Define the loss function
Zero the gradients while training the network
1. Import necessary libraries for loading our data#
For this recipe, we will just be usingtorch andtorchvision toaccess the dataset.
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFimporttorch.optimasoptimimporttorchvisionimporttorchvision.transformsastransforms
2. Load and normalize the dataset#
PyTorch features various built-in datasets (see the Loading Data recipefor more information).
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2)classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
3. Build the neural network#
We will use a convolutional neural network. To learn more see theDefining a Neural Network recipe.
classNet(nn.Module):def__init__(self):super(Net,self).__init__()self.conv1=nn.Conv2d(3,6,5)self.pool=nn.MaxPool2d(2,2)self.conv2=nn.Conv2d(6,16,5)self.fc1=nn.Linear(16*5*5,120)self.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)defforward(self,x):x=self.pool(F.relu(self.conv1(x)))x=self.pool(F.relu(self.conv2(x)))x=x.view(-1,16*5*5)x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)returnx
4. Define a Loss function and optimizer#
Let’s use a Classification Cross-Entropy loss and SGD with momentum.
net=Net()criterion=nn.CrossEntropyLoss()optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9)
5. Zero the gradients while training the network#
This is when things start to get interesting. We simply have to loopover our data iterator, and feed the inputs to the network and optimize.
Notice that for each entity of data, we zero out the gradients. This isto ensure that we aren’t tracking any unnecessary information when wetrain our neural network.
forepochinrange(2):# loop over the dataset multiple timesrunning_loss=0.0fori,datainenumerate(trainloader,0):# get the inputs; data is a list of [inputs, labels]inputs,labels=data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs=net(inputs)loss=criterion(outputs,labels)loss.backward()optimizer.step()# print statisticsrunning_loss+=loss.item()ifi%2000==1999:# print every 2000 mini-batchesprint('[%d,%5d] loss:%.3f'%(epoch+1,i+1,running_loss/2000))running_loss=0.0print('Finished Training')
You can also usemodel.zero_grad(). This is the same as usingoptimizer.zero_grad() as long as all your model parameters are inthat optimizer. Use your best judgment to decide which one to use.
Congratulations! You have successfully zeroed out gradients in PyTorch.
Learn More#
Take a look at these other recipes to continue your learning: