Rate this Page

Note

Go to the endto download the full example code.

PyTorch: Control Flow + Weight Sharing#

Created On: Mar 24, 2017 | Last Updated: Dec 28, 2021 | Last Verified: Nov 05, 2024

To showcase the power of PyTorch dynamic graphs, we will implement a very strangemodel: a third-fifth order polynomial that on each forward passchooses a random number between 4 and 5 and uses that many orders, reusingthe same weights multiple times to compute the fourth and fifth order.

importrandomimporttorchimportmathclassDynamicNet(torch.nn.Module):def__init__(self):"""        In the constructor we instantiate five parameters and assign them as members.        """super().__init__()self.a=torch.nn.Parameter(torch.randn(()))self.b=torch.nn.Parameter(torch.randn(()))self.c=torch.nn.Parameter(torch.randn(()))self.d=torch.nn.Parameter(torch.randn(()))self.e=torch.nn.Parameter(torch.randn(()))defforward(self,x):"""        For the forward pass of the model, we randomly choose either 4, 5        and reuse the e parameter to compute the contribution of these orders.        Since each forward pass builds a dynamic computation graph, we can use normal        Python control-flow operators like loops or conditional statements when        defining the forward pass of the model.        Here we also see that it is perfectly safe to reuse the same parameter many        times when defining a computational graph.        """y=self.a+self.b*x+self.c*x**2+self.d*x**3forexpinrange(4,random.randint(4,6)):y=y+self.e*x**expreturnydefstring(self):"""        Just like any class in Python, you can also define custom method on PyTorch modules        """returnf'y ={self.a.item()} +{self.b.item()} x +{self.c.item()} x^2 +{self.d.item()} x^3 +{self.e.item()} x^4 ? +{self.e.item()} x^5 ?'# Create Tensors to hold input and outputs.x=torch.linspace(-math.pi,math.pi,2000)y=torch.sin(x)# Construct our model by instantiating the class defined abovemodel=DynamicNet()# Construct our loss function and an Optimizer. Training this strange model with# vanilla stochastic gradient descent is tough, so we use momentumcriterion=torch.nn.MSELoss(reduction='sum')optimizer=torch.optim.SGD(model.parameters(),lr=1e-8,momentum=0.9)fortinrange(30000):# Forward pass: Compute predicted y by passing x to the modely_pred=model(x)# Compute and print lossloss=criterion(y_pred,y)ift%2000==1999:print(t,loss.item())# Zero gradients, perform a backward pass, and update the weights.optimizer.zero_grad()loss.backward()optimizer.step()print(f'Result:{model.string()}')