Rate this Page

Note

Go to the endto download the full example code.

Pruning Tutorial#

Created On: Jul 22, 2019 | Last Updated: Nov 02, 2023 | Last Verified: Nov 05, 2024

Author:Michela Paganini

State-of-the-art deep learning techniques rely on over-parametrized modelsthat are hard to deploy. On the contrary, biological neural networks areknown to use efficient sparse connectivity. Identifying optimaltechniques to compress models by reducing the number of parameters in them isimportant in order to reduce memory, battery, and hardware consumption withoutsacrificing accuracy. This in turn allows you to deploy lightweight models on device, and guaranteeprivacy with private on-device computation. On the research front, pruning isused to investigate the differences in learning dynamics betweenover-parametrized and under-parametrized networks, to study the role of luckysparse subnetworks and initializations(”lottery tickets”) as a destructiveneural architecture search technique, and more.

In this tutorial, you will learn how to usetorch.nn.utils.prune tosparsify your neural networks, and how to extend it to implement yourown custom pruning technique.

Requirements#

"torch>=1.4.0a0+8e8a5e0"

importtorchfromtorchimportnnimporttorch.nn.utils.pruneaspruneimporttorch.nn.functionalasF

Create a model#

In this tutorial, we use theLeNet architecture fromLeCun et al., 1998.

device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")classLeNet(nn.Module):def__init__(self):super(LeNet,self).__init__()# 1 input image channel, 6 output channels, 5x5 square conv kernelself.conv1=nn.Conv2d(1,6,5)self.conv2=nn.Conv2d(6,16,5)self.fc1=nn.Linear(16*5*5,120)# 5x5 image dimensionself.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)defforward(self,x):x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))x=F.max_pool2d(F.relu(self.conv2(x)),2)x=x.view(-1,int(x.nelement()/x.shape[0]))x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)returnxmodel=LeNet().to(device=device)

Inspect a Module#

Let’s inspect the (unpruned)conv1 layer in our LeNet model. It will contain twoparametersweight andbias, and no buffers, for now.

[('weight', Parameter containing:tensor([[[[-0.0272,  0.0551, -0.1933,  0.0165,  0.0877],          [-0.0308, -0.1279,  0.0496, -0.0701, -0.1179],          [ 0.0031, -0.1734, -0.0843, -0.0390, -0.0425],          [-0.1958, -0.1451,  0.1107,  0.1285, -0.1281],          [ 0.0379,  0.0496,  0.1430,  0.1946, -0.0207]]],        [[[ 0.0242, -0.1359,  0.0350,  0.0532, -0.1349],          [ 0.0802, -0.0730, -0.0084,  0.1603,  0.1444],          [ 0.1351, -0.0441,  0.0896,  0.1218,  0.0088],          [-0.0450, -0.0324, -0.0172,  0.0480,  0.1602],          [ 0.0754,  0.0280,  0.0084,  0.1256, -0.1584]]],        [[[-0.0690, -0.0977, -0.0530, -0.0642, -0.1816],          [ 0.0379, -0.1806, -0.0858,  0.1302,  0.1291],          [ 0.1237, -0.0264,  0.1423,  0.1394, -0.1692],          [ 0.1917, -0.0523,  0.1784,  0.1189, -0.1059],          [-0.0274, -0.0100, -0.1361, -0.0004,  0.1947]]],        [[[-0.0536, -0.0251,  0.0947,  0.0941, -0.0293],          [ 0.1725, -0.0306,  0.0916, -0.1544,  0.0521],          [ 0.1784,  0.1610,  0.1548, -0.1838, -0.0674],          [-0.0988, -0.1755,  0.0338,  0.0050, -0.1058],          [-0.0048,  0.1586,  0.1444, -0.0452, -0.1875]]],        [[[-0.1729, -0.0184, -0.1283,  0.0416, -0.0771],          [-0.1250, -0.0326, -0.1453, -0.0840,  0.1750],          [ 0.1824, -0.1408,  0.0157, -0.0611, -0.1136],          [-0.0942, -0.0245, -0.1901, -0.0286,  0.0286],          [-0.1583, -0.1605,  0.1254,  0.0133, -0.1682]]],        [[[-0.0387,  0.0259, -0.1654,  0.0895,  0.0093],          [-0.1045,  0.1651,  0.0966, -0.1185, -0.0338],          [ 0.0213, -0.0885, -0.0413, -0.0325, -0.0550],          [ 0.1380,  0.0995,  0.1335, -0.1250, -0.0416],          [ 0.1243,  0.0719,  0.0360,  0.0076,  0.0581]]]], device='cuda:0',       requires_grad=True)), ('bias', Parameter containing:tensor([ 0.1648,  0.1975, -0.1674, -0.0989,  0.0567,  0.1619], device='cuda:0',       requires_grad=True))]
print(list(module.named_buffers()))
[]

Pruning a Module#

To prune a module (in this example, theconv1 layer of our LeNetarchitecture), first select a pruning technique among those available intorch.nn.utils.prune (orimplementyour own by subclassingBasePruningMethod). Then, specify the module and the name of the parameter toprune within that module. Finally, using the adequate keyword argumentsrequired by the selected pruning technique, specify the pruning parameters.

In this example, we will prune at random 30% of the connections inthe parameter namedweight in theconv1 layer.The module is passed as the first argument to the function;nameidentifies the parameter within that module using its string identifier; andamount indicates either the percentage of connections to prune (if itis a float between 0. and 1.), or the absolute number of connections toprune (if it is a non-negative integer).

prune.random_unstructured(module,name="weight",amount=0.3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

Pruning acts by removingweight from the parameters and replacing it witha new parameter calledweight_orig (i.e. appending"_orig" to theinitial parametername).weight_orig stores the unpruned version ofthe tensor. Thebias was not pruned, so it will remain intact.

[('bias', Parameter containing:tensor([ 0.1648,  0.1975, -0.1674, -0.0989,  0.0567,  0.1619], device='cuda:0',       requires_grad=True)), ('weight_orig', Parameter containing:tensor([[[[-0.0272,  0.0551, -0.1933,  0.0165,  0.0877],          [-0.0308, -0.1279,  0.0496, -0.0701, -0.1179],          [ 0.0031, -0.1734, -0.0843, -0.0390, -0.0425],          [-0.1958, -0.1451,  0.1107,  0.1285, -0.1281],          [ 0.0379,  0.0496,  0.1430,  0.1946, -0.0207]]],        [[[ 0.0242, -0.1359,  0.0350,  0.0532, -0.1349],          [ 0.0802, -0.0730, -0.0084,  0.1603,  0.1444],          [ 0.1351, -0.0441,  0.0896,  0.1218,  0.0088],          [-0.0450, -0.0324, -0.0172,  0.0480,  0.1602],          [ 0.0754,  0.0280,  0.0084,  0.1256, -0.1584]]],        [[[-0.0690, -0.0977, -0.0530, -0.0642, -0.1816],          [ 0.0379, -0.1806, -0.0858,  0.1302,  0.1291],          [ 0.1237, -0.0264,  0.1423,  0.1394, -0.1692],          [ 0.1917, -0.0523,  0.1784,  0.1189, -0.1059],          [-0.0274, -0.0100, -0.1361, -0.0004,  0.1947]]],        [[[-0.0536, -0.0251,  0.0947,  0.0941, -0.0293],          [ 0.1725, -0.0306,  0.0916, -0.1544,  0.0521],          [ 0.1784,  0.1610,  0.1548, -0.1838, -0.0674],          [-0.0988, -0.1755,  0.0338,  0.0050, -0.1058],          [-0.0048,  0.1586,  0.1444, -0.0452, -0.1875]]],        [[[-0.1729, -0.0184, -0.1283,  0.0416, -0.0771],          [-0.1250, -0.0326, -0.1453, -0.0840,  0.1750],          [ 0.1824, -0.1408,  0.0157, -0.0611, -0.1136],          [-0.0942, -0.0245, -0.1901, -0.0286,  0.0286],          [-0.1583, -0.1605,  0.1254,  0.0133, -0.1682]]],        [[[-0.0387,  0.0259, -0.1654,  0.0895,  0.0093],          [-0.1045,  0.1651,  0.0966, -0.1185, -0.0338],          [ 0.0213, -0.0885, -0.0413, -0.0325, -0.0550],          [ 0.1380,  0.0995,  0.1335, -0.1250, -0.0416],          [ 0.1243,  0.0719,  0.0360,  0.0076,  0.0581]]]], device='cuda:0',       requires_grad=True))]

The pruning mask generated by the pruning technique selected above is savedas a module buffer namedweight_mask (i.e. appending"_mask" to theinitial parametername).

print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],          [1., 1., 1., 1., 1.],          [1., 1., 0., 1., 1.],          [1., 0., 0., 1., 1.],          [1., 1., 1., 1., 1.]]],        [[[0., 1., 1., 0., 1.],          [0., 0., 1., 0., 1.],          [1., 0., 1., 1., 1.],          [1., 1., 1., 1., 0.],          [0., 1., 0., 1., 1.]]],        [[[0., 1., 1., 1., 1.],          [0., 1., 1., 1., 1.],          [1., 1., 1., 1., 1.],          [1., 1., 1., 1., 1.],          [0., 1., 0., 0., 1.]]],        [[[1., 0., 0., 1., 0.],          [0., 0., 0., 0., 1.],          [1., 1., 1., 1., 0.],          [1., 1., 1., 0., 0.],          [1., 0., 0., 0., 0.]]],        [[[0., 1., 1., 0., 1.],          [1., 1., 1., 1., 1.],          [1., 1., 1., 1., 1.],          [0., 0., 1., 1., 1.],          [1., 0., 0., 1., 0.]]],        [[[1., 1., 1., 1., 1.],          [1., 1., 1., 0., 0.],          [0., 0., 1., 1., 1.],          [1., 0., 1., 1., 1.],          [0., 1., 1., 1., 0.]]]], device='cuda:0'))]

For the forward pass to work without modification, theweight attributeneeds to exist. The pruning techniques implemented intorch.nn.utils.prune compute the pruned version of the weight (bycombining the mask with the original parameter) and store them in theattributeweight. Note, this is no longer a parameter of themodule,it is now simply an attribute.

tensor([[[[-0.0272,  0.0551, -0.1933,  0.0165,  0.0877],          [-0.0308, -0.1279,  0.0496, -0.0701, -0.1179],          [ 0.0031, -0.1734, -0.0000, -0.0390, -0.0425],          [-0.1958, -0.0000,  0.0000,  0.1285, -0.1281],          [ 0.0379,  0.0496,  0.1430,  0.1946, -0.0207]]],        [[[ 0.0000, -0.1359,  0.0350,  0.0000, -0.1349],          [ 0.0000, -0.0000, -0.0084,  0.0000,  0.1444],          [ 0.1351, -0.0000,  0.0896,  0.1218,  0.0088],          [-0.0450, -0.0324, -0.0172,  0.0480,  0.0000],          [ 0.0000,  0.0280,  0.0000,  0.1256, -0.1584]]],        [[[-0.0000, -0.0977, -0.0530, -0.0642, -0.1816],          [ 0.0000, -0.1806, -0.0858,  0.1302,  0.1291],          [ 0.1237, -0.0264,  0.1423,  0.1394, -0.1692],          [ 0.1917, -0.0523,  0.1784,  0.1189, -0.1059],          [-0.0000, -0.0100, -0.0000, -0.0000,  0.1947]]],        [[[-0.0536, -0.0000,  0.0000,  0.0941, -0.0000],          [ 0.0000, -0.0000,  0.0000, -0.0000,  0.0521],          [ 0.1784,  0.1610,  0.1548, -0.1838, -0.0000],          [-0.0988, -0.1755,  0.0338,  0.0000, -0.0000],          [-0.0048,  0.0000,  0.0000, -0.0000, -0.0000]]],        [[[-0.0000, -0.0184, -0.1283,  0.0000, -0.0771],          [-0.1250, -0.0326, -0.1453, -0.0840,  0.1750],          [ 0.1824, -0.1408,  0.0157, -0.0611, -0.1136],          [-0.0000, -0.0000, -0.1901, -0.0286,  0.0286],          [-0.1583, -0.0000,  0.0000,  0.0133, -0.0000]]],        [[[-0.0387,  0.0259, -0.1654,  0.0895,  0.0093],          [-0.1045,  0.1651,  0.0966, -0.0000, -0.0000],          [ 0.0000, -0.0000, -0.0413, -0.0325, -0.0550],          [ 0.1380,  0.0000,  0.1335, -0.1250, -0.0416],          [ 0.0000,  0.0719,  0.0360,  0.0076,  0.0000]]]], device='cuda:0',       grad_fn=<MulBackward0>)

Finally, pruning is applied prior to each forward pass using PyTorch’sforward_pre_hooks. Specifically, when themodule is pruned, as wehave done here, it will acquire aforward_pre_hook for each parameterassociated with it that gets pruned. In this case, since we have so faronly pruned the original parameter namedweight, only one hook will bepresent.

print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f7425943910>)])

For completeness, we can now prune thebias too, to see how theparameters, buffers, hooks, and attributes of themodule change.Just for the sake of trying out another pruning technique, here we prune the3 smallest entries in the bias by L1 norm, as implemented in thel1_unstructured pruning function.

prune.l1_unstructured(module,name="bias",amount=3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

We now expect the named parameters to include bothweight_orig (frombefore) andbias_orig. The buffers will includeweight_mask andbias_mask. The pruned versions of the two tensors will exist asmodule attributes, and the module will now have twoforward_pre_hooks.

[('weight_orig', Parameter containing:tensor([[[[-0.0272,  0.0551, -0.1933,  0.0165,  0.0877],          [-0.0308, -0.1279,  0.0496, -0.0701, -0.1179],          [ 0.0031, -0.1734, -0.0843, -0.0390, -0.0425],          [-0.1958, -0.1451,  0.1107,  0.1285, -0.1281],          [ 0.0379,  0.0496,  0.1430,  0.1946, -0.0207]]],        [[[ 0.0242, -0.1359,  0.0350,  0.0532, -0.1349],          [ 0.0802, -0.0730, -0.0084,  0.1603,  0.1444],          [ 0.1351, -0.0441,  0.0896,  0.1218,  0.0088],          [-0.0450, -0.0324, -0.0172,  0.0480,  0.1602],          [ 0.0754,  0.0280,  0.0084,  0.1256, -0.1584]]],        [[[-0.0690, -0.0977, -0.0530, -0.0642, -0.1816],          [ 0.0379, -0.1806, -0.0858,  0.1302,  0.1291],          [ 0.1237, -0.0264,  0.1423,  0.1394, -0.1692],          [ 0.1917, -0.0523,  0.1784,  0.1189, -0.1059],          [-0.0274, -0.0100, -0.1361, -0.0004,  0.1947]]],        [[[-0.0536, -0.0251,  0.0947,  0.0941, -0.0293],          [ 0.1725, -0.0306,  0.0916, -0.1544,  0.0521],          [ 0.1784,  0.1610,  0.1548, -0.1838, -0.0674],          [-0.0988, -0.1755,  0.0338,  0.0050, -0.1058],          [-0.0048,  0.1586,  0.1444, -0.0452, -0.1875]]],        [[[-0.1729, -0.0184, -0.1283,  0.0416, -0.0771],          [-0.1250, -0.0326, -0.1453, -0.0840,  0.1750],          [ 0.1824, -0.1408,  0.0157, -0.0611, -0.1136],          [-0.0942, -0.0245, -0.1901, -0.0286,  0.0286],          [-0.1583, -0.1605,  0.1254,  0.0133, -0.1682]]],        [[[-0.0387,  0.0259, -0.1654,  0.0895,  0.0093],          [-0.1045,  0.1651,  0.0966, -0.1185, -0.0338],          [ 0.0213, -0.0885, -0.0413, -0.0325, -0.0550],          [ 0.1380,  0.0995,  0.1335, -0.1250, -0.0416],          [ 0.1243,  0.0719,  0.0360,  0.0076,  0.0581]]]], device='cuda:0',       requires_grad=True)), ('bias_orig', Parameter containing:tensor([ 0.1648,  0.1975, -0.1674, -0.0989,  0.0567,  0.1619], device='cuda:0',       requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],          [1., 1., 1., 1., 1.],          [1., 1., 0., 1., 1.],          [1., 0., 0., 1., 1.],          [1., 1., 1., 1., 1.]]],        [[[0., 1., 1., 0., 1.],          [0., 0., 1., 0., 1.],          [1., 0., 1., 1., 1.],          [1., 1., 1., 1., 0.],          [0., 1., 0., 1., 1.]]],        [[[0., 1., 1., 1., 1.],          [0., 1., 1., 1., 1.],          [1., 1., 1., 1., 1.],          [1., 1., 1., 1., 1.],          [0., 1., 0., 0., 1.]]],        [[[1., 0., 0., 1., 0.],          [0., 0., 0., 0., 1.],          [1., 1., 1., 1., 0.],          [1., 1., 1., 0., 0.],          [1., 0., 0., 0., 0.]]],        [[[0., 1., 1., 0., 1.],          [1., 1., 1., 1., 1.],          [1., 1., 1., 1., 1.],          [0., 0., 1., 1., 1.],          [1., 0., 0., 1., 0.]]],        [[[1., 1., 1., 1., 1.],          [1., 1., 1., 0., 0.],          [0., 0., 1., 1., 1.],          [1., 0., 1., 1., 1.],          [0., 1., 1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([1., 1., 1., 0., 0., 0.], device='cuda:0'))]
tensor([ 0.1648,  0.1975, -0.1674, -0.0000,  0.0000,  0.0000], device='cuda:0',       grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f7425943910>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f7425943a30>)])

Iterative Pruning#

The same parameter in a module can be pruned multiple times, with theeffect of the various pruning calls being equal to the combination of thevarious masks applied in series.The combination of a new mask with the old mask is handled by thePruningContainer’scompute_mask method.

Say, for example, that we now want to further prunemodule.weight, thistime using structured pruning along the 0th axis of the tensor (the 0th axiscorresponds to the output channels of the convolutional layer and hasdimensionality 6 forconv1), based on the channels’ L2 norm. This can beachieved using theln_structured function, withn=2 anddim=0.

prune.ln_structured(module,name="weight",amount=0.5,n=2,dim=0)# As we can verify, this will zero out all the connections corresponding to# 50% (3 out of 6) of the channels, while preserving the action of the# previous mask.print(module.weight)
tensor([[[[-0.0272,  0.0551, -0.1933,  0.0165,  0.0877],          [-0.0308, -0.1279,  0.0496, -0.0701, -0.1179],          [ 0.0031, -0.1734, -0.0000, -0.0390, -0.0425],          [-0.1958, -0.0000,  0.0000,  0.1285, -0.1281],          [ 0.0379,  0.0496,  0.1430,  0.1946, -0.0207]]],        [[[ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],          [ 0.0000, -0.0000, -0.0000,  0.0000,  0.0000],          [ 0.0000, -0.0000,  0.0000,  0.0000,  0.0000],          [-0.0000, -0.0000, -0.0000,  0.0000,  0.0000],          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000]]],        [[[-0.0000, -0.0977, -0.0530, -0.0642, -0.1816],          [ 0.0000, -0.1806, -0.0858,  0.1302,  0.1291],          [ 0.1237, -0.0264,  0.1423,  0.1394, -0.1692],          [ 0.1917, -0.0523,  0.1784,  0.1189, -0.1059],          [-0.0000, -0.0100, -0.0000, -0.0000,  0.1947]]],        [[[-0.0000, -0.0000,  0.0000,  0.0000, -0.0000],          [ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000],          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],          [-0.0000, -0.0000,  0.0000,  0.0000, -0.0000],          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000]]],        [[[-0.0000, -0.0184, -0.1283,  0.0000, -0.0771],          [-0.1250, -0.0326, -0.1453, -0.0840,  0.1750],          [ 0.1824, -0.1408,  0.0157, -0.0611, -0.1136],          [-0.0000, -0.0000, -0.1901, -0.0286,  0.0286],          [-0.1583, -0.0000,  0.0000,  0.0133, -0.0000]]],        [[[-0.0000,  0.0000, -0.0000,  0.0000,  0.0000],          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000],          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]]], device='cuda:0',       grad_fn=<MulBackward0>)

The corresponding hook will now be of typetorch.nn.utils.prune.PruningContainer, and will store the history ofpruning applied to theweight parameter.

forhookinmodule._forward_pre_hooks.values():ifhook._tensor_name=="weight":# select out the correct hookbreakprint(list(hook))# pruning history in the container
[<torch.nn.utils.prune.RandomUnstructured object at 0x7f7425943910>, <torch.nn.utils.prune.LnStructured object at 0x7f74259437c0>]

Serializing a pruned model#

All relevant tensors, including the mask buffers and the original parametersused to compute the pruned tensors are stored in the model’sstate_dictand can therefore be easily serialized and saved, if needed.

print(model.state_dict().keys())
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

Remove pruning re-parametrization#

To make the pruning permanent, remove the re-parametrization in termsofweight_orig andweight_mask, and remove theforward_pre_hook,we can use theremove functionality fromtorch.nn.utils.prune.Note that this doesn’t undo the pruning, as if it never happened. It simplymakes it permanent, instead, by reassigning the parameterweight to themodel parameters, in its pruned version.

Prior to removing the re-parametrization:

[('weight_orig', Parameter containing:tensor([[[[-0.0272,  0.0551, -0.1933,  0.0165,  0.0877],          [-0.0308, -0.1279,  0.0496, -0.0701, -0.1179],          [ 0.0031, -0.1734, -0.0843, -0.0390, -0.0425],          [-0.1958, -0.1451,  0.1107,  0.1285, -0.1281],          [ 0.0379,  0.0496,  0.1430,  0.1946, -0.0207]]],        [[[ 0.0242, -0.1359,  0.0350,  0.0532, -0.1349],          [ 0.0802, -0.0730, -0.0084,  0.1603,  0.1444],          [ 0.1351, -0.0441,  0.0896,  0.1218,  0.0088],          [-0.0450, -0.0324, -0.0172,  0.0480,  0.1602],          [ 0.0754,  0.0280,  0.0084,  0.1256, -0.1584]]],        [[[-0.0690, -0.0977, -0.0530, -0.0642, -0.1816],          [ 0.0379, -0.1806, -0.0858,  0.1302,  0.1291],          [ 0.1237, -0.0264,  0.1423,  0.1394, -0.1692],          [ 0.1917, -0.0523,  0.1784,  0.1189, -0.1059],          [-0.0274, -0.0100, -0.1361, -0.0004,  0.1947]]],        [[[-0.0536, -0.0251,  0.0947,  0.0941, -0.0293],          [ 0.1725, -0.0306,  0.0916, -0.1544,  0.0521],          [ 0.1784,  0.1610,  0.1548, -0.1838, -0.0674],          [-0.0988, -0.1755,  0.0338,  0.0050, -0.1058],          [-0.0048,  0.1586,  0.1444, -0.0452, -0.1875]]],        [[[-0.1729, -0.0184, -0.1283,  0.0416, -0.0771],          [-0.1250, -0.0326, -0.1453, -0.0840,  0.1750],          [ 0.1824, -0.1408,  0.0157, -0.0611, -0.1136],          [-0.0942, -0.0245, -0.1901, -0.0286,  0.0286],          [-0.1583, -0.1605,  0.1254,  0.0133, -0.1682]]],        [[[-0.0387,  0.0259, -0.1654,  0.0895,  0.0093],          [-0.1045,  0.1651,  0.0966, -0.1185, -0.0338],          [ 0.0213, -0.0885, -0.0413, -0.0325, -0.0550],          [ 0.1380,  0.0995,  0.1335, -0.1250, -0.0416],          [ 0.1243,  0.0719,  0.0360,  0.0076,  0.0581]]]], device='cuda:0',       requires_grad=True)), ('bias_orig', Parameter containing:tensor([ 0.1648,  0.1975, -0.1674, -0.0989,  0.0567,  0.1619], device='cuda:0',       requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],          [1., 1., 1., 1., 1.],          [1., 1., 0., 1., 1.],          [1., 0., 0., 1., 1.],          [1., 1., 1., 1., 1.]]],        [[[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]]],        [[[0., 1., 1., 1., 1.],          [0., 1., 1., 1., 1.],          [1., 1., 1., 1., 1.],          [1., 1., 1., 1., 1.],          [0., 1., 0., 0., 1.]]],        [[[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]]],        [[[0., 1., 1., 0., 1.],          [1., 1., 1., 1., 1.],          [1., 1., 1., 1., 1.],          [0., 0., 1., 1., 1.],          [1., 0., 0., 1., 0.]]],        [[[0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.],          [0., 0., 0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([1., 1., 1., 0., 0., 0.], device='cuda:0'))]
tensor([[[[-0.0272,  0.0551, -0.1933,  0.0165,  0.0877],          [-0.0308, -0.1279,  0.0496, -0.0701, -0.1179],          [ 0.0031, -0.1734, -0.0000, -0.0390, -0.0425],          [-0.1958, -0.0000,  0.0000,  0.1285, -0.1281],          [ 0.0379,  0.0496,  0.1430,  0.1946, -0.0207]]],        [[[ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],          [ 0.0000, -0.0000, -0.0000,  0.0000,  0.0000],          [ 0.0000, -0.0000,  0.0000,  0.0000,  0.0000],          [-0.0000, -0.0000, -0.0000,  0.0000,  0.0000],          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000]]],        [[[-0.0000, -0.0977, -0.0530, -0.0642, -0.1816],          [ 0.0000, -0.1806, -0.0858,  0.1302,  0.1291],          [ 0.1237, -0.0264,  0.1423,  0.1394, -0.1692],          [ 0.1917, -0.0523,  0.1784,  0.1189, -0.1059],          [-0.0000, -0.0100, -0.0000, -0.0000,  0.1947]]],        [[[-0.0000, -0.0000,  0.0000,  0.0000, -0.0000],          [ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000],          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],          [-0.0000, -0.0000,  0.0000,  0.0000, -0.0000],          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000]]],        [[[-0.0000, -0.0184, -0.1283,  0.0000, -0.0771],          [-0.1250, -0.0326, -0.1453, -0.0840,  0.1750],          [ 0.1824, -0.1408,  0.0157, -0.0611, -0.1136],          [-0.0000, -0.0000, -0.1901, -0.0286,  0.0286],          [-0.1583, -0.0000,  0.0000,  0.0133, -0.0000]]],        [[[-0.0000,  0.0000, -0.0000,  0.0000,  0.0000],          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000],          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]]], device='cuda:0',       grad_fn=<MulBackward0>)

After removing the re-parametrization:

[('bias_orig', Parameter containing:tensor([ 0.1648,  0.1975, -0.1674, -0.0989,  0.0567,  0.1619], device='cuda:0',       requires_grad=True)), ('weight', Parameter containing:tensor([[[[-0.0272,  0.0551, -0.1933,  0.0165,  0.0877],          [-0.0308, -0.1279,  0.0496, -0.0701, -0.1179],          [ 0.0031, -0.1734, -0.0000, -0.0390, -0.0425],          [-0.1958, -0.0000,  0.0000,  0.1285, -0.1281],          [ 0.0379,  0.0496,  0.1430,  0.1946, -0.0207]]],        [[[ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],          [ 0.0000, -0.0000, -0.0000,  0.0000,  0.0000],          [ 0.0000, -0.0000,  0.0000,  0.0000,  0.0000],          [-0.0000, -0.0000, -0.0000,  0.0000,  0.0000],          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000]]],        [[[-0.0000, -0.0977, -0.0530, -0.0642, -0.1816],          [ 0.0000, -0.1806, -0.0858,  0.1302,  0.1291],          [ 0.1237, -0.0264,  0.1423,  0.1394, -0.1692],          [ 0.1917, -0.0523,  0.1784,  0.1189, -0.1059],          [-0.0000, -0.0100, -0.0000, -0.0000,  0.1947]]],        [[[-0.0000, -0.0000,  0.0000,  0.0000, -0.0000],          [ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000],          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],          [-0.0000, -0.0000,  0.0000,  0.0000, -0.0000],          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000]]],        [[[-0.0000, -0.0184, -0.1283,  0.0000, -0.0771],          [-0.1250, -0.0326, -0.1453, -0.0840,  0.1750],          [ 0.1824, -0.1408,  0.0157, -0.0611, -0.1136],          [-0.0000, -0.0000, -0.1901, -0.0286,  0.0286],          [-0.1583, -0.0000,  0.0000,  0.0133, -0.0000]]],        [[[-0.0000,  0.0000, -0.0000,  0.0000,  0.0000],          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000],          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]]], device='cuda:0',       requires_grad=True))]
print(list(module.named_buffers()))
[('bias_mask', tensor([1., 1., 1., 0., 0., 0.], device='cuda:0'))]

Pruning multiple parameters in a model#

By specifying the desired pruning technique and parameters, we can easilyprune multiple tensors in a network, perhaps according to their type, as wewill see in this example.

new_model=LeNet()forname,moduleinnew_model.named_modules():# prune 20% of connections in all 2D-conv layersifisinstance(module,torch.nn.Conv2d):prune.l1_unstructured(module,name='weight',amount=0.2)# prune 40% of connections in all linear layerselifisinstance(module,torch.nn.Linear):prune.l1_unstructured(module,name='weight',amount=0.4)print(dict(new_model.named_buffers()).keys())# to verify that all masks exist
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

Global pruning#

So far, we only looked at what is usually referred to as “local” pruning,i.e. the practice of pruning tensors in a model one by one, bycomparing the statistics (weight magnitude, activation, gradient, etc.) ofeach entry exclusively to the other entries in that tensor. However, acommon and perhaps more powerful technique is to prune the model all atonce, by removing (for example) the lowest 20% of connections across thewhole model, instead of removing the lowest 20% of connections in eachlayer. This is likely to result in different pruning percentages per layer.Let’s see how to do that usingglobal_unstructured fromtorch.nn.utils.prune.

model=LeNet()parameters_to_prune=((model.conv1,'weight'),(model.conv2,'weight'),(model.fc1,'weight'),(model.fc2,'weight'),(model.fc3,'weight'),)prune.global_unstructured(parameters_to_prune,pruning_method=prune.L1Unstructured,amount=0.2,)

Now we can check the sparsity induced in every pruned parameter, which willnot be equal to 20% in each layer. However, the global sparsity will be(approximately) 20%.

print("Sparsity in conv1.weight:{:.2f}%".format(100.*float(torch.sum(model.conv1.weight==0))/float(model.conv1.weight.nelement())))print("Sparsity in conv2.weight:{:.2f}%".format(100.*float(torch.sum(model.conv2.weight==0))/float(model.conv2.weight.nelement())))print("Sparsity in fc1.weight:{:.2f}%".format(100.*float(torch.sum(model.fc1.weight==0))/float(model.fc1.weight.nelement())))print("Sparsity in fc2.weight:{:.2f}%".format(100.*float(torch.sum(model.fc2.weight==0))/float(model.fc2.weight.nelement())))print("Sparsity in fc3.weight:{:.2f}%".format(100.*float(torch.sum(model.fc3.weight==0))/float(model.fc3.weight.nelement())))print("Global sparsity:{:.2f}%".format(100.*float(torch.sum(model.conv1.weight==0)+torch.sum(model.conv2.weight==0)+torch.sum(model.fc1.weight==0)+torch.sum(model.fc2.weight==0)+torch.sum(model.fc3.weight==0))/float(model.conv1.weight.nelement()+model.conv2.weight.nelement()+model.fc1.weight.nelement()+model.fc2.weight.nelement()+model.fc3.weight.nelement())))
Sparsity in conv1.weight: 6.67%Sparsity in conv2.weight: 12.58%Sparsity in fc1.weight: 22.13%Sparsity in fc2.weight: 12.39%Sparsity in fc3.weight: 12.98%Global sparsity: 20.00%

Extendingtorch.nn.utils.prune with custom pruning functions#

To implement your own pruning function, you can extend thenn.utils.prune module by subclassing theBasePruningMethodbase class, the same way all other pruning methods do. The base classimplements the following methods for you:__call__,apply_mask,apply,prune, andremove. Beyond some special cases, you shouldn’thave to reimplement these methods for your new pruning technique.You will, however, have to implement__init__ (the constructor),andcompute_mask (the instructions on how to compute the maskfor the given tensor according to the logic of your pruningtechnique). In addition, you will have to specify which type ofpruning this technique implements (supported options areglobal,structured, andunstructured). This is needed to determinehow to combine masks in the case in which pruning is appliediteratively. In other words, when pruning a prepruned parameter,the current pruning technique is expected to act on the unprunedportion of the parameter. Specifying thePRUNING_TYPE willenable thePruningContainer (which handles the iterativeapplication of pruning masks) to correctly identify the slice of theparameter to prune.

Let’s assume, for example, that you want to implement a pruningtechnique that prunes every other entry in a tensor (or – if thetensor has previously been pruned – in the remaining unprunedportion of the tensor). This will be ofPRUNING_TYPE='unstructured'because it acts on individual connections in a layer and not on entireunits/channels ('structured'), or across different parameters('global').

classFooBarPruningMethod(prune.BasePruningMethod):"""Prune every other entry in a tensor    """PRUNING_TYPE='unstructured'defcompute_mask(self,t,default_mask):mask=default_mask.clone()mask.view(-1)[::2]=0returnmask

Now, to apply this to a parameter in annn.Module, you shouldalso provide a simple function that instantiates the method andapplies it.

deffoobar_unstructured(module,name):"""Prunes tensor corresponding to parameter called `name` in `module`    by removing every other entry in the tensors.    Modifies module in place (and also return the modified module)    by:    1) adding a named buffer called `name+'_mask'` corresponding to the    binary mask applied to the parameter `name` by the pruning method.    The parameter `name` is replaced by its pruned version, while the    original (unpruned) parameter is stored in a new parameter named    `name+'_orig'`.    Args:        module (nn.Module): module containing the tensor to prune        name (string): parameter name within `module` on which pruning                will act.    Returns:        module (nn.Module): modified (i.e. pruned) version of the input            module    Examples:        >>> m = nn.Linear(3, 4)        >>> foobar_unstructured(m, name='bias')    """FooBarPruningMethod.apply(module,name)returnmodule

Let’s try it out!

model=LeNet()foobar_unstructured(model.fc3,name='bias')print(model.fc3.bias_mask)
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])

Total running time of the script: (0 minutes 0.542 seconds)