Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitc1feaee

Browse files
sychen52dominicshanshan
authored andcommitted
[OMNIML-2336][feat] add W4A8 NVFP4 FP8 fused moe (NVIDIA#7968)
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent542f497 commitc1feaee

File tree

6 files changed

+312
-26
lines changed

6 files changed

+312
-26
lines changed

‎tensorrt_llm/_torch/modules/fused_moe/create_moe.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def get_moe_cls(
4141
quant_config.quant_mode.has_fp8_block_scales()
4242
orquant_config.quant_mode.has_nvfp4()
4343
orquant_config.quant_mode.has_w4a16_mxfp4()
44+
orquant_config.quant_mode.has_w4a8_nvfp4_fp8()
4445
orquant_config.quant_mode.has_w4a8_mxfp4_fp8()
4546
orquant_config.quant_mode.has_w4a8_mxfp4_mxfp8()):
4647
returnTRTLLMGenFusedMoE

‎tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py‎

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
NVFP4TRTLLMGenFusedMoEMethod,
1616
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod,
1717
W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod,
18+
W4A8NVFP4FP8TRTLLMGenFusedMoEMethod,
1819
W4A16MXFP4TRTLLMGenFusedMoEMethod)
1920
from .routingimportBaseMoeRoutingMethod,DeepSeekV3MoeRoutingMethod
2021

