@@ -96,7 +96,7 @@ def trtllmgen_maybe_get_cached_w3_w1_permute_indices(
9696torch .Tensor ],
9797epilogue_tile_m :int ,
9898num_elts_per_sf :Union [None ,int ]= None )-> torch .Tensor :
99- key = (dst_w3_w1_weight .shape ,"w31" )
99+ key = (dst_w3_w1_weight .shape ,"w31" , int ( num_elts_per_sf or - 1 ) )
100100if key not in cache_permute_indices :
101101# Get permute indices and chain them together
102102permute0 = get_reorder_rows_for_gated_act_gemm_row_indices (
@@ -122,7 +122,7 @@ def trtllmgen_maybe_get_cached_w2_permute_indices(
122122torch .Tensor ],
123123epilogue_tile_m :int ,
124124num_elts_per_sf :Union [None ,int ]= None )-> torch .Tensor :
125- key = (dst_w2_weight .shape ,"w2" )
125+ key = (dst_w2_weight .shape ,"w2" , int ( num_elts_per_sf or - 1 ) )
126126if key not in cache_permute_indices :
127127if num_elts_per_sf is None :
128128permute_indices = (get_shuffle_matrix_a_row_indices (
@@ -1478,11 +1478,15 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
14781478 Base class for NVFP4 fused MoE methods for all backends.
14791479 """
14801480
1481- def create_weights (self ,module :torch .nn .Module ,weight_dtype ,
1482- weight_vec_size ,block_scales_dtype ,
1483- block_scales_vec_size ):
1481+ def create_weights (self ,
1482+ module :torch .nn .Module ,
1483+ weight_dtype ,
1484+ weight_vec_size ,
1485+ block_scales_dtype ,
1486+ block_scales_vec_size ,
1487+ scaling_vector_size = 16 ):
14841488
1485- module .scaling_vector_size = 16
1489+ module .scaling_vector_size = scaling_vector_size
14861490# Divide by 16 because we use int64 to pack 16 fp4 values
14871491w3_w1_weight_shape = (module .expert_size_per_partition ,
14881492module .intermediate_size_per_partition * 2 ,
@@ -1893,9 +1897,12 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
18931897non_blocking = True )
18941898
18951899def load_expert_w3_w1_weight_scale_nvfp4 (
1896- self ,module :torch .nn .Module ,w1_weight_scale :torch .Tensor ,
1900+ self ,
1901+ module :torch .nn .Module ,
1902+ w1_weight_scale :torch .Tensor ,
18971903w3_weight_scale :torch .Tensor ,
1898- dst_w3_w1_weight_scale :torch .Tensor ):
1904+ dst_w3_w1_weight_scale :torch .Tensor ,
1905+ num_elts_per_sf :int = 16 ):
18991906device = dst_w3_w1_weight_scale .device
19001907assert device .type == "cuda"
19011908w1_weight_scale = load_weight_shard (w1_weight_scale ,
@@ -1933,7 +1940,7 @@ def load_expert_w3_w1_weight_scale_nvfp4(
19331940dst_w3_w1_weight_scale .view (float4_sf_dtype ),
19341941self ._cache_permute_indices ,
19351942epilogue_tile_m ,
1936- num_elts_per_sf = 16 )
1943+ num_elts_per_sf = num_elts_per_sf )
19371944
19381945# Shuffle the weight according to permute indices
19391946w3_w1_weight_scale = torch .ops .trtllm .shuffle_matrix (
@@ -1949,9 +1956,11 @@ def load_expert_w3_w1_weight_scale_nvfp4(
19491956processed_w3_w1_weight_scale .view (
19501957self .block_scales_dtype ).reshape (orig_shape ))
19511958
1952- def load_expert_w2_weight_scale_nvfp4 (self ,module :torch .nn .Module ,
1959+ def load_expert_w2_weight_scale_nvfp4 (self ,
1960+ module :torch .nn .Module ,
19531961w2_weight_scale :torch .Tensor ,
1954- dst_w2_weight_scale :torch .Tensor ):
1962+ dst_w2_weight_scale :torch .Tensor ,
1963+ num_elts_per_sf :int = 16 ):
19551964device = dst_w2_weight_scale .device
19561965assert device .type == "cuda"
19571966w2_weight_scale = load_weight_shard (w2_weight_scale ,
@@ -1976,7 +1985,7 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
19761985dst_w2_weight_scale .view (float4_sf_dtype ),
19771986self ._cache_permute_indices ,
19781987epilogue_tile_m ,
1979- num_elts_per_sf = 16 )
1988+ num_elts_per_sf = num_elts_per_sf )
19801989
19811990# Shuffle the weight according to permute indices
19821991w_shuffled = torch .ops .trtllm .shuffle_matrix (
@@ -1998,6 +2007,56 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
19982007non_blocking = True )
19992008
20002009
2010+ class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod (NVFP4TRTLLMGenFusedMoEMethod ):
2011+
2012+ def create_weights (self ,module :torch .nn .Module ):
2013+ weight_vec_size = torch .iinfo (self .weight_dtype ).bits // 4
2014+ block_scales_vec_size = 1
2015+
2016+ NVFP4FusedMoEMethod .create_weights (self ,module ,self .weight_dtype ,
2017+ weight_vec_size ,
2018+ self .block_scales_dtype ,
2019+ block_scales_vec_size ,32 )
2020+
2021+ fc31_scale_c = nn .Parameter (torch .ones (module .expert_size_per_partition ,
2022+ dtype = torch .float32 ),
2023+ requires_grad = False )
2024+ module .register_parameter ("fc31_scale_c" ,fc31_scale_c )
2025+
2026+ self .setup_quant_scales (module )
2027+
2028+ def load_expert_w3_w1_weight_scale_nvfp4 (
2029+ self ,module :torch .nn .Module ,w1_weight_scale :torch .Tensor ,
2030+ w3_weight_scale :torch .Tensor ,
2031+ dst_w3_w1_weight_scale :torch .Tensor ):
2032+ return super ().load_expert_w3_w1_weight_scale_nvfp4 (
2033+ module ,w1_weight_scale ,w3_weight_scale ,dst_w3_w1_weight_scale ,
2034+ 32 )
2035+
2036+ def load_expert_w2_weight_scale_nvfp4 (self ,module :torch .nn .Module ,
2037+ w2_weight_scale :torch .Tensor ,
2038+ dst_w2_weight_scale :torch .Tensor ):
2039+ return super ().load_expert_w2_weight_scale_nvfp4 (
2040+ module ,w2_weight_scale ,dst_w2_weight_scale ,32 )
2041+
2042+ def load_all_fp4_weight_scales_and_alphas (
2043+ self ,module :torch .nn .Module ,weights :Dict ,
2044+ load_expert_ids :List [int ],dst_w3_w1_weight_scale :torch .Tensor ,
2045+ dst_w2_weight_scale :torch .Tensor ,dst_fc31_alpha :torch .Tensor ,
2046+ dst_fc2_alpha :torch .Tensor ):
2047+ super ().load_all_fp4_weight_scales_and_alphas (
2048+ module ,weights ,load_expert_ids ,dst_w3_w1_weight_scale ,
2049+ dst_w2_weight_scale ,dst_fc31_alpha ,dst_fc2_alpha )
2050+ # The kernel we use will convert nvfp4 to e4m3 before matmul,
2051+ # so the range of the scale factor can only be [0,448/6].
2052+ dst_w3_w1_weight_scale .copy_ ((dst_w3_w1_weight_scale .to (torch .float32 )/
2053+ 6.0 ).to (torch .float8_e4m3fn ))
2054+ dst_w2_weight_scale .copy_ ((dst_w2_weight_scale .to (torch .float32 )/
2055+ 6.0 ).to (torch .float8_e4m3fn ))
2056+ dst_fc31_alpha .copy_ (dst_fc31_alpha * 6.0 )
2057+ dst_fc2_alpha .copy_ (dst_fc2_alpha * 6.0 )
2058+
2059+
20012060def _get_weight_alignment (weight_alignment ,scaling_vector_size ,tp_size ,
20022061shard_dim_size ):
20032062