Rate this Page

torch.nn.utils.prune.random_structured#

torch.nn.utils.prune.random_structured(module,name,amount,dim)[source]#

Prune tensor by removing random channels along the specified dimension.

Prunes tensor corresponding to parameter calledname inmoduleby removing the specifiedamount of (currently unpruned) channelsalong the specifieddim selected at random.Modifies module in place (and also return the modified module)by:

  1. adding a named buffer calledname+'_mask' corresponding to thebinary mask applied to the parametername by the pruning method.

  2. replacing the parametername by its pruned version, while theoriginal (unpruned) parameter is stored in a new parameter namedname+'_orig'.

Parameters
  • module (nn.Module) – module containing the tensor to prune

  • name (str) – parameter name withinmodule on which pruningwill act.

  • amount (int orfloat) – quantity of parameters to prune.Iffloat, should be between 0.0 and 1.0 and represent thefraction of parameters to prune. Ifint, it represents theabsolute number of parameters to prune.

  • dim (int) – index of the dim along which we define channels to prune.

Returns

modified (i.e. pruned) version of the input module

Return type

module (nn.Module)

Examples

>>>m=prune.random_structured(nn.Linear(5,3),"weight",amount=3,dim=1)>>>columns_pruned=int(sum(torch.sum(m.weight,dim=0)==0))>>>print(columns_pruned)3