RMSNorm#
- classtorch.nn.modules.normalization.RMSNorm(normalized_shape,eps=None,elementwise_affine=True,device=None,dtype=None)[source]#
Applies Root Mean Square Layer Normalization over a mini-batch of inputs.
This layer implements the operation as described inthe paperRoot Mean Square Layer Normalization
The RMS is taken over the last
Ddimensions, whereDis the dimension ofnormalized_shape. For example, ifnormalized_shapeis(3,5)(a 2-dimensional shape), the RMS is computed overthe last 2 dimensions of the input.- Parameters
normalized_shape (int orlist ortorch.Size) –
input shape from an expected inputof size
If a single integer is used, it is treated as a singleton list, and this module willnormalize over the last dimension which is expected to be of that specific size.
eps (Optional[float]) – a value added to the denominator for numerical stability. Default:
torch.finfo(x.dtype).epselementwise_affine (bool) – a boolean value that when set to
True, this modulehas learnable per-element affine parameters initialized to ones (for weights). Default:True.
- Shape:
Input:
Output: (same shape as input)
Examples:
>>>rms_norm=nn.RMSNorm([2,3])>>>input=torch.randn(2,2,3)>>>rms_norm(input)