torch.nn.utils.prune.global_unstructured#
- torch.nn.utils.prune.global_unstructured(parameters,pruning_method,importance_scores=None,**kwargs)[source]#
Globally prunes tensors corresponding to all parameters in
parametersby applying the specifiedpruning_method.Modifies modules in place by:
adding a named buffer called
name+'_mask'corresponding to thebinary mask applied to the parameternameby the pruning method.replacing the parameter
nameby its pruned version, while theoriginal (unpruned) parameter is stored in a new parameter namedname+'_orig'.
- Parameters
parameters (Iterable of(module,name)tuples) – parameters ofthe model to prune in a global fashion, i.e. by aggregating allweights prior to deciding which ones to prune. module must be oftype
nn.Module, and name must be a string.pruning_method (function) – a valid pruning function from this module,or a custom one implemented by the user that satisfies theimplementation guidelines and has
PRUNING_TYPE='unstructured'.importance_scores (dict) – a dictionary mapping (module, name) tuples tothe corresponding parameter’s importance scores tensor. The tensorshould be the same shape as the parameter, and is used for computingmask for pruning.If unspecified or None, the parameter will be used in place of itsimportance scores.
kwargs – other keyword arguments such as:amount (int or float): quantity of parameters to prune across thespecified parameters.If
float, should be between 0.0 and 1.0 and represent thefraction of parameters to prune. Ifint, it represents theabsolute number of parameters to prune.
- Raises
TypeError – if
PRUNING_TYPE!='unstructured'
Note
Since global structured pruning doesn’t make much sense unless thenorm is normalized by the size of the parameter, we now limit thescope of global pruning to unstructured methods.
Examples
>>>fromtorch.nn.utilsimportprune>>>fromcollectionsimportOrderedDict>>>net=nn.Sequential(...OrderedDict(...[...("first",nn.Linear(10,4)),...("second",nn.Linear(4,1)),...]...)...)>>>parameters_to_prune=(...(net.first,"weight"),...(net.second,"weight"),...)>>>prune.global_unstructured(...parameters_to_prune,...pruning_method=prune.L1Unstructured,...amount=10,...)>>>print(sum(torch.nn.utils.parameters_to_vector(net.buffers())==0))tensor(10)