torch.nn.utils.parametrizations.orthogonal#
- torch.nn.utils.parametrizations.orthogonal(module,name='weight',orthogonal_map=None,*,use_trivialization=True)[source]#
Apply an orthogonal or unitary parametrization to a matrix or a batch of matrices.
Letting be or, the parametrizedmatrix isorthogonal as
where is the conjugate transpose when is complexand the transpose when is real-valued, and is then-dimensional identity matrix.In plain words, will have orthonormal columns wheneverand orthonormal rows otherwise.
If the tensor has more than two dimensions, we consider it as a batch of matrices of shape(…, m, n).
The matrix may be parametrized via three different
orthogonal_mapin terms of the original tensor:"matrix_exp"/"cayley":thematrix_exp()and theCayley map are applied to a skew-symmetric to give an orthogonal matrix."householder": computes a product of Householder reflectors(householder_product()).
"matrix_exp"/"cayley"often make the parametrized weight converge faster than"householder", but they are slower to compute for very thin or very wide matrices.If
use_trivialization=True(default), the parametrization implements the “Dynamic Trivialization Framework”,where an extra matrix is stored undermodule.parametrizations.weight[0].base. This helps theconvergence of the parametrized layer at the expense of some extra memory use.SeeTrivializations for Gradient-Based Optimization on Manifolds .Initial value of:If the original tensor is not parametrized and
use_trivialization=True(default), the initial valueof is that of the original tensor if it is orthogonal (or unitary in the complex case)and it is orthogonalized via the QR decomposition otherwise (seetorch.linalg.qr()).Same happens when it is not parametrized andorthogonal_map="householder"even whenuse_trivialization=False.Otherwise, the initial value is the result of the composition of all the registeredparametrizations applied to the original tensor.Note
This function is implemented using the parametrization functionalityin
register_parametrization().- Parameters
module (nn.Module) – module on which to register the parametrization.
name (str,optional) – name of the tensor to make orthogonal. Default:
"weight".orthogonal_map (str,optional) – One of the following:
"matrix_exp","cayley","householder".Default:"matrix_exp"if the matrix is square or complex,"householder"otherwise.use_trivialization (bool,optional) – whether to use the dynamic trivialization framework.Default:
True.
- Returns
The original module with an orthogonal parametrization registered to the specifiedweight
- Return type
Example:
>>>orth_linear=orthogonal(nn.Linear(20,40))>>>orth_linearParametrizedLinear(in_features=20, out_features=40, bias=True(parametrizations): ModuleDict( (weight): ParametrizationList( (0): _Orthogonal() )))>>>Q=orth_linear.weight>>>torch.dist(Q.T@Q,torch.eye(20))tensor(4.9332e-07)