RandomStructured#
- classtorch.nn.utils.prune.RandomStructured(amount,dim=-1)[source]#
Prune entire (currently unpruned) channels in a tensor at random.
- Parameters
- classmethodapply(module,name,amount,dim=-1)[source]#
Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly andthe reparametrization of a tensor in terms of the original tensorand the pruning mask.
- Parameters
module (nn.Module) – module containing the tensor to prune
name (str) – parameter name within
moduleon which pruningwill act.amount (int orfloat) – quantity of parameters to prune.If
float, should be between 0.0 and 1.0 and represent thefraction of parameters to prune. Ifint, it represents theabsolute number of parameters to prune.dim (int,optional) – index of the dim along which we definechannels to prune. Default: -1.
- apply_mask(module)[source]#
Simply handles the multiplication between the parameter being pruned and the generated mask.
Fetches the mask and the original tensor from the moduleand returns the pruned version of the tensor.
- Parameters
module (nn.Module) – module containing the tensor to prune
- Returns
pruned version of the input tensor
- Return type
pruned_tensor (torch.Tensor)
- compute_mask(t,default_mask)[source]#
Compute and returns a mask for the input tensor
t.Starting from a base
default_mask(which should be a mask of onesif the tensor has not been pruned yet), generate a random mask toapply on top of thedefault_maskby randomly zeroing out channelsalong the specified dim of the tensor.- Parameters
t (torch.Tensor) – tensor representing the parameter to prune
default_mask (torch.Tensor) – Base mask from previous pruningiterations, that need to be respected after the new mask isapplied. Same dims as
t.
- Returns
mask to apply to
t, of same dims ast- Return type
mask (torch.Tensor)
- Raises
IndexError – if
self.dim>=len(t.shape)
- prune(t,default_mask=None,importance_scores=None)[source]#
Compute and returns a pruned version of input tensor
t.According to the pruning rule specified in
compute_mask().- Parameters
t (torch.Tensor) – tensor to prune (of same dimensions as
default_mask).importance_scores (torch.Tensor) – tensor of importance scores (ofsame shape as
t) used to compute mask for pruningt.The values in this tensor indicate the importance of thecorresponding elements in thetthat is being pruned.If unspecified or None, the tensortwill be used in its place.default_mask (torch.Tensor,optional) – mask from previous pruningiteration, if any. To be considered when determining whatportion of the tensor that pruning should act on. If None,default to a mask of ones.
- Returns
pruned version of tensor
t.
- remove(module)[source]#
Remove the pruning reparameterization from a module.
The pruned parameter named
nameremains permanently pruned,and the parameter namedname+'_orig'is removed from the parameter list.Similarly, the buffer namedname+'_mask'is removed from the buffers.Note
Pruning itself is NOT undone or reversed!