Rate this Page

Learning PyTorch with Examples#

Created On: Mar 24, 2017 | Last Updated: Sep 29, 2025 | Last Verified: Nov 05, 2024

Author:Justin Johnson

Note

This is one of our older PyTorch tutorials. You can view our latestbeginner content inLearn the Basics.

This tutorial introduces the fundamental concepts ofPyTorch through self-containedexamples.

At its core, PyTorch provides two main features:

  • An n-dimensional Tensor, similar to numpy but can run on GPUs

  • Automatic differentiation for building and training neural networks

We will use a problem of fitting\(y=\sin(x)\) with a third order polynomialas our running example. The network will have four parameters, and will be trained withgradient descent to fit random data by minimizing the Euclidean distancebetween the network output and the true output.

Note

You can browse the individual examples at theend of this page.

To run the tutorials below, make sure you have thetorchandnumpy packages installed.

Tensors#

Warm-up: numpy#

Before introducing PyTorch, we will first implement the network usingnumpy.

Numpy provides an n-dimensional array object, and many functions formanipulating these arrays. Numpy is a generic framework for scientificcomputing; it does not know anything about computation graphs, or deeplearning, or gradients. However we can easily use numpy to fit athird order polynomial to sine function by manually implementing the forwardand backward passes through the network using numpy operations:

# -*- coding: utf-8 -*-importnumpyasnpimportmath# Create random input and output datax=np.linspace(-math.pi,math.pi,2000)y=np.sin(x)# Randomly initialize weightsa=np.random.randn()b=np.random.randn()c=np.random.randn()d=np.random.randn()learning_rate=1e-6fortinrange(2000):# Forward pass: compute predicted y# y = a + b x + c x^2 + d x^3y_pred=a+b*x+c*x**2+d*x**3# Compute and print lossloss=np.square(y_pred-y).sum()ift%100==99:print(t,loss)# Backprop to compute gradients of a, b, c, d with respect to lossgrad_y_pred=2.0*(y_pred-y)grad_a=grad_y_pred.sum()grad_b=(grad_y_pred*x).sum()grad_c=(grad_y_pred*x**2).sum()grad_d=(grad_y_pred*x**3).sum()# Update weightsa-=learning_rate*grad_ab-=learning_rate*grad_bc-=learning_rate*grad_cd-=learning_rate*grad_dprint(f'Result: y ={a} +{b} x +{c} x^2 +{d} x^3')

PyTorch: Tensors#

Numpy is a great framework, but it cannot utilize GPUs to accelerate itsnumerical computations. For modern deep neural networks, GPUs oftenprovide speedups of50x or greater, sounfortunately numpy won’t be enough for modern deep learning.

Here we introduce the most fundamental PyTorch concept: theTensor.A PyTorch Tensor is conceptually identical to a numpy array: a Tensor isan n-dimensional array, and PyTorch provides many functions foroperating on these Tensors. Behind the scenes, Tensors can keep track ofa computational graph and gradients, but they’re also useful as ageneric tool for scientific computing.

Also unlike numpy, PyTorch Tensors can utilize GPUs to acceleratetheir numeric computations. To run a PyTorch Tensor on GPU, you simplyneed to specify the correct device.

Here we use PyTorch Tensors to fit a third order polynomial to sine function.Like the numpy example above we need to manually implement the forwardand backward passes through the network:

# -*- coding: utf-8 -*-importtorchimportmathdtype=torch.floatdevice=torch.device("cpu")# device = torch.device("cuda:0") # Uncomment this to run on GPU# Create random input and output datax=torch.linspace(-math.pi,math.pi,2000,device=device,dtype=dtype)y=torch.sin(x)# Randomly initialize weightsa=torch.randn((),device=device,dtype=dtype)b=torch.randn((),device=device,dtype=dtype)c=torch.randn((),device=device,dtype=dtype)d=torch.randn((),device=device,dtype=dtype)learning_rate=1e-6fortinrange(2000):# Forward pass: compute predicted yy_pred=a+b*x+c*x**2+d*x**3# Compute and print lossloss=(y_pred-y).pow(2).sum().item()ift%100==99:print(t,loss)# Backprop to compute gradients of a, b, c, d with respect to lossgrad_y_pred=2.0*(y_pred-y)grad_a=grad_y_pred.sum()grad_b=(grad_y_pred*x).sum()grad_c=(grad_y_pred*x**2).sum()grad_d=(grad_y_pred*x**3).sum()# Update weights using gradient descenta-=learning_rate*grad_ab-=learning_rate*grad_bc-=learning_rate*grad_cd-=learning_rate*grad_dprint(f'Result: y ={a.item()} +{b.item()} x +{c.item()} x^2 +{d.item()} x^3')