@@ -111,7 +112,7 @@ def __init__(
111112

112113
def_check_configs(self):
113114
assertself.has_deepseek_fp8_block_scales \
114-
orself.has_nvfp4orself.has_w4a16_mxfp4 \
115+
orself.has_nvfp4orself.has_w4a16_mxfp4orself.has_w4a8_nvfp4_fp8\
115116
orself.has_w4a8_mxfp4_fp8orself.has_w4a8_mxfp4_mxfp8,"TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
116117

117118
ifself.biasorself.swiglu_alphaisnotNoneorself.swiglu_betaisnotNoneorself.swiglu_limitisnotNone:
@@ -125,6 +126,8 @@ def _get_quant_method(self):
125126
returnNVFP4TRTLLMGenFusedMoEMethod()
126127
elifself.quant_config.layer_quant_mode.has_w4a16_mxfp4():
127128
returnW4A16MXFP4TRTLLMGenFusedMoEMethod()
129+
elifself.quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8():
130+
returnW4A8NVFP4FP8TRTLLMGenFusedMoEMethod()
128131
elifself.quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8():
129132
returnW4A8MXFP4FP8TRTLLMGenFusedMoEMethod()
130133
elifself.quant_config.layer_quant_mode.has_w4a8_mxfp4_mxfp8():
@@ -147,8 +150,8 @@ def create_weights(self):
147150
self._weights_created=True
148151
self._check_configs()
149152

150-
# TODO: FIX this.
151-
if (self.has_w4a16_mxfp4orself.has_w4a8_mxfp4_fp8
153+
if (self.has_w4a16_mxfp4orself.has_w4a8_nvfp4_fp8
154+
orself.has_w4a8_mxfp4_fp8
152155
orself.has_w4a8_mxfp4_mxfp8)andnotself.bias:
153156
self.w3_w1_bias=nn.Parameter(torch.zeros(
154157
(self.w3_w1_weight.shape[0],self.w3_w1_weight.shape[1]),
@@ -378,6 +381,46 @@ def forward_impl(
378381
)
379382
final_hidden_states=final_hidden_states[:, :self.
380383
hidden_size].contiguous()
384+
elifself.has_w4a8_nvfp4_fp8:
385+
386+
ifnotrun_post_quant_allgather:
387+
hidden_states_fp8,_=torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
388+
x,1.0/self.fc31_input_scale)
389+
else:
390+
hidden_states_fp8=x
391+
392+
outputs=torch.ops.trtllm.fp8_fp4_block_scale_moe_runner(
393+
router_logits,
394+
routing_bias,
395+
hidden_states_fp8,
396+
self.w3_w1_weight,
397+
self.w3_w1_weight_scale.view(torch.float8_e4m3fn),
398+
self.w2_weight,
399+
self.w2_weight_scale.view(torch.float8_e4m3fn),
400+
self.fc31_scale_c.data,
401+
self.fc31_alpha.data,
402+
self.fc2_alpha.data,
403+
self.num_slots,
404+
top_k,
405+
n_group,
406+
topk_group,
407+
self.intermediate_size_per_partition,
408+
self.
409+
slot_start,# local_expert_start; use ep_rank if stride!=1
410+
self.expert_size_per_partition,# local_expert_size
411+
routed_scaling_factor,
412+
self.routing_method.routing_method_type,
413+
do_finalize=do_finalize,
414+
act_type=0,
415+
topk_ids=token_selected_experts,
416+
topk_weights=token_final_scales,
417+
)
418+
419+
ifnotdo_finalize:
420+
assertnotself.reduce_results,"reduce_results must be False when do_finalize is False"
421+
returnoutputs
422+
else:
423+
final_hidden_states=outputs[0]
381424
elifself.has_w4a8_mxfp4_fp8:
382425
pad_size=self.w3_w1_weight.shape[-1]*2-x.shape[-1]
383426
ifnotrun_post_quant_allgather:

‎tensorrt_llm/_torch/modules/fused_moe/interface.py‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,12 @@ def has_nvfp4(self):
301301
returnself.quant_configisnotNoneandself.quant_config.layer_quant_mode.has_nvfp4(
302302
)
303303

304+
@property
305+
defhas_w4a8_nvfp4_fp8(self):
306+
assertself._weights_created
307+
returnself.quant_configisnotNoneandself.quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8(
308+
)
309+
304310
@property
305311
defhas_w4a8_mxfp4_fp8(self):
306312
assertself._weights_created

‎tensorrt_llm/_torch/modules/fused_moe/quantization.py‎

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def trtllmgen_maybe_get_cached_w3_w1_permute_indices(
9696
torch.Tensor],
9797
epilogue_tile_m:int,
9898
num_elts_per_sf:Union[None,int]=None)->torch.Tensor:
99-
key= (dst_w3_w1_weight.shape,"w31")
99+
key= (dst_w3_w1_weight.shape,"w31",int(num_elts_per_sfor-1))
100100
ifkeynotincache_permute_indices:
101101
# Get permute indices and chain them together
102102
permute0=get_reorder_rows_for_gated_act_gemm_row_indices(
@@ -122,7 +122,7 @@ def trtllmgen_maybe_get_cached_w2_permute_indices(
122122
torch.Tensor],
123123
epilogue_tile_m:int,
124124
num_elts_per_sf:Union[None,int]=None)->torch.Tensor:
125-
key= (dst_w2_weight.shape,"w2")
125+
key= (dst_w2_weight.shape,"w2",int(num_elts_per_sfor-1))
126126
ifkeynotincache_permute_indices:
127127
ifnum_elts_per_sfisNone:
128128
permute_indices= (get_shuffle_matrix_a_row_indices(
@@ -1478,11 +1478,15 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
14781478
Base class for NVFP4 fused MoE methods for all backends.
14791479
"""
14801480

1481-
defcreate_weights(self,module:torch.nn.Module,weight_dtype,
1482-
weight_vec_size,block_scales_dtype,
1483-
block_scales_vec_size):
1481+
defcreate_weights(self,
1482+
module:torch.nn.Module,
1483+
weight_dtype,
1484+
weight_vec_size,
1485+
block_scales_dtype,
1486+
block_scales_vec_size,
1487+
scaling_vector_size=16):
14841488

1485-
module.scaling_vector_size=16
1489+
module.scaling_vector_size=scaling_vector_size
14861490
# Divide by 16 because we use int64 to pack 16 fp4 values
14871491
w3_w1_weight_shape= (module.expert_size_per_partition,
14881492
module.intermediate_size_per_partition*2,
@@ -1893,9 +1897,12 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
18931897
non_blocking=True)
18941898

18951899
defload_expert_w3_w1_weight_scale_nvfp4(
1896-
self,module:torch.nn.Module,w1_weight_scale:torch.Tensor,
1900+
self,
1901+
module:torch.nn.Module,
1902+
w1_weight_scale:torch.Tensor,
18971903
w3_weight_scale:torch.Tensor,
1898-
dst_w3_w1_weight_scale:torch.Tensor):
1904+
dst_w3_w1_weight_scale:torch.Tensor,
1905+
num_elts_per_sf:int=16):
18991906
device=dst_w3_w1_weight_scale.device
19001907
assertdevice.type=="cuda"
19011908
w1_weight_scale=load_weight_shard(w1_weight_scale,
@@ -1933,7 +1940,7 @@ def load_expert_w3_w1_weight_scale_nvfp4(
19331940
dst_w3_w1_weight_scale.view(float4_sf_dtype),
19341941
self._cache_permute_indices,
19351942
epilogue_tile_m,
1936-
num_elts_per_sf=16)
1943+
num_elts_per_sf=num_elts_per_sf)
19371944

19381945
# Shuffle the weight according to permute indices
19391946
w3_w1_weight_scale=torch.ops.trtllm.shuffle_matrix(
@@ -1949,9 +1956,11 @@ def load_expert_w3_w1_weight_scale_nvfp4(
19491956
processed_w3_w1_weight_scale.view(
19501957
self.block_scales_dtype).reshape(orig_shape))
19511958

1952-
defload_expert_w2_weight_scale_nvfp4(self,module:torch.nn.Module,
1959+
defload_expert_w2_weight_scale_nvfp4(self,
1960+
module:torch.nn.Module,
19531961
w2_weight_scale:torch.Tensor,
1954-
dst_w2_weight_scale:torch.Tensor):
1962+
dst_w2_weight_scale:torch.Tensor,
1963+
num_elts_per_sf:int=16):
19551964
device=dst_w2_weight_scale.device
19561965
assertdevice.type=="cuda"
19571966
w2_weight_scale=load_weight_shard(w2_weight_scale,
@@ -1976,7 +1985,7 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
19761985
dst_w2_weight_scale.view(float4_sf_dtype),
19771986
self._cache_permute_indices,
19781987
epilogue_tile_m,
1979-
num_elts_per_sf=16)
1988+
num_elts_per_sf=num_elts_per_sf)
19801989

19811990
# Shuffle the weight according to permute indices
19821991
w_shuffled=torch.ops.trtllm.shuffle_matrix(
@@ -1998,6 +2007,56 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
19982007
non_blocking=True)
19992008

20002009

2010+
classW4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEMethod):
2011+
2012+
defcreate_weights(self,module:torch.nn.Module):
2013+
weight_vec_size=torch.iinfo(self.weight_dtype).bits//4
2014+
block_scales_vec_size=1
2015+
2016+
NVFP4FusedMoEMethod.create_weights(self,module,self.weight_dtype,
2017+
weight_vec_size,
2018+
self.block_scales_dtype,
2019+
block_scales_vec_size,32)
2020+
2021+
fc31_scale_c=nn.Parameter(torch.ones(module.expert_size_per_partition,
2022+
dtype=torch.float32),
2023+
requires_grad=False)
2024+
module.register_parameter("fc31_scale_c",fc31_scale_c)
2025+
2026+
self.setup_quant_scales(module)
2027+
2028+
defload_expert_w3_w1_weight_scale_nvfp4(
2029+
self,module:torch.nn.Module,w1_weight_scale:torch.Tensor,
2030+
w3_weight_scale:torch.Tensor,
2031+
dst_w3_w1_weight_scale:torch.Tensor):
2032+
returnsuper().load_expert_w3_w1_weight_scale_nvfp4(
2033+
module,w1_weight_scale,w3_weight_scale,dst_w3_w1_weight_scale,
2034+
32)
2035+
2036+
defload_expert_w2_weight_scale_nvfp4(self,module:torch.nn.Module,
2037+
w2_weight_scale:torch.Tensor,
2038+
dst_w2_weight_scale:torch.Tensor):
2039+
returnsuper().load_expert_w2_weight_scale_nvfp4(
2040+
module,w2_weight_scale,dst_w2_weight_scale,32)
2041+
2042+
defload_all_fp4_weight_scales_and_alphas(
2043+
self,module:torch.nn.Module,weights:Dict,
2044+
load_expert_ids:List[int],dst_w3_w1_weight_scale:torch.Tensor,
2045+
dst_w2_weight_scale:torch.Tensor,dst_fc31_alpha:torch.Tensor,
2046+
dst_fc2_alpha:torch.Tensor):
2047+
super().load_all_fp4_weight_scales_and_alphas(
2048+
module,weights,load_expert_ids,dst_w3_w1_weight_scale,
2049+
dst_w2_weight_scale,dst_fc31_alpha,dst_fc2_alpha)
2050+
# The kernel we use will convert nvfp4 to e4m3 before matmul,
2051+
# so the range of the scale factor can only be [0,448/6].
2052+
dst_w3_w1_weight_scale.copy_((dst_w3_w1_weight_scale.to(torch.float32)/
2053+
6.0).to(torch.float8_e4m3fn))
2054+
dst_w2_weight_scale.copy_((dst_w2_weight_scale.to(torch.float32)/
2055+
6.0).to(torch.float8_e4m3fn))
2056+
dst_fc31_alpha.copy_(dst_fc31_alpha*6.0)
2057+
dst_fc2_alpha.copy_(dst_fc2_alpha*6.0)
2058+
2059+
20012060
def_get_weight_alignment(weight_alignment,scaling_vector_size,tp_size,
20022061
shard_dim_size):
20032062

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp