Rate this Page

torch.nn.utils.prune.global_unstructured#

torch.nn.utils.prune.global_unstructured(parameters,pruning_method,importance_scores=None,**kwargs)[source]#

Globally prunes tensors corresponding to all parameters inparameters by applying the specifiedpruning_method.

Modifies modules in place by:

  1. adding a named buffer calledname+'_mask' corresponding to thebinary mask applied to the parametername by the pruning method.

  2. replacing the parametername by its pruned version, while theoriginal (unpruned) parameter is stored in a new parameter namedname+'_orig'.

Parameters
  • parameters (Iterable of(module,name)tuples) – parameters ofthe model to prune in a global fashion, i.e. by aggregating allweights prior to deciding which ones to prune. module must be oftypenn.Module, and name must be a string.

  • pruning_method (function) – a valid pruning function from this module,or a custom one implemented by the user that satisfies theimplementation guidelines and hasPRUNING_TYPE='unstructured'.

  • importance_scores (dict) – a dictionary mapping (module, name) tuples tothe corresponding parameter’s importance scores tensor. The tensorshould be the same shape as the parameter, and is used for computingmask for pruning.If unspecified or None, the parameter will be used in place of itsimportance scores.

  • kwargs – other keyword arguments such as:amount (int or float): quantity of parameters to prune across thespecified parameters.Iffloat, should be between 0.0 and 1.0 and represent thefraction of parameters to prune. Ifint, it represents theabsolute number of parameters to prune.

Raises

TypeError – ifPRUNING_TYPE!='unstructured'

Note

Since global structured pruning doesn’t make much sense unless thenorm is normalized by the size of the parameter, we now limit thescope of global pruning to unstructured methods.

Examples

>>>fromtorch.nn.utilsimportprune>>>fromcollectionsimportOrderedDict>>>net=nn.Sequential(...OrderedDict(...[...("first",nn.Linear(10,4)),...("second",nn.Linear(4,1)),...]...)...)>>>parameters_to_prune=(...(net.first,"weight"),...(net.second,"weight"),...)>>>prune.global_unstructured(...parameters_to_prune,...pruning_method=prune.L1Unstructured,...amount=10,...)>>>print(sum(torch.nn.utils.parameters_to_vector(net.buffers())==0))tensor(10)