Autograd#

PyTorch: Tensors and autograd#

In the above examples, we had to manually implement both the forward andbackward passes of our neural network. Manually implementing thebackward pass is not a big deal for a small two-layer network, but canquickly get very hairy for large complex networks.

Thankfully, we can useautomaticdifferentiationto automate the computation of backward passes in neural networks. Theautograd package in PyTorch provides exactly this functionality.When using autograd, the forward pass of your network will define acomputational graph; nodes in the graph will be Tensors, and edgeswill be functions that produce output Tensors from input Tensors.Backpropagating through this graph then allows you to easily computegradients.

This sounds complicated, it’s pretty simple to use in practice. Each Tensorrepresents a node in a computational graph. Ifx is a Tensor that hasx.requires_grad=True thenx.grad is another Tensor holding thegradient ofx with respect to some scalar value.

Here we use PyTorch Tensors and autograd to implement our fitting sine wavewith third order polynomial example; now we no longer need to manuallyimplement the backward pass through the network:

importtorchimportmath# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__# such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.dtype=torch.floatdevice=torch.accelerator.current_accelerator().typeiftorch.accelerator.is_available()else"cpu"print(f"Using{device} device")torch.set_default_device(device)# Create Tensors to hold input and outputs.# By default, requires_grad=False, which indicates that we do not need to# compute gradients with respect to these Tensors during the backward pass.x=torch.linspace(-1,1,2000,dtype=dtype)y=torch.exp(x)# A Taylor expansion would be 1 + x + (1/2) x**2 + (1/3!) x**3 + ...# Create random Tensors for weights. For a third order polynomial, we need# 4 weights: y = a + b x + c x^2 + d x^3# Setting requires_grad=True indicates that we want to compute gradients with# respect to these Tensors during the backward pass.a=torch.randn((),dtype=dtype,requires_grad=True)b=torch.randn((),dtype=dtype,requires_grad=True)c=torch.randn((),dtype=dtype,requires_grad=True)d=torch.randn((),dtype=dtype,requires_grad=True)initial_loss=1.learning_rate=1e-5fortinrange(5000):# Forward pass: compute predicted y using operations on Tensors.y_pred=a+b*x+c*x**2+d*x**3# Compute and print loss using operations on Tensors.# Now loss is a Tensor of shape (1,)# loss.item() gets the scalar value held in the loss.loss=(y_pred-y).pow(2).sum()# Calculare initial loss, so we can report loss relative to itift==0:initial_loss=loss.item()ift%100==99:print(f'Iteration t ={t:4d}  loss(t)/loss(0) ={round(loss.item()/initial_loss,6):10.6f}  a ={a.item():10.6f}  b ={b.item():10.6f}  c ={c.item():10.6f}  d ={d.item():10.6f}')# Use autograd to compute the backward pass. This call will compute the# gradient of loss with respect to all Tensors with requires_grad=True.# After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding# the gradient of the loss with respect to a, b, c, d respectively.loss.backward()# Manually update weights using gradient descent. Wrap in torch.no_grad()# because weights have requires_grad=True, but we don't need to track this# in autograd.withtorch.no_grad():a-=learning_rate*a.gradb-=learning_rate*b.gradc-=learning_rate*c.gradd-=learning_rate*d.grad# Manually zero the gradients after updating weightsa.grad=Noneb.grad=Nonec.grad=Noned.grad=Noneprint(f'Result: y ={a.item()} +{b.item()} x +{c.item()} x^2 +{d.item()} x^3')

PyTorch: Defining new autograd functions#

Under the hood, each primitive autograd operator is really two functionsthat operate on Tensors. Theforward function computes outputTensors from input Tensors. Thebackward function receives thegradient of the output Tensors with respect to some scalar value, andcomputes the gradient of the input Tensors with respect to that samescalar value.

In PyTorch we can easily define our own autograd operator by defining asubclass oftorch.autograd.Function and implementing theforwardandbackward functions. We can then use our new autograd operator byconstructing an instance and calling it like a function, passingTensors containing input data.

In this example we define our model as\(y=a+b P_3(c+dx)\) instead of\(y=a+bx+cx^2+dx^3\), where\(P_3(x)=\frac{1}{2}\left(5x^3-3x\right)\)is theLegendre polynomial of degree three. We write our own custom autogradfunction for computing forward and backward of\(P_3\), and use it to implementour model:

importtorchimportmathclassLegendrePolynomial3(torch.autograd.Function):"""    We can implement our own custom autograd Functions by subclassing    torch.autograd.Function and implementing the forward and backward passes    which operate on Tensors.    """@staticmethoddefforward(ctx,input):"""        In the forward pass we receive a Tensor containing the input and return        a Tensor containing the output. ctx is a context object that can be used        to stash information for backward computation. You can cache tensors for        use in the backward pass using the ``ctx.save_for_backward`` method. Other        objects can be stored directly as attributes on the ctx object, such as        ``ctx.my_object = my_object``. Check out `Extending torch.autograd <https://docs.pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd>`_        for further details.        """ctx.save_for_backward(input)return0.5*(5*input**3-3*input)@staticmethoddefbackward(ctx,grad_output):"""        In the backward pass we receive a Tensor containing the gradient of the loss        with respect to the output, and we need to compute the gradient of the loss        with respect to the input.        """input,=ctx.saved_tensorsreturngrad_output*1.5*(5*input**2-1)dtype=torch.floatdevice=torch.device("cpu")# device = torch.device("cuda:0")  # Uncomment this to run on GPU# Create Tensors to hold input and outputs.# By default, requires_grad=False, which indicates that we do not need to# compute gradients with respect to these Tensors during the backward pass.x=torch.linspace(-math.pi,math.pi,2000,device=device,dtype=dtype)y=torch.sin(x)# Create random Tensors for weights. For this example, we need# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized# not too far from the correct result to ensure convergence.# Setting requires_grad=True indicates that we want to compute gradients with# respect to these Tensors during the backward pass.a=torch.full((),0.0,device=device,dtype=dtype,requires_grad=True)b=torch.full((),-1.0,device=device,dtype=dtype,requires_grad=True)c=torch.full((),0.0,device=device,dtype=dtype,requires_grad=True)d=torch.full((),0.3,device=device,dtype=dtype,requires_grad=True)learning_rate=5e-6fortinrange(2000):# To apply our Function, we use Function.apply method. We alias this as 'P3'.P3=LegendrePolynomial3.apply# Forward pass: compute predicted y using operations; we compute# P3 using our custom autograd operation.y_pred=a+b*P3(c+d*x)# Compute and print lossloss=(y_pred-y).pow(2).sum()ift%100==99:print(t,loss.item())# Use autograd to compute the backward pass.loss.backward()# Update weights using gradient descentwithtorch.no_grad():a-=learning_rate*a.gradb-=learning_rate*b.gradc-=learning_rate*c.gradd-=learning_rate*d.grad# Manually zero the gradients after updating weightsa.grad=Noneb.grad=Nonec.grad=Noned.grad=Noneprint(f'Result: y ={a.item()} +{b.item()} * P3({c.item()} +{d.item()} x)')

nn module#

PyTorch:nn#

Computational graphs and autograd are a very powerful paradigm fordefining complex operators and automatically taking derivatives; howeverfor large neural networks raw autograd can be a bit too low-level.

When building neural networks we frequently think of arranging thecomputation intolayers, some of which havelearnable parameterswhich will be optimized during learning.

In TensorFlow, packages likeKeras,TensorFlow-Slim,andTFLearn provide higher-level abstractionsover raw computational graphs that are useful for building neuralnetworks.

In PyTorch, thenn package serves this same purpose. Thennpackage defines a set ofModules, which are roughly equivalent toneural network layers. A Module receives input Tensors and computesoutput Tensors, but may also hold internal state such as Tensorscontaining learnable parameters. Thenn package also defines a setof useful loss functions that are commonly used when training neuralnetworks.

In this example we use thenn package to implement our polynomial modelnetwork:

# -*- coding: utf-8 -*-importtorchimportmath# Create Tensors to hold input and outputs.x=torch.linspace(-math.pi,math.pi,2000)y=torch.sin(x)# For this example, the output y is a linear function of (x, x^2, x^3), so# we can consider it as a linear layer neural network. Let's prepare the# tensor (x, x^2, x^3).p=torch.tensor([1,2,3])xx=x.unsqueeze(-1).pow(p)# In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape# (3,), for this case, broadcasting semantics will apply to obtain a tensor# of shape (2000, 3)# Use the nn package to define our model as a sequence of layers. nn.Sequential# is a Module which contains other Modules, and applies them in sequence to# produce its output. The Linear Module computes output from input using a# linear function, and holds internal Tensors for its weight and bias.# The Flatten layer flatens the output of the linear layer to a 1D tensor,# to match the shape of `y`.model=torch.nn.Sequential(torch.nn.Linear(3,1),torch.nn.Flatten(0,1))# The nn package also contains definitions of popular loss functions; in this# case we will use Mean Squared Error (MSE) as our loss function.loss_fn=torch.nn.MSELoss(reduction='sum')learning_rate=1e-6fortinrange(2000):# Forward pass: compute predicted y by passing x to the model. Module objects# override the __call__ operator so you can call them like functions. When# doing so you pass a Tensor of input data to the Module and it produces# a Tensor of output data.y_pred=model(xx)# Compute and print loss. We pass Tensors containing the predicted and true# values of y, and the loss function returns a Tensor containing the# loss.loss=loss_fn(y_pred,y)ift%100==99:print(t,loss.item())# Zero the gradients before running the backward pass.model.zero_grad()# Backward pass: compute gradient of the loss with respect to all the learnable# parameters of the model. Internally, the parameters of each Module are stored# in Tensors with requires_grad=True, so this call will compute gradients for# all learnable parameters in the model.loss.backward()# Update the weights using gradient descent. Each parameter is a Tensor, so# we can access its gradients like we did before.withtorch.no_grad():forparaminmodel.parameters():param-=learning_rate*param.grad# You can access the first layer of `model` like accessing the first item of a listlinear_layer=model[0]# For linear layer, its parameters are stored as `weight` and `bias`.print(f'Result: y ={linear_layer.bias.item()} +{linear_layer.weight[:,0].item()} x +{linear_layer.weight[:,1].item()} x^2 +{linear_layer.weight[:,2].item()} x^3')

PyTorch: optim#

Up to this point we have updated the weights of our models by manuallymutating the Tensors holding learnable parameters withtorch.no_grad().This is not a huge burden for simple optimization algorithms like stochasticgradient descent, but in practice we often train neural networks using moresophisticated optimizers likeAdaGrad,RMSProp,Adam, and other.

Theoptim package in PyTorch abstracts the idea of an optimizationalgorithm and provides implementations of commonly used optimizationalgorithms.

In this example we will use thenn package to define our model asbefore, but we will optimize the model using theRMSprop algorithm providedby theoptim package:

# -*- coding: utf-8 -*-importtorchimportmath# Create Tensors to hold input and outputs.x=torch.linspace(-math.pi,math.pi,2000)y=torch.sin(x)# Prepare the input tensor (x, x^2, x^3).p=torch.tensor([1,2,3])xx=x.unsqueeze(-1).pow(p)# Use the nn package to define our model and loss function.model=torch.nn.Sequential(torch.nn.Linear(3,1),torch.nn.Flatten(0,1))loss_fn=torch.nn.MSELoss(reduction='sum')# Use the optim package to define an Optimizer that will update the weights of# the model for us. Here we will use RMSprop; the optim package contains many other# optimization algorithms. The first argument to the RMSprop constructor tells the# optimizer which Tensors it should update.learning_rate=1e-3optimizer=torch.optim.RMSprop(model.parameters(),lr=learning_rate)fortinrange(2000):# Forward pass: compute predicted y by passing x to the model.y_pred=model(xx)# Compute and print loss.loss=loss_fn(y_pred,y)ift%100==99:print(t,loss.item())# Before the backward pass, use the optimizer object to zero all of the# gradients for the variables it will update (which are the learnable# weights of the model). This is because by default, gradients are# accumulated in buffers( i.e, not overwritten) whenever .backward()# is called. Checkout docs of torch.autograd.backward for more details.optimizer.zero_grad()# Backward pass: compute gradient of the loss with respect to model# parametersloss.backward()# Calling the step function on an Optimizer makes an update to its# parametersoptimizer.step()linear_layer=model[0]print(f'Result: y ={linear_layer.bias.item()} +{linear_layer.weight[:,0].item()} x +{linear_layer.weight[:,1].item()} x^2 +{linear_layer.weight[:,2].item()} x^3')

PyTorch: Customnn Modules#

Sometimes you will want to specify models that are more complex than asequence of existing Modules; for these cases you can define your ownModules by subclassingnn.Module and defining aforward whichreceives input Tensors and produces output Tensors using othermodules or other autograd operations on Tensors.

In this example we implement our third order polynomial as a custom Modulesubclass:

# -*- coding: utf-8 -*-importtorchimportmathclassPolynomial3(torch.nn.Module):def__init__(self):"""        In the constructor we instantiate four parameters and assign them as        member parameters.        """super().__init__()self.a=torch.nn.Parameter(torch.randn(()))self.b=torch.nn.Parameter(torch.randn(()))self.c=torch.nn.Parameter(torch.randn(()))self.d=torch.nn.Parameter(torch.randn(()))defforward(self,x):"""        In the forward function we accept a Tensor of input data and we must return        a Tensor of output data. We can use Modules defined in the constructor as        well as arbitrary operators on Tensors.        """returnself.a+self.b*x+self.c*x**2+self.d*x**3defstring(self):"""        Just like any class in Python, you can also define custom method on PyTorch modules        """returnf'y ={self.a.item()} +{self.b.item()} x +{self.c.item()} x^2 +{self.d.item()} x^3'# Create Tensors to hold input and outputs.x=torch.linspace(-math.pi,math.pi,2000)y=torch.sin(x)# Construct our model by instantiating the class defined abovemodel=Polynomial3()# Construct our loss function and an Optimizer. The call to model.parameters()# in the SGD constructor will contain the learnable parameters (defined# with torch.nn.Parameter) which are members of the model.criterion=torch.nn.MSELoss(reduction='sum')optimizer=torch.optim.SGD(model.parameters(),lr=1e-6)fortinrange(2000):# Forward pass: Compute predicted y by passing x to the modely_pred=model(x)# Compute and print lossloss=criterion(y_pred,y)ift%100==99:print(t,loss.item())# Zero gradients, perform a backward pass, and update the weights.optimizer.zero_grad()loss.backward()optimizer.step()print(f'Result:{model.string()}')

PyTorch: Control Flow + Weight Sharing#

As an example of dynamic graphs and weight sharing, we implement a verystrange model: a third-fifth order polynomial that on each forward passchooses a random number between 3 and 5 and uses that many orders, reusingthe same weights multiple times to compute the fourth and fifth order.

For this model we can use normal Python flow control to implement the loop,and we can implement weight sharing by simply reusing the same parameter multipletimes when defining the forward pass.

We can easily implement this model as a Module subclass:

# -*- coding: utf-8 -*-importrandomimporttorchimportmathclassDynamicNet(torch.nn.Module):def__init__(self):"""        In the constructor we instantiate five parameters and assign them as members.        """super().__init__()self.a=torch.nn.Parameter(torch.randn(()))self.b=torch.nn.Parameter(torch.randn(()))self.c=torch.nn.Parameter(torch.randn(()))self.d=torch.nn.Parameter(torch.randn(()))self.e=torch.nn.Parameter(torch.randn(()))defforward(self,x):"""        For the forward pass of the model, we randomly choose either 4, 5        and reuse the e parameter to compute the contribution of these orders.        Since each forward pass builds a dynamic computation graph, we can use normal        Python control-flow operators like loops or conditional statements when        defining the forward pass of the model.        Here we also see that it is perfectly safe to reuse the same parameter many        times when defining a computational graph.        """y=self.a+self.b*x+self.c*x**2+self.d*x**3forexpinrange(4,random.randint(4,6)):y=y+self.e*x**expreturnydefstring(self):"""        Just like any class in Python, you can also define custom method on PyTorch modules        """returnf'y ={self.a.item()} +{self.b.item()} x +{self.c.item()} x^2 +{self.d.item()} x^3 +{self.e.item()} x^4 ? +{self.e.item()} x^5 ?'# Create Tensors to hold input and outputs.x=torch.linspace(-math.pi,math.pi,2000)y=torch.sin(x)# Construct our model by instantiating the class defined abovemodel=DynamicNet()# Construct our loss function and an Optimizer. Training this strange model with# vanilla stochastic gradient descent is tough, so we use momentumcriterion=torch.nn.MSELoss(reduction='sum')optimizer=torch.optim.SGD(model.parameters(),lr=1e-8,momentum=0.9)fortinrange(30000):# Forward pass: Compute predicted y by passing x to the modely_pred=model(x)# Compute and print lossloss=criterion(y_pred,y)ift%2000==1999:print(t,loss.item())# Zero gradients, perform a backward pass, and update the weights.optimizer.zero_grad()loss.backward()optimizer.step()print(f'Result:{model.string()}')

Examples#

You can browse the above examples here.

Tensors#

Autograd#

nn module#