Rate this Page

torch.nn.utils.spectral_norm#

torch.nn.utils.spectral_norm(module,name='weight',n_power_iterations=1,eps=1e-12,dim=None)[source]#

Apply spectral normalization to a parameter in the given module.

Spectral normalization stabilizes the training of discriminators (critics)in Generative Adversarial Networks (GANs) by rescaling the weight tensorwith spectral normσ\sigma of the weight matrix calculated usingpower iteration method. If the dimension of the weight tensor is greaterthan 2, it is reshaped to 2D in power iteration method to get spectralnorm. This is implemented via a hook that calculates spectral norm andrescales weight before everyforward() call.

SeeSpectral Normalization for Generative Adversarial Networks .

Parameters
  • module (nn.Module) – containing module

  • name (str,optional) – name of weight parameter

  • n_power_iterations (int,optional) – number of power iterations tocalculate spectral norm

  • eps (float,optional) – epsilon for numerical stability incalculating norms

  • dim (int,optional) – dimension corresponding to number of outputs,the default is0, except for modules that are instances ofConvTranspose{1,2,3}d, when it is1

Returns

The original module with the spectral norm hook

Return type

T_module

Note

This function has been reimplemented astorch.nn.utils.parametrizations.spectral_norm() using the newparametrization functionality intorch.nn.utils.parametrize.register_parametrization(). Please usethe newer version. This function will be deprecated in a future versionof PyTorch.

Example:

>>>m=spectral_norm(nn.Linear(20,40))>>>mLinear(in_features=20, out_features=40, bias=True)>>>m.weight_u.size()torch.Size([40])