Note
Go to the endto download the full example code.
Learn the Basics ||Quickstart ||Tensors ||Datasets & DataLoaders ||Transforms ||Build Model ||Autograd ||Optimization ||Save & Load Model
Quickstart#
Created On: Feb 09, 2021 | Last Updated: Jan 24, 2025 | Last Verified: Not Verified
This section runs through the API for common tasks in machine learning. Refer to the links in each section to dive deeper.
Working with data#
PyTorch has twoprimitives to work with data:torch.utils.data.DataLoader andtorch.utils.data.Dataset.Dataset stores the samples and their corresponding labels, andDataLoader wraps an iterable aroundtheDataset.
importtorchfromtorchimportnnfromtorch.utils.dataimportDataLoaderfromtorchvisionimportdatasetsfromtorchvision.transformsimportToTensor
PyTorch offers domain-specific libraries such asTorchText,TorchVision, andTorchAudio,all of which include datasets. For this tutorial, we will be using a TorchVision dataset.
Thetorchvision.datasets module containsDataset objects for many real-world vision data likeCIFAR, COCO (full list here). In this tutorial, weuse the FashionMNIST dataset. Every TorchVisionDataset includes two arguments:transform andtarget_transform to modify the samples and labels respectively.
# Download training data from open datasets.training_data=datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),)# Download test data from open datasets.test_data=datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor(),)
0%| | 0.00/26.4M [00:00<?, ?B/s] 0%| | 65.5k/26.4M [00:00<01:12, 363kB/s] 1%| | 229k/26.4M [00:00<00:38, 681kB/s] 3%|▎ | 885k/26.4M [00:00<00:12, 2.02MB/s] 13%|█▎ | 3.47M/26.4M [00:00<00:03, 6.85MB/s] 35%|███▌ | 9.37M/26.4M [00:00<00:01, 16.1MB/s] 59%|█████▊ | 15.5M/26.4M [00:01<00:00, 22.1MB/s] 81%|████████ | 21.4M/26.4M [00:01<00:00, 25.6MB/s]100%|██████████| 26.4M/26.4M [00:01<00:00, 19.3MB/s] 0%| | 0.00/29.5k [00:00<?, ?B/s]100%|██████████| 29.5k/29.5k [00:00<00:00, 327kB/s] 0%| | 0.00/4.42M [00:00<?, ?B/s] 1%|▏ | 65.5k/4.42M [00:00<00:12, 357kB/s] 5%|▌ | 229k/4.42M [00:00<00:06, 672kB/s] 21%|██ | 918k/4.42M [00:00<00:01, 2.08MB/s] 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.17MB/s]100%|██████████| 4.42M/4.42M [00:00<00:00, 6.01MB/s] 0%| | 0.00/5.15k [00:00<?, ?B/s]100%|██████████| 5.15k/5.15k [00:00<00:00, 51.5MB/s]
We pass theDataset as an argument toDataLoader. This wraps an iterable over our dataset, and supportsautomatic batching, sampling, shuffling and multiprocess data loading. Here we define a batch size of 64, i.e. each elementin the dataloader iterable will return a batch of 64 features and labels.
batch_size=64# Create data loaders.train_dataloader=DataLoader(training_data,batch_size=batch_size)test_dataloader=DataLoader(test_data,batch_size=batch_size)forX,yintest_dataloader:print(f"Shape of X [N, C, H, W]:{X.shape}")print(f"Shape of y:{y.shape}{y.dtype}")break
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])Shape of y: torch.Size([64]) torch.int64
Read more aboutloading data in PyTorch.
Creating Models#
To define a neural network in PyTorch, we create a class that inheritsfromnn.Module. We define the layers of the networkin the__init__ function and specify how data will pass through the network in theforward function. To accelerateoperations in the neural network, we move it to theacceleratorsuch as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.
device=torch.accelerator.current_accelerator().typeiftorch.accelerator.is_available()else"cpu"print(f"Using{device} device")# Define modelclassNeuralNetwork(nn.Module):def__init__(self):super().__init__()self.flatten=nn.Flatten()self.linear_relu_stack=nn.Sequential(nn.Linear(28*28,512),nn.ReLU(),nn.Linear(512,512),nn.ReLU(),nn.Linear(512,10))defforward(self,x):x=self.flatten(x)logits=self.linear_relu_stack(x)returnlogitsmodel=NeuralNetwork().to(device)print(model)
Using cuda deviceNeuralNetwork( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ))
Read more aboutbuilding neural networks in PyTorch.
Optimizing the Model Parameters#
To train a model, we need aloss functionand anoptimizer.
In a single training loop, the model makes predictions on the training dataset (fed to it in batches), andbackpropagates the prediction error to adjust the model’s parameters.
deftrain(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)model.train()forbatch,(X,y)inenumerate(dataloader):X,y=X.to(device),y.to(device)# Compute prediction errorpred=model(X)loss=loss_fn(pred,y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()ifbatch%100==0:loss,current=loss.item(),(batch+1)*len(X)print(f"loss:{loss:>7f} [{current:>5d}/{size:>5d}]")
We also check the model’s performance against the test dataset to ensure it is learning.
deftest(dataloader,model,loss_fn):size=len(dataloader.dataset)num_batches=len(dataloader)model.eval()test_loss,correct=0,0withtorch.no_grad():forX,yindataloader:X,y=X.to(device),y.to(device)pred=model(X)test_loss+=loss_fn(pred,y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()test_loss/=num_batchescorrect/=sizeprint(f"Test Error:\n Accuracy:{(100*correct):>0.1f}%, Avg loss:{test_loss:>8f}\n")
The training process is conducted over several iterations (epochs). During each epoch, the model learnsparameters to make better predictions. We print the model’s accuracy and loss at each epoch; we’d like to see theaccuracy increase and the loss decrease with every epoch.
epochs=5fortinrange(epochs):print(f"Epoch{t+1}\n-------------------------------")train(train_dataloader,model,loss_fn,optimizer)test(test_dataloader,model,loss_fn)print("Done!")
Epoch 1-------------------------------loss: 2.303581 [ 64/60000]loss: 2.288507 [ 6464/60000]loss: 2.266801 [12864/60000]loss: 2.260141 [19264/60000]loss: 2.246968 [25664/60000]loss: 2.212795 [32064/60000]loss: 2.229716 [38464/60000]loss: 2.194052 [44864/60000]loss: 2.195681 [51264/60000]loss: 2.149111 [57664/60000]Test Error: Accuracy: 30.8%, Avg loss: 2.150481Epoch 2-------------------------------loss: 2.162689 [ 64/60000]loss: 2.149975 [ 6464/60000]loss: 2.091477 [12864/60000]loss: 2.106713 [19264/60000]loss: 2.056189 [25664/60000]loss: 1.989290 [32064/60000]loss: 2.022153 [38464/60000]loss: 1.944219 [44864/60000]loss: 1.959422 [51264/60000]loss: 1.860426 [57664/60000]Test Error: Accuracy: 54.4%, Avg loss: 1.875605Epoch 3-------------------------------loss: 1.910749 [ 64/60000]loss: 1.876096 [ 6464/60000]loss: 1.770178 [12864/60000]loss: 1.806621 [19264/60000]loss: 1.694793 [25664/60000]loss: 1.644445 [32064/60000]loss: 1.664176 [38464/60000]loss: 1.575405 [44864/60000]loss: 1.605776 [51264/60000]loss: 1.477963 [57664/60000]Test Error: Accuracy: 62.8%, Avg loss: 1.511444Epoch 4-------------------------------loss: 1.581210 [ 64/60000]loss: 1.541737 [ 6464/60000]loss: 1.405106 [12864/60000]loss: 1.466240 [19264/60000]loss: 1.348204 [25664/60000]loss: 1.346540 [32064/60000]loss: 1.351477 [38464/60000]loss: 1.287819 [44864/60000]loss: 1.324715 [51264/60000]loss: 1.210392 [57664/60000]Test Error: Accuracy: 63.7%, Avg loss: 1.243052Epoch 5-------------------------------loss: 1.324621 [ 64/60000]loss: 1.303051 [ 6464/60000]loss: 1.144375 [12864/60000]loss: 1.244700 [19264/60000]loss: 1.118607 [25664/60000]loss: 1.145195 [32064/60000]loss: 1.159526 [38464/60000]loss: 1.106616 [44864/60000]loss: 1.149438 [51264/60000]loss: 1.054472 [57664/60000]Test Error: Accuracy: 65.0%, Avg loss: 1.077903Done!
Read more aboutTraining your model.
Saving Models#
A common way to save a model is to serialize the internal state dictionary (containing the model parameters).
torch.save(model.state_dict(),"model.pth")print("Saved PyTorch Model State to model.pth")
Saved PyTorch Model State to model.pth
Loading Models#
The process for loading a model includes re-creating the model structure and loadingthe state dictionary into it.
model=NeuralNetwork().to(device)model.load_state_dict(torch.load("model.pth",weights_only=True))
<All keys matched successfully>
This model can now be used to make predictions.
classes=["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",]model.eval()x,y=test_data[0][0],test_data[0][1]withtorch.no_grad():x=x.to(device)pred=model(x)predicted,actual=classes[pred[0].argmax(0)],classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')
Predicted: "Ankle boot", Actual: "Ankle boot"
Read more aboutSaving & Loading your model.
Total running time of the script: (0 minutes 36.215 seconds)