Movatterモバイル変換


[0]ホーム

URL:


Skip to content
DEV Community
Log in Create account

DEV Community

Cover image for Neural Network from Scratch Using PyTorch
Lankinen
Lankinen

Posted on • Edited on

     

Neural Network from Scratch Using PyTorch

In this article I show how to build a neural network from scratch. The example is simple and short to make it easier to understand but I haven’t took any shortcuts to hide details.

Looking for Tensorflow version of this same tutorial?Go here.

import torch import matplotlib.pyplot as plt
Enter fullscreen modeExit fullscreen mode

First we create some random data. x is just 1-D tensor and the model will predict one value y.

x = torch.tensor([[1.,2.]])x.shapeCONSOLE: torch.Size([1, 2])y = 5.
Enter fullscreen modeExit fullscreen mode

The parameters are initialized using normal distribution where mean is 0 and variance 1.

def initalize_parameters(size, variance=1.0):    return (torch.randn(size) * variance).requires_grad_()first_layer_output_size = 3weights_1 = initalize_parameters(                   (x.shape[1],                      first_layer_output_size))weights_1, weights_1.shapeCONSOLE: (tensor([[ 0.3575, -1.6650,  1.1152],                  [-0.2687, -0.6715, -1.2855]],                  requires_grad=True),             torch.Size([2, 3]))bias_1 = initalize_parameters(1)bias_1, bias_1.shapeCONSOLE: (tensor([-2.5051], requires_grad=True),           torch.Size([1]))weights_2 = initalize_parameters((first_layer_output_size,1))weights_2, weights_2.shapeCONSOLE: (tensor([[-0.9567],                  [-1.6121],                  [ 0.6514]], requires_grad=True),          torch.Size([3, 1]))bias_2 = initalize_parameters([1])bias_2, bias_2.shapeCONSOLE: (tensor([0.2285], requires_grad=True),          torch.Size([1]))
Enter fullscreen modeExit fullscreen mode

The neural network contains two linear functions and one non-linear function between them.

def simple_neural_network(xb):    # linear (1,2 @ 2,3 = 1,3)    l1 = xb @ weights_1 + bias_1    # non-linear    l2 = l1.max(torch.tensor(0.0))    # linear (1,3 @ 3,1 = 1,1)    l3 = l2 @ weights_2 + bias_2    return l3
Enter fullscreen modeExit fullscreen mode

Loss function measures how close the predictions are to the real values.

def loss_func(preds, yb):    # Mean Squared Error (MSE)    return ((preds-yb)**2).mean()
Enter fullscreen modeExit fullscreen mode

Learning rate reduces gradient making sure parameters are not changed too much in each step.

lr = 10E-4
Enter fullscreen modeExit fullscreen mode

Helper function that updates the parameters and then clears the gradient.

def update_params(a):    a.data -= a.grad * lr    a.grad = None
Enter fullscreen modeExit fullscreen mode

Training contains three simple steps:

  1. Make prediction
  2. Calculate how good the prediction was compared to the real value (When calculating loss it automatically calculates gradient so we don't need to think about it)
  3. Update parameters by subtracting gradient times learning rate

The code continues taking steps until the loss is less than or equal to 0.1. Finally it plots the loss change.

losses = []while(len(losses) == 0 or losses[-1] > 0.1):    # 1. predict    preds = simple_neural_network(x)    # 2. loss    loss = loss_func(preds, y)    loss.backward()    # 3. update parameters    update_params(weights_1)    update_params(bias_1)    update_params(weights_2)    update_params(bias_2)    losses.append(loss)plt.plot(list(range(len(losses))), losses)plt.ylabel('loss (MSE)')plt.xlabel('steps')plt.show()
Enter fullscreen modeExit fullscreen mode

Loss plot

It changes a lot how many steps it takes to get to loss under 0.1.

Source Code on Github

Top comments(0)

Subscribe
pic
Create template

Templates let you quickly answer FAQs or store snippets for re-use.

Dismiss

Are you sure you want to hide this comment? It will become hidden in your post, but will still be visible via the comment'spermalink.

For further actions, you may consider blocking this person and/orreporting abuse

Not ready...
  • Location
    Earth
  • Joined

More fromLankinen

DEV Community

We're a place where coders share, stay up-to-date and grow their careers.

Log in Create account

[8]ページ先頭

©2009-2025 Movatter.jp