torch.nn.utils.spectral_norm#
- torch.nn.utils.spectral_norm(module,name='weight',n_power_iterations=1,eps=1e-12,dim=None)[source]#
Apply spectral normalization to a parameter in the given module.
Spectral normalization stabilizes the training of discriminators (critics)in Generative Adversarial Networks (GANs) by rescaling the weight tensorwith spectral norm of the weight matrix calculated usingpower iteration method. If the dimension of the weight tensor is greaterthan 2, it is reshaped to 2D in power iteration method to get spectralnorm. This is implemented via a hook that calculates spectral norm andrescales weight before every
forward()call.SeeSpectral Normalization for Generative Adversarial Networks .
- Parameters
module (nn.Module) – containing module
name (str,optional) – name of weight parameter
n_power_iterations (int,optional) – number of power iterations tocalculate spectral norm
eps (float,optional) – epsilon for numerical stability incalculating norms
dim (int,optional) – dimension corresponding to number of outputs,the default is
0, except for modules that are instances ofConvTranspose{1,2,3}d, when it is1
- Returns
The original module with the spectral norm hook
- Return type
T_module
Note
This function has been reimplemented as
torch.nn.utils.parametrizations.spectral_norm()using the newparametrization functionality intorch.nn.utils.parametrize.register_parametrization(). Please usethe newer version. This function will be deprecated in a future versionof PyTorch.Example:
>>>m=spectral_norm(nn.Linear(20,40))>>>mLinear(in_features=20, out_features=40, bias=True)>>>m.weight_u.size()torch.Size([40])