Rate this Page

Note

Go to the endto download the full example code.

Learn the Basics ||Quickstart ||Tensors ||Datasets & DataLoaders ||Transforms ||Build Model ||Autograd ||Optimization ||Save & Load Model

Optimizing Model Parameters#

Created On: Feb 09, 2021 | Last Updated: Apr 28, 2025 | Last Verified: Nov 05, 2024

Now that we have a model and data it’s time to train, validate and test our model by optimizing its parameters onour data. Training a model is an iterative process; in each iteration the model makes a guess about the output, calculatesthe error in its guess (loss), collects the derivatives of the error with respect to its parameters (as we saw intheprevious section), andoptimizes these parameters using gradient descent. For a moredetailed walkthrough of this process, check out this video onbackpropagation from 3Blue1Brown.

Prerequisite Code#

We load the code from the previous sections onDatasets & DataLoadersandBuild Model.

importtorchfromtorchimportnnfromtorch.utils.dataimportDataLoaderfromtorchvisionimportdatasetsfromtorchvision.transformsimportToTensortraining_data=datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())test_data=datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())train_dataloader=DataLoader(training_data,batch_size=64)test_dataloader=DataLoader(test_data,batch_size=64)classNeuralNetwork(nn.Module):def__init__(self):super().__init__()self.flatten=nn.Flatten()self.linear_relu_stack=nn.Sequential(nn.Linear(28*28,512),nn.ReLU(),nn.Linear(512,512),nn.ReLU(),nn.Linear(512,10),)defforward(self,x):x=self.flatten(x)logits=self.linear_relu_stack(x)returnlogitsmodel=NeuralNetwork()
  0%|          | 0.00/26.4M [00:00<?, ?B/s]  0%|          | 65.5k/26.4M [00:00<01:12, 365kB/s]  1%|          | 229k/26.4M [00:00<00:38, 684kB/s]  3%|▎         | 918k/26.4M [00:00<00:12, 2.11MB/s] 14%|█▍        | 3.67M/26.4M [00:00<00:03, 7.29MB/s] 36%|███▌      | 9.47M/26.4M [00:00<00:00, 18.1MB/s] 48%|████▊     | 12.8M/26.4M [00:00<00:00, 19.7MB/s] 70%|███████   | 18.5M/26.4M [00:01<00:00, 26.8MB/s] 84%|████████▎ | 22.1M/26.4M [00:01<00:00, 26.3MB/s]100%|██████████| 26.4M/26.4M [00:01<00:00, 19.4MB/s]  0%|          | 0.00/29.5k [00:00<?, ?B/s]100%|██████████| 29.5k/29.5k [00:00<00:00, 326kB/s]  0%|          | 0.00/4.42M [00:00<?, ?B/s]  1%|▏         | 65.5k/4.42M [00:00<00:12, 361kB/s]  5%|▌         | 229k/4.42M [00:00<00:06, 681kB/s] 21%|██        | 918k/4.42M [00:00<00:01, 2.10MB/s] 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.27MB/s]100%|██████████| 4.42M/4.42M [00:00<00:00, 6.08MB/s]  0%|          | 0.00/5.15k [00:00<?, ?B/s]100%|██████████| 5.15k/5.15k [00:00<00:00, 45.6MB/s]

Hyperparameters#

Hyperparameters are adjustable parameters that let you control the model optimization process.Different hyperparameter values can impact model training and convergence rates(read more about hyperparameter tuning)

We define the following hyperparameters for training:
  • Number of Epochs - the number of times to iterate over the dataset

  • Batch Size - the number of data samples propagated through the network before the parameters are updated

  • Learning Rate - how much to update models parameters at each batch/epoch. Smaller values yield slow learning speed, while large values may result in unpredictable behavior during training.

learning_rate=1e-3batch_size=64epochs=5

Optimization Loop#

Once we set our hyperparameters, we can then train and optimize our model with an optimization loop. Eachiteration of the optimization loop is called anepoch.

Each epoch consists of two main parts:
  • The Train Loop - iterate over the training dataset and try to converge to optimal parameters.

  • The Validation/Test Loop - iterate over the test dataset to check if model performance is improving.

Let’s briefly familiarize ourselves with some of the concepts used in the training loop. Jump ahead tosee theFull Implementation of the optimization loop.

