1414# limitations under the License.
1515
1616import enum
17- from typing import Optional ,Tuple ,Union
17+ from types import EllipsisType # https://stackoverflow.com/a/66636313
18+ from typing import Optional ,Tuple ,TypeAlias ,Union ,cast
1819
1920import torch
2021from torch import nn
2425
2526class RMSNorm (nn .Module ):
2627
28+ _ARGUMENT_NOT_SPECIFIED_SENTINEL = ...
29+ _ArgumentNotSpecifiedSentinelType :TypeAlias = EllipsisType
30+
2731def __init__ (
2832self ,
2933* ,
@@ -48,12 +52,19 @@ def __init__(
4852def forward (
4953self ,
5054hidden_states :torch .Tensor ,
51- residual :Optional [torch .Tensor ]= ...,
52- )-> Union [torch .Tensor ,Tuple [torch .Tensor ,torch .Tensor ]]:
55+ residual :Union [
56+ Optional [torch .Tensor ],
57+ _ArgumentNotSpecifiedSentinelType ]= _ARGUMENT_NOT_SPECIFIED_SENTINEL ,
58+ )-> Union [torch .Tensor ,Tuple [torch .Tensor ,Optional [torch .Tensor ]]]:
59+ return_residual = True
60+ if residual is self ._ARGUMENT_NOT_SPECIFIED_SENTINEL :
61+ return_residual = False
62+ residual = None
63+
5364if IS_FLASHINFER_AVAILABLE :
5465from ..custom_ops import (flashinfer_fused_add_rmsnorm ,
5566flashinfer_rmsnorm )
56- if isinstance ( residual , torch . Tensor ) :
67+ if residual is not None :
5768flashinfer_fused_add_rmsnorm (hidden_states ,residual ,
5869self .weight ,self .variance_epsilon )
5970else :
@@ -62,7 +73,7 @@ def forward(
6273else :
6374input_dtype = hidden_states .dtype
6475hidden_states = hidden_states .to (torch .float32 )
65- if isinstance ( residual , torch . Tensor ) :
76+ if residual is not None :
6677hidden_states = hidden_states + residual .to (torch .float32 )
6778residual = hidden_states .to (input_dtype )
6879
@@ -71,20 +82,22 @@ def forward(
7182self .variance_epsilon )
7283hidden_states = self .weight * hidden_states .to (input_dtype )
7384
74- if residual is ... :
75- return hidden_states
85+ if return_residual :
86+ return hidden_states , cast ( Optional [ torch . Tensor ], residual )
7687else :
77- return hidden_states , residual
88+ return hidden_states
7889
7990def skip_forward (
8091self ,
8192hidden_states :torch .Tensor ,
82- residual :Optional [torch .Tensor ]= ...,
83- )-> Union [torch .Tensor ,Tuple [torch .Tensor ,torch .Tensor ]]:
84- if residual is ...:
93+ residual :Union [
94+ Optional [torch .Tensor ],
95+ _ArgumentNotSpecifiedSentinelType ]= _ARGUMENT_NOT_SPECIFIED_SENTINEL ,
96+ )-> Union [torch .Tensor ,Tuple [torch .Tensor ,Optional [torch .Tensor ]]]:
97+ if residual is self ._ARGUMENT_NOT_SPECIFIED_SENTINEL :
8598return hidden_states
8699else :
87- return hidden_states ,residual
100+ return hidden_states ,cast ( Optional [ torch . Tensor ], residual )
88101
89102
90103class GroupRMSNormKernelSelection (enum .Enum ):