PruningContainer#
- classtorch.nn.utils.prune.PruningContainer(*args)[source]#
Container holding a sequence of pruning methods for iterative pruning.
Keeps track of the order in which pruning methods are applied and handlescombining successive pruning calls.
Accepts as argument an instance of a BasePruningMethod or an iterable ofthem.
- add_pruning_method(method)[source]#
Add a child pruning
methodto the container.- Parameters
method (subclass ofBasePruningMethod) – child pruning methodto be added to the container.
- classmethodapply(module,name,*args,importance_scores=None,**kwargs)[source]#
Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly andthe reparametrization of a tensor in terms of the original tensorand the pruning mask.
- Parameters
module (nn.Module) – module containing the tensor to prune
name (str) – parameter name within
moduleon which pruningwill act.args – arguments passed on to a subclass of
BasePruningMethodimportance_scores (torch.Tensor) – tensor of importance scores (ofsame shape as module parameter) used to compute mask for pruning.The values in this tensor indicate the importance of thecorresponding elements in the parameter being pruned.If unspecified or None, the parameter will be used in its place.
kwargs – keyword arguments passed on to a subclass of a
BasePruningMethod
- apply_mask(module)[source]#
Simply handles the multiplication between the parameter being pruned and the generated mask.
Fetches the mask and the original tensor from the moduleand returns the pruned version of the tensor.
- Parameters
module (nn.Module) – module containing the tensor to prune
- Returns
pruned version of the input tensor
- Return type
pruned_tensor (torch.Tensor)
- compute_mask(t,default_mask)[source]#
Apply the latest
methodby computing the new partial masks and returning its combination with thedefault_mask.The new partial mask should be computed on the entries or channelsthat were not zeroed out by the
default_mask.Which portions of the tensortthe new mask will be calculated fromdepends on thePRUNING_TYPE(handled by the type handler):for ‘unstructured’, the mask will be computed from the raveledlist of nonmasked entries;
for ‘structured’, the mask will be computed from the nonmaskedchannels in the tensor;
for ‘global’, the mask will be computed across all entries.
- Parameters
t (torch.Tensor) – tensor representing the parameter to prune(of same dimensions as
default_mask).default_mask (torch.Tensor) – mask from previous pruning iteration.
- Returns
new mask that combines the effectsof the
default_maskand the new mask from the currentpruningmethod(of same dimensions asdefault_maskandt).- Return type
mask (torch.Tensor)
- prune(t,default_mask=None,importance_scores=None)[source]#
Compute and returns a pruned version of input tensor
t.According to the pruning rule specified in
compute_mask().- Parameters
t (torch.Tensor) – tensor to prune (of same dimensions as
default_mask).importance_scores (torch.Tensor) – tensor of importance scores (ofsame shape as
t) used to compute mask for pruningt.The values in this tensor indicate the importance of thecorresponding elements in thetthat is being pruned.If unspecified or None, the tensortwill be used in its place.default_mask (torch.Tensor,optional) – mask from previous pruningiteration, if any. To be considered when determining whatportion of the tensor that pruning should act on. If None,default to a mask of ones.
- Returns
pruned version of tensor
t.
- remove(module)[source]#
Remove the pruning reparameterization from a module.
The pruned parameter named
nameremains permanently pruned,and the parameter namedname+'_orig'is removed from the parameter list.Similarly, the buffer namedname+'_mask'is removed from the buffers.Note
Pruning itself is NOT undone or reversed!