Loss Function#

When presented with some training data, our untrained network is likely not to give the correctanswer.Loss function measures the degree of dissimilarity of obtained result to the target value,and it is the loss function that we want to minimize during training. To calculate the loss we make aprediction using the inputs of our given data sample and compare it against the true data label value.

Common loss functions includenn.MSELoss (Mean Square Error) for regression tasks, andnn.NLLLoss (Negative Log Likelihood) for classification.nn.CrossEntropyLoss combinesnn.LogSoftmax andnn.NLLLoss.

We pass our model’s output logits tonn.CrossEntropyLoss, which will normalize the logits and compute the prediction error.

# Initialize the loss functionloss_fn=nn.CrossEntropyLoss()

Optimizer#

Optimization is the process of adjusting model parameters to reduce model error in each training step.Optimization algorithms define how this process is performed (in this example we use Stochastic Gradient Descent).All optimization logic is encapsulated in theoptimizer object. Here, we use the SGD optimizer; additionally, there are manydifferent optimizersavailable in PyTorch such as ADAM and RMSProp, that work better for different kinds of models and data.

We initialize the optimizer by registering the model’s parameters that need to be trained, and passing in the learning rate hyperparameter.

Inside the training loop, optimization happens in three steps:
  • Calloptimizer.zero_grad() to reset the gradients of model parameters. Gradients by default add up; to prevent double-counting, we explicitly zero them at each iteration.

  • Backpropagate the prediction loss with a call toloss.backward(). PyTorch deposits the gradients of the loss w.r.t. each parameter.

  • Once we have our gradients, we calloptimizer.step() to adjust the parameters by the gradients collected in the backward pass.

Full Implementation#

We definetrain_loop that loops over our optimization code, andtest_loop thatevaluates the model’s performance against our test data.

deftrain_loop(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)# Set the model to training mode - important for batch normalization and dropout layers# Unnecessary in this situation but added for best practicesmodel.train()forbatch,(X,y)inenumerate(dataloader):# Compute prediction and losspred=model(X)loss=loss_fn(pred,y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()ifbatch%100==0:loss,current=loss.item(),batch*batch_size+len(X)print(f"loss:{loss:>7f}  [{current:>5d}/{size:>5d}]")deftest_loop(dataloader,model,loss_fn):# Set the model to evaluation mode - important for batch normalization and dropout layers# Unnecessary in this situation but added for best practicesmodel.eval()size=len(dataloader.dataset)num_batches=len(dataloader)test_loss,correct=0,0# Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode# also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=Truewithtorch.no_grad():forX,yindataloader:pred=model(X)test_loss+=loss_fn(pred,y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()test_loss/=num_batchescorrect/=sizeprint(f"Test Error:\n Accuracy:{(100*correct):>0.1f}%, Avg loss:{test_loss:>8f}\n")

We initialize the loss function and optimizer, and pass it totrain_loop andtest_loop.Feel free to increase the number of epochs to track the model’s improving performance.

loss_fn=nn.CrossEntropyLoss()optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)epochs=10fortinrange(epochs):print(f"Epoch{t+1}\n-------------------------------")train_loop(train_dataloader,model,loss_fn,optimizer)test_loop(test_dataloader,model,loss_fn)print("Done!")
Epoch 1-------------------------------loss: 2.310727  [   64/60000]loss: 2.293469  [ 6464/60000]loss: 2.276914  [12864/60000]loss: 2.268131  [19264/60000]loss: 2.248348  [25664/60000]loss: 2.215117  [32064/60000]loss: 2.231262  [38464/60000]loss: 2.197844  [44864/60000]loss: 2.194369  [51264/60000]loss: 2.163816  [57664/60000]Test Error: Accuracy: 42.2%, Avg loss: 2.153352Epoch 2-------------------------------loss: 2.167751  [   64/60000]loss: 2.152363  [ 6464/60000]loss: 2.089968  [12864/60000]loss: 2.116054  [19264/60000]loss: 2.060365  [25664/60000]loss: 1.987720  [32064/60000]loss: 2.039777  [38464/60000]loss: 1.954339  [44864/60000]loss: 1.958542  [51264/60000]loss: 1.901452  [57664/60000]Test Error: Accuracy: 55.4%, Avg loss: 1.885695Epoch 3-------------------------------loss: 1.914735  [   64/60000]loss: 1.879721  [ 6464/60000]loss: 1.756200  [12864/60000]loss: 1.821928  [19264/60000]loss: 1.703173  [25664/60000]loss: 1.641386  [32064/60000]loss: 1.691698  [38464/60000]loss: 1.586060  [44864/60000]loss: 1.605827  [51264/60000]loss: 1.523651  [57664/60000]Test Error: Accuracy: 63.3%, Avg loss: 1.524184Epoch 4-------------------------------loss: 1.580570  [   64/60000]loss: 1.547215  [ 6464/60000]loss: 1.392179  [12864/60000]loss: 1.483722  [19264/60000]loss: 1.361350  [25664/60000]loss: 1.338133  [32064/60000]loss: 1.373251  [38464/60000]loss: 1.297191  [44864/60000]loss: 1.320179  [51264/60000]loss: 1.238795  [57664/60000]Test Error: Accuracy: 64.4%, Avg loss: 1.253644Epoch 5-------------------------------loss: 1.319207  [   64/60000]loss: 1.307838  [ 6464/60000]loss: 1.138406  [12864/60000]loss: 1.255827  [19264/60000]loss: 1.134329  [25664/60000]loss: 1.133734  [32064/60000]loss: 1.173040  [38464/60000]loss: 1.113333  [44864/60000]loss: 1.141338  [51264/60000]loss: 1.071028  [57664/60000]Test Error: Accuracy: 65.3%, Avg loss: 1.083599Epoch 6-------------------------------loss: 1.144013  [   64/60000]loss: 1.154895  [ 6464/60000]loss: 0.969636  [12864/60000]loss: 1.112832  [19264/60000]loss: 0.993944  [25664/60000]loss: 0.996532  [32064/60000]loss: 1.050147  [38464/60000]loss: 0.995274  [44864/60000]loss: 1.025120  [51264/60000]loss: 0.967031  [57664/60000]Test Error: Accuracy: 66.3%, Avg loss: 0.974268Epoch 7-------------------------------loss: 1.022512  [   64/60000]loss: 1.056700  [ 6464/60000]loss: 0.854182  [12864/60000]loss: 1.018504  [19264/60000]loss: 0.906160  [25664/60000]loss: 0.901887  [32064/60000]loss: 0.971663  [38464/60000]loss: 0.919038  [44864/60000]loss: 0.946074  [51264/60000]loss: 0.898707  [57664/60000]Test Error: Accuracy: 67.4%, Avg loss: 0.900728Epoch 8-------------------------------loss: 0.934186  [   64/60000]loss: 0.989839  [ 6464/60000]loss: 0.772073  [12864/60000]loss: 0.952816  [19264/60000]loss: 0.848014  [25664/60000]loss: 0.834285  [32064/60000]loss: 0.918051  [38464/60000]loss: 0.868596  [44864/60000]loss: 0.889919  [51264/60000]loss: 0.850770  [57664/60000]Test Error: Accuracy: 68.7%, Avg loss: 0.848627Epoch 9-------------------------------loss: 0.867301  [   64/60000]loss: 0.940487  [ 6464/60000]loss: 0.711356  [12864/60000]loss: 0.904958  [19264/60000]loss: 0.806926  [25664/60000]loss: 0.784890  [32064/60000]loss: 0.878383  [38464/60000]loss: 0.833803  [44864/60000]loss: 0.848429  [51264/60000]loss: 0.814703  [57664/60000]Test Error: Accuracy: 70.0%, Avg loss: 0.809665Epoch 10-------------------------------loss: 0.814660  [   64/60000]loss: 0.901449  [ 6464/60000]loss: 0.664618  [12864/60000]loss: 0.868369  [19264/60000]loss: 0.776012  [25664/60000]loss: 0.747600  [32064/60000]loss: 0.846622  [38464/60000]loss: 0.808364  [44864/60000]loss: 0.816517  [51264/60000]loss: 0.785885  [57664/60000]Test Error: Accuracy: 71.4%, Avg loss: 0.778833Done!

Further Reading#

Total running time of the script: (1 minutes 14.357 seconds)