Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite9f26fe

Browse files
authored
[None][chore] Cherry-pick from (#7598) Make low_precision_combine as a llm arg (#7898)
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
1 parent28b9a81 commite9f26fe

File tree

6 files changed

+21
-4
lines changed

6 files changed

+21
-4
lines changed

‎examples/llm-api/quickstart_advanced.py‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ def add_llm_args(parser):
7373
parser.add_argument('--moe_ep_size',type=int,default=-1)
7474
parser.add_argument('--moe_tp_size',type=int,default=-1)
7575
parser.add_argument('--moe_cluster_size',type=int,default=-1)
76+
parser.add_argument(
77+
'--use_low_precision_moe_combine',
78+
default=False,
79+
action='store_true',
80+
help='Use low precision combine in MoE (only for NVFP4 quantization)')
7681

7782
# KV cache
7883
parser.add_argument('--kv_cache_dtype',type=str,default='auto')
@@ -236,7 +241,7 @@ def setup_llm(args, **kwargs):
236241
enable_piecewise_cuda_graph= \
237242
args.use_piecewise_cuda_graph)
238243
ifargs.use_torch_compileelseNone,
239-
moe_config=MoeConfig(backend=args.moe_backend),
244+
moe_config=MoeConfig(backend=args.moe_backend,use_low_precision_moe_combine=args.use_low_precision_moe_combine),
240245
sampler_type=args.sampler_type,
241246
max_seq_len=args.max_seq_len,
242247
max_batch_size=args.max_batch_size,

‎tensorrt_llm/_torch/model_config.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ class ModelConfig(Generic[TConfig]):
133133
moe_backend:str='CUTLASS'# options can be CUTLASS, TRTLLM
134134
# IF true, disables FC2+finalize fusion in CUTLASS MoE backend
135135
moe_disable_finalize_fusion:bool=False
136+
# If true, use low precision combine in MoE operations (only for NVFP4 quantization)
137+
use_low_precision_moe_combine:bool=False
136138

137139
allreduce_strategy:AllReduceStrategy=AllReduceStrategy.AUTO
138140

‎tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ def __init__(
193193
ifself.enable_alltoall:
194194
self.use_postquant_alltoall= (os.environ.get(
195195
"TRTLLM_MOE_POST_QUANT_ALLTOALLV","1")=="1")
196-
self.use_low_precision_combine= (os.environ.get(
197-
"TRTLLM_MOE_USE_LOW_PRECISION_COMBINE","0")=="1")
196+
self.use_low_precision_combine=model_config.use_low_precision_moe_combine
198197

199198
ifself.alltoall_method_type==AlltoallMethodType.MNNVL:
200199
MnnvlMemory.initialize()

‎tensorrt_llm/_torch/pyexecutor/config.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class PyTorchConfig:
6262
moe_backend:str='CUTLASS'
6363

6464
moe_disable_finalize_fusion:bool=False
65+
use_low_precision_moe_combine:bool=False
6566

6667
sampler_type:SamplerType=SamplerType.auto
6768
"""

‎tensorrt_llm/_torch/pyexecutor/model_loader.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,9 @@ def _load_and_validate_config(
303303
attn_backend=self.pytorch_backend_config.attn_backend,
304304
moe_backend=self.pytorch_backend_config.moe_backend,
305305
moe_disable_finalize_fusion=self.pytorch_backend_config.
306-
moe_disable_finalize_fusion)
306+
moe_disable_finalize_fusion,
307+
use_low_precision_moe_combine=self.pytorch_backend_config.
308+
use_low_precision_moe_combine)
307309

308310
validate_and_set_kv_cache_quant(
309311
config,self.pytorch_backend_config.kv_cache_dtype)

‎tensorrt_llm/llmapi/llm_args.py‎

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ class MoeConfig(StrictBaseModel):
192192
"Disable FC2+finalize kernel fusion in CUTLASS MoE backend. Setting this to True recovers deterministic numerical behavior with top-k > 2."
193193
)
194194

195+
use_low_precision_moe_combine:bool=Field(
196+
default=False,
197+
description=
198+
"Use low precision combine in MoE operations (only for NVFP4 quantization). When enabled, uses lower precision for combining expert outputs to improve performance."
199+
)
200+
195201
@classmethod
196202
deffrom_dict(cls,data:dict):
197203
returncls(**data)
@@ -2614,6 +2620,8 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
26142620
moe_load_balancer=self.moe_config.load_balancer,
26152621
attn_backend=self.attn_backend,
26162622
moe_backend=self.moe_config.backend,
2623+
use_low_precision_moe_combine=self.moe_config.
2624+
use_low_precision_moe_combine,
26172625
sampler_type=self.sampler_type,
26182626
kv_cache_dtype=self.kv_cache_config.dtype,
26192627
mamba_ssm_cache_dtype=self.kv_cache_config.mamba_ssm_cache_dtype,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp