Note
Go to the endto download the full example code.
Reasoning about Shapes in PyTorch#
Created On: Mar 27, 2023 | Last Updated: Mar 27, 2023 | Last Verified: Not Verified
When writing models with PyTorch, it is commonly the case that the parametersto a given layer depend on the shape of the output of the previous layer. Forexample, thein_features of annn.Linear layer must match thesize(-1) of the input. For some layers, the shape computation involvescomplex equations, for example convolution operations.
One way around this is to run the forward pass with random inputs, but this iswasteful in terms of memory and compute.
Instead, we can make use of themeta device to determine the output shapesof a layer without materializing any data.
importtorchimporttimeitt=torch.rand(2,3,10,10,device="meta")conv=torch.nn.Conv2d(3,5,2,device="meta")start=timeit.default_timer()out=conv(t)end=timeit.default_timer()print(out)print(f"Time taken:{end-start}")
tensor(..., device='meta', size=(2, 5, 9, 9), grad_fn=<ConvolutionBackward0>)Time taken: 0.0003385929999240034
Observe that since data is not materialized, passing arbitrarily largeinputs will not significantly alter the time taken for shape computation.
tensor(..., device='meta', size=(1024, 5, 65535, 65535), grad_fn=<ConvolutionBackward0>)Time taken: 9.937699996953597e-05
Consider an arbitrary network such as the following:
importtorch.nnasnnimporttorch.nn.functionalasFclassNet(nn.Module):def__init__(self):super().__init__()self.conv1=nn.Conv2d(3,6,5)self.pool=nn.MaxPool2d(2,2)self.conv2=nn.Conv2d(6,16,5)self.fc1=nn.Linear(16*5*5,120)self.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)defforward(self,x):x=self.pool(F.relu(self.conv1(x)))x=self.pool(F.relu(self.conv2(x)))x=torch.flatten(x,1)# flatten all dimensions except batchx=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)returnx
We can view the intermediate shapes within an entire network by registering aforward hook to each layer that prints the shape of the output.
deffw_hook(module,input,output):print(f"Shape of output to{module} is{output.shape}.")# Any tensor created within this torch.device context manager will be# on the meta device.withtorch.device("meta"):net=Net()inp=torch.randn((1024,3,32,32))forname,layerinnet.named_modules():layer.register_forward_hook(fw_hook)out=net(inp)
Shape of output to Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) is torch.Size([1024, 6, 28, 28]).Shape of output to MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) is torch.Size([1024, 6, 14, 14]).Shape of output to Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) is torch.Size([1024, 16, 10, 10]).Shape of output to MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) is torch.Size([1024, 16, 5, 5]).Shape of output to Linear(in_features=400, out_features=120, bias=True) is torch.Size([1024, 120]).Shape of output to Linear(in_features=120, out_features=84, bias=True) is torch.Size([1024, 84]).Shape of output to Linear(in_features=84, out_features=10, bias=True) is torch.Size([1024, 10]).Shape of output to Net( (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (fc1): Linear(in_features=400, out_features=120, bias=True) (fc2): Linear(in_features=120, out_features=84, bias=True) (fc3): Linear(in_features=84, out_features=10, bias=True)) is torch.Size([1024, 10]).
Total running time of the script: (0 minutes 0.016 seconds)