Rate this Page

Note

Go to the endto download the full example code.

Learn the Basics ||Quickstart ||Tensors ||Datasets & DataLoaders ||Transforms ||Build Model ||Autograd ||Optimization ||Save & Load Model

Build the Neural Network#

Created On: Feb 09, 2021 | Last Updated: Jan 24, 2025 | Last Verified: Not Verified

Neural networks comprise of layers/modules that perform operations on data.Thetorch.nn namespace provides all the building blocks you need tobuild your own neural network. Every module in PyTorch subclasses thenn.Module.A neural network is a module itself that consists of other modules (layers). This nested structure allows forbuilding and managing complex architectures easily.

In the following sections, we’ll build a neural network to classify images in the FashionMNIST dataset.

importosimporttorchfromtorchimportnnfromtorch.utils.dataimportDataLoaderfromtorchvisionimportdatasets,transforms

Get Device for Training#

We want to be able to train our model on anacceleratorsuch as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.

device=torch.accelerator.current_accelerator().typeiftorch.accelerator.is_available()else"cpu"print(f"Using{device} device")
Using cuda device

Define the Class#

We define our neural network by subclassingnn.Module, andinitialize the neural network layers in__init__. Everynn.Module subclass implementsthe operations on input data in theforward method.

classNeuralNetwork(nn.Module):def__init__(self):super().__init__()self.flatten=nn.Flatten()self.linear_relu_stack=nn.Sequential(nn.Linear(28*28,512),nn.ReLU(),nn.Linear(512,512),nn.ReLU(),nn.Linear(512,10),)defforward(self,x):x=self.flatten(x)logits=self.linear_relu_stack(x)returnlogits

We create an instance ofNeuralNetwork, and move it to thedevice, and printits structure.

model=NeuralNetwork().to(device)print(model)
NeuralNetwork(  (flatten): Flatten(start_dim=1, end_dim=-1)  (linear_relu_stack): Sequential(    (0): Linear(in_features=784, out_features=512, bias=True)    (1): ReLU()    (2): Linear(in_features=512, out_features=512, bias=True)    (3): ReLU()    (4): Linear(in_features=512, out_features=10, bias=True)  ))

To use the model, we pass it the input data. This executes the model’sforward,along with somebackground operations.Do not callmodel.forward() directly!

Calling the model on the input returns a 2-dimensional tensor with dim=0 corresponding to each output of 10 raw predicted values for each class, and dim=1 corresponding to the individual values of each output.We get the prediction probabilities by passing it through an instance of thenn.Softmax module.

X=torch.rand(1,28,28,device=device)logits=model(X)pred_probab=nn.Softmax(dim=1)(logits)y_pred=pred_probab.argmax(1)print(f"Predicted class:{y_pred}")
Predicted class: tensor([6], device='cuda:0')

Model Layers#

Let’s break down the layers in the FashionMNIST model. To illustrate it, wewill take a sample minibatch of 3 images of size 28x28 and see what happens to it aswe pass it through the network.

input_image=torch.rand(3,28,28)print(input_image.size())
torch.Size([3, 28, 28])

nn.Flatten#

We initialize thenn.Flattenlayer to convert each 2D 28x28 image into a contiguous array of 784 pixel values (the minibatch dimension (at dim=0) is maintained).

torch.Size([3, 784])

nn.Linear#

Thelinear layeris a module that applies a linear transformation on the input using its stored weights and biases.

layer1=nn.Linear(in_features=28*28,out_features=20)hidden1=layer1(flat_image)print(hidden1.size())
torch.Size([3, 20])

nn.ReLU#

Non-linear activations are what create the complex mappings between the model’s inputs and outputs.They are applied after linear transformations to introducenonlinearity, helping neural networkslearn a wide variety of phenomena.

In this model, we usenn.ReLU between ourlinear layers, but there’s other activations to introduce non-linearity in your model.

print(f"Before ReLU:{hidden1}\n\n")hidden1=nn.ReLU()(hidden1)print(f"After ReLU:{hidden1}")
Before ReLU: tensor([[ 0.0794, -0.2958,  0.2315, -0.3651, -0.1669, -0.0367, -0.1708, -0.1830,          0.3132, -0.0524, -0.2309, -0.0977, -0.1087, -0.2144, -0.0453, -0.5924,          0.0176, -0.2197, -0.3919, -0.5935],        [ 0.0810,  0.0553,  0.2831, -0.1680, -0.3074, -0.0408, -0.1673, -0.0401,          0.2754,  0.0323,  0.0523,  0.1238,  0.2902,  0.2540,  0.2785, -0.6839,          0.1530, -0.5762, -0.3334, -0.6541],        [ 0.2876,  0.3310,  0.0731,  0.0408,  0.0061, -0.1462, -0.3694, -0.1568,          0.4456, -0.0628, -0.0456,  0.1272,  0.2312,  0.0054,  0.3196, -0.3090,          0.1041, -0.3894, -0.0039, -0.7682]], grad_fn=<AddmmBackward0>)After ReLU: tensor([[0.0794, 0.0000, 0.2315, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3132,         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0176, 0.0000,         0.0000, 0.0000],        [0.0810, 0.0553, 0.2831, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2754,         0.0323, 0.0523, 0.1238, 0.2902, 0.2540, 0.2785, 0.0000, 0.1530, 0.0000,         0.0000, 0.0000],        [0.2876, 0.3310, 0.0731, 0.0408, 0.0061, 0.0000, 0.0000, 0.0000, 0.4456,         0.0000, 0.0000, 0.1272, 0.2312, 0.0054, 0.3196, 0.0000, 0.1041, 0.0000,         0.0000, 0.0000]], grad_fn=<ReluBackward0>)

nn.Sequential#

nn.Sequential is an orderedcontainer of modules. The data is passed through all the modules in the same order as defined. You can usesequential containers to put together a quick network likeseq_modules.

nn.Softmax#

The last linear layer of the neural network returnslogits - raw values in [-infty, infty] - which are passed to thenn.Softmax module. The logits are scaled to values[0, 1] representing the model’s predicted probabilities for each class.dim parameter indicates the dimension alongwhich the values must sum to 1.

Model Parameters#

Many layers inside a neural network areparameterized, i.e. have associated weightsand biases that are optimized during training. Subclassingnn.Module automaticallytracks all fields defined inside your model object, and makes all parametersaccessible using your model’sparameters() ornamed_parameters() methods.

In this example, we iterate over each parameter, and print its size and a preview of its values.

print(f"Model structure:{model}\n\n")forname,paraminmodel.named_parameters():print(f"Layer:{name} | Size:{param.size()} | Values :{param[:2]}\n")
Model structure: NeuralNetwork(  (flatten): Flatten(start_dim=1, end_dim=-1)  (linear_relu_stack): Sequential(    (0): Linear(in_features=784, out_features=512, bias=True)    (1): ReLU()    (2): Linear(in_features=512, out_features=512, bias=True)    (3): ReLU()    (4): Linear(in_features=512, out_features=10, bias=True)  ))Layer: linear_relu_stack.0.weight | Size: torch.Size([512, 784]) | Values : tensor([[ 0.0225, -0.0159, -0.0165,  ..., -0.0131,  0.0188,  0.0095],        [-0.0151,  0.0143, -0.0073,  ...,  0.0034, -0.0109,  0.0252]],       device='cuda:0', grad_fn=<SliceBackward0>)Layer: linear_relu_stack.0.bias | Size: torch.Size([512]) | Values : tensor([ 1.5040e-03, -6.6417e-06], device='cuda:0', grad_fn=<SliceBackward0>)Layer: linear_relu_stack.2.weight | Size: torch.Size([512, 512]) | Values : tensor([[ 0.0314, -0.0190, -0.0007,  ...,  0.0144, -0.0253, -0.0336],        [-0.0358,  0.0239, -0.0153,  ...,  0.0169,  0.0122, -0.0136]],       device='cuda:0', grad_fn=<SliceBackward0>)Layer: linear_relu_stack.2.bias | Size: torch.Size([512]) | Values : tensor([-0.0085, -0.0340], device='cuda:0', grad_fn=<SliceBackward0>)Layer: linear_relu_stack.4.weight | Size: torch.Size([10, 512]) | Values : tensor([[-0.0035, -0.0396, -0.0116,  ...,  0.0267,  0.0337, -0.0235],        [-0.0015,  0.0171,  0.0347,  ..., -0.0136,  0.0214,  0.0150]],       device='cuda:0', grad_fn=<SliceBackward0>)Layer: linear_relu_stack.4.bias | Size: torch.Size([10]) | Values : tensor([0.0297, 0.0366], device='cuda:0', grad_fn=<SliceBackward0>)

Further Reading#

Total running time of the script: (0 minutes 0.520 seconds)