torch.nn.utils.weight_norm#
- torch.nn.utils.weight_norm(module,name='weight',dim=0)[source]#
Apply weight normalization to a parameter in the given module.
Weight normalization is a reparameterization that decouples the magnitudeof a weight tensor from its direction. This replaces the parameter specifiedby
name(e.g.'weight') with two parameters: one specifying the magnitude(e.g.'weight_g') and one specifying the direction (e.g.'weight_v').Weight normalization is implemented via a hook that recomputes the weighttensor from the magnitude and direction before everyforward()call.By default, with
dim=0, the norm is computed independently per outputchannel/plane. To compute a norm over the entire weight tensor, usedim=None.Seehttps://arxiv.org/abs/1602.07868
Warning
This function is deprecated. Use
torch.nn.utils.parametrizations.weight_norm()which uses the modern parametrization API. The newweight_normis compatiblewithstate_dictgenerated from oldweight_norm.Migration guide:
The magnitude (
weight_g) and direction (weight_v) are now expressedasparametrizations.weight.original0andparametrizations.weight.original1respectively. If this is bothering you, please comment onpytorch/pytorch#102999To remove the weight normalization reparametrization, use
torch.nn.utils.parametrize.remove_parametrizations().The weight is no longer recomputed once at module forward; instead, it willbe recomputed on every access. To restore the old behavior, use
torch.nn.utils.parametrize.cached()before invoking the modulein question.
- Parameters
- Returns
The original module with the weight norm hook
- Return type
T_module
Example:
>>>m=weight_norm(nn.Linear(20,40),name='weight')>>>mLinear(in_features=20, out_features=40, bias=True)>>>m.weight_g.size()torch.Size([40, 1])>>>m.weight_v.size()torch.Size([40, 20])