Rate this Page

Note

Go to the endto download the full example code.

PyTorch: Custom nn Modules#

Created On: Dec 03, 2020 | Last Updated: Aug 31, 2022 | Last Verified: Nov 05, 2024

A third order polynomial, trained to predict\(y=\sin(x)\) from\(-\pi\)to\(\pi\) by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever youwant a model more complex than a simple sequence of existing Modules you willneed to define your model this way.

importtorchimportmathclassPolynomial3(torch.nn.Module):def__init__(self):"""        In the constructor we instantiate four parameters and assign them as        member parameters.        """super().__init__()self.a=torch.nn.Parameter(torch.randn(()))self.b=torch.nn.Parameter(torch.randn(()))self.c=torch.nn.Parameter(torch.randn(()))self.d=torch.nn.Parameter(torch.randn(()))defforward(self,x):"""        In the forward function we accept a Tensor of input data and we must return        a Tensor of output data. We can use Modules defined in the constructor as        well as arbitrary operators on Tensors.        """returnself.a+self.b*x+self.c*x**2+self.d*x**3defstring(self):"""        Just like any class in Python, you can also define custom method on PyTorch modules        """returnf'y ={self.a.item()} +{self.b.item()} x +{self.c.item()} x^2 +{self.d.item()} x^3'# Create Tensors to hold input and outputs.x=torch.linspace(-math.pi,math.pi,2000)y=torch.sin(x)# Construct our model by instantiating the class defined abovemodel=Polynomial3()# Construct our loss function and an Optimizer. The call to model.parameters()# in the SGD constructor will contain the learnable parameters (defined# with torch.nn.Parameter) which are members of the model.criterion=torch.nn.MSELoss(reduction='sum')optimizer=torch.optim.SGD(model.parameters(),lr=1e-6)fortinrange(2000):# Forward pass: Compute predicted y by passing x to the modely_pred=model(x)# Compute and print lossloss=criterion(y_pred,y)ift%100==99:print(t,loss.item())# Zero gradients, perform a backward pass, and update the weights.optimizer.zero_grad()loss.backward()optimizer.step()print(f'Result:{model.string()}')