Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit75745c7

Browse files
authored
[None][chore] Make low_precision_combine as a llm arg (#7598)
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
1 parentbc90a34 commit75745c7

File tree

6 files changed

+21
-4
lines changed

6 files changed

+21
-4
lines changed

‎examples/llm-api/quickstart_advanced.py‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ def add_llm_args(parser):
7373
parser.add_argument('--moe_ep_size',type=int,default=-1)
7474
parser.add_argument('--moe_tp_size',type=int,default=-1)
7575
parser.add_argument('--moe_cluster_size',type=int,default=-1)
76+
parser.add_argument(
77+
'--use_low_precision_moe_combine',
78+
default=False,
79+
action='store_true',
80+
help='Use low precision combine in MoE (only for NVFP4 quantization)')
7681

7782
# KV cache
7883
parser.add_argument('--kv_cache_dtype',type=str,default='auto')
@@ -229,7 +234,7 @@ def setup_llm(args, **kwargs):
229234
enable_piecewise_cuda_graph= \
230235
args.use_piecewise_cuda_graph)
231236
ifargs.use_torch_compileelseNone,
232-
moe_config=MoeConfig(backend=args.moe_backend),
237+
moe_config=MoeConfig(backend=args.moe_backend,use_low_precision_moe_combine=args.use_low_precision_moe_combine),
233238
sampler_type=args.sampler_type,
234239
max_seq_len=args.max_seq_len,
235240
max_batch_size=args.max_batch_size,

‎tensorrt_llm/_torch/model_config.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ class ModelConfig(Generic[TConfig]):
117117
moe_backend:str='CUTLASS'# options can be CUTLASS, TRTLLM
118118
# IF true, disables FC2+finalize fusion in CUTLASS MoE backend
119119
moe_disable_finalize_fusion:bool=False
120+
# If true, use low precision combine in MoE operations (only for NVFP4 quantization)
121+
use_low_precision_moe_combine:bool=False
120122

121123
allreduce_strategy:AllReduceStrategy=AllReduceStrategy.AUTO
122124

‎tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py‎

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,8 @@ def __init__(
195195
self.use_postquant_alltoall= (os.environ.get(
196196
"TRTLLM_MOE_POST_QUANT_ALLTOALLV","1")
197197
=="1")andqm.has_nvfp4()
198-
self.use_low_precision_combine= (os.environ.get(
199-
"TRTLLM_MOE_USE_LOW_PRECISION_COMBINE","0")
200-
=="1")andqm.has_nvfp4()
198+
self.use_low_precision_combine=model_config.use_low_precision_moe_combineandqm.has_nvfp4(
199+
)
201200

202201
ifself.alltoall_method_type==AlltoallMethodType.MNNVL:
203202
MnnvlMemory.initialize()

‎tensorrt_llm/_torch/pyexecutor/config.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class PyTorchConfig:
6060
moe_backend:str='CUTLASS'
6161

6262
moe_disable_finalize_fusion:bool=False
63+
use_low_precision_moe_combine:bool=False
6364

6465
enable_mixed_sampler:bool=False
6566
"""

‎tensorrt_llm/_torch/pyexecutor/model_engine.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ def __init__(
307307
moe_backend=pytorch_backend_config.moe_backend,
308308
moe_disable_finalize_fusion=pytorch_backend_config.
309309
moe_disable_finalize_fusion,
310+
use_low_precision_moe_combine=pytorch_backend_config.
311+
use_low_precision_moe_combine,
310312
load_format=pytorch_backend_config.load_format,
311313
max_num_tokens=max_num_tokens,
312314
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,

‎tensorrt_llm/llmapi/llm_args.py‎

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ class MoeConfig(StrictBaseModel):
191191
"Disable FC2+finalize kernel fusion in CUTLASS MoE backend. Setting this to True recovers deterministic numerical behavior with top-k > 2."
192192
)
193193

194+
use_low_precision_moe_combine:bool=Field(
195+
default=False,
196+
description=
197+
"Use low precision combine in MoE operations (only for NVFP4 quantization). When enabled, uses lower precision for combining expert outputs to improve performance."
198+
)
199+
194200
@classmethod
195201
deffrom_dict(cls,data:dict):
196202
returncls(**data)
@@ -2586,6 +2592,8 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
25862592
moe_load_balancer=self.moe_config.load_balancer,
25872593
attn_backend=self.attn_backend,
25882594
moe_backend=self.moe_config.backend,
2595+
use_low_precision_moe_combine=self.moe_config.
2596+
use_low_precision_moe_combine,
25892597
enable_mixed_sampler=self.enable_mixed_sampler,
25902598
sampler_type=self.sampler_type,
25912599
kv_cache_dtype=self.kv_cache_config.dtype,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp