Rate this Page

LnStructured#

classtorch.nn.utils.prune.LnStructured(amount,n,dim=-1)[source]#

Prune entire (currently unpruned) channels in a tensor based on their Ln-norm.

Parameters
  • amount (int orfloat) – quantity of channels to prune.Iffloat, should be between 0.0 and 1.0 and represent thefraction of parameters to prune. Ifint, it represents theabsolute number of parameters to prune.

  • n (int,float,inf,-inf,'fro','nuc') – See documentation of validentries for argumentp intorch.norm().

  • dim (int,optional) – index of the dim along which we definechannels to prune. Default: -1.

classmethodapply(module,name,amount,n,dim,importance_scores=None)[source]#

Add pruning on the fly and reparametrization of a tensor.

Adds the forward pre-hook that enables pruning on the fly andthe reparametrization of a tensor in terms of the original tensorand the pruning mask.

Parameters
  • module (nn.Module) – module containing the tensor to prune

  • name (str) – parameter name withinmodule on which pruningwill act.

  • amount (int orfloat) – quantity of parameters to prune.Iffloat, should be between 0.0 and 1.0 and represent thefraction of parameters to prune. Ifint, it represents theabsolute number of parameters to prune.

  • n (int,float,inf,-inf,'fro','nuc') – See documentation of validentries for argumentp intorch.norm().

  • dim (int) – index of the dim along which we define channels toprune.

  • importance_scores (torch.Tensor) – tensor of importance scores (of sameshape as module parameter) used to compute mask for pruning.The values in this tensor indicate the importance of the correspondingelements in the parameter being pruned.If unspecified or None, the module parameter will be used in its place.

apply_mask(module)[source]#

Simply handles the multiplication between the parameter being pruned and the generated mask.

Fetches the mask and the original tensor from the moduleand returns the pruned version of the tensor.

Parameters

module (nn.Module) – module containing the tensor to prune

Returns

pruned version of the input tensor

Return type

pruned_tensor (torch.Tensor)

compute_mask(t,default_mask)[source]#

Compute and returns a mask for the input tensort.

Starting from a basedefault_mask (which should be a mask of onesif the tensor has not been pruned yet), generate a mask to apply ontop of thedefault_mask by zeroing out the channels along thespecified dim with the lowest Ln-norm.

Parameters
  • t (torch.Tensor) – tensor representing the parameter to prune

  • default_mask (torch.Tensor) – Base mask from previous pruningiterations, that need to be respected after the new mask isapplied. Same dims ast.

Returns

mask to apply tot, of same dims ast

Return type

mask (torch.Tensor)

Raises

IndexError – ifself.dim>=len(t.shape)

prune(t,default_mask=None,importance_scores=None)[source]#

Compute and returns a pruned version of input tensort.

According to the pruning rule specified incompute_mask().

Parameters
  • t (torch.Tensor) – tensor to prune (of same dimensions asdefault_mask).

  • importance_scores (torch.Tensor) – tensor of importance scores (ofsame shape ast) used to compute mask for pruningt.The values in this tensor indicate the importance of thecorresponding elements in thet that is being pruned.If unspecified or None, the tensort will be used in its place.

  • default_mask (torch.Tensor,optional) – mask from previous pruningiteration, if any. To be considered when determining whatportion of the tensor that pruning should act on. If None,default to a mask of ones.

Returns

pruned version of tensort.

remove(module)[source]#

Remove the pruning reparameterization from a module.

The pruned parameter namedname remains permanently pruned,and the parameter namedname+'_orig' is removed from the parameter list.Similarly, the buffer namedname+'_mask' is removed from the buffers.

Note

Pruning itself is NOT undone or reversed!