Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit78ac556

Browse files
[None][fix] Fix the aux_stream in Llama4MinLatencyFusedMoE
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
1 parentfac5220 commit78ac556

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

‎tensorrt_llm/_torch/models/modeling_llama.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from ..modules.multi_stream_utilsimportmaybe_execute_in_parallel
4242
from ..modules.rms_normimportRMSNorm
4343
from ..speculativeimportSpecMetadata
44-
from ..utilsimportFp4QuantizedTensor
44+
from ..utilsimportAuxStreamType,Fp4QuantizedTensor
4545
from .modeling_multimodal_utilsimportfuse_input_embeds
4646
from .modeling_speculativeimportSpecDecOneEngineForCausalLM
4747
from .modeling_utilsimport (DecoderModel,DecoderModelForCausalLM,
@@ -293,6 +293,7 @@ def __init__(
293293
weight_loading_mode=MoEWeightLoadingMode.FUSED_GATE_UP_PROJ,
294294
model_config=model_config,
295295
apply_router_weight_on_input=True,
296+
aux_stream_dict={AuxStreamType.MoeChunkingOverlap:aux_stream},
296297
layer_idx=layer_idx)
297298

298299
self.router=Linear(

‎tensorrt_llm/_torch/models/modeling_llama_min_latency.py‎

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
WeightsLoadingConfig)
2424
from ..modules.multi_stream_utilsimportmaybe_execute_in_parallel
2525
from ..speculativeimportSpecMetadata
26-
from ..utilsimportFp4QuantizedTensor
26+
from ..utilsimportAuxStreamType,Fp4QuantizedTensor
2727
from .modeling_llamaimportLlama4Attention,Llama4DecoderLayer,Llama4MoE
2828

2929
# Perf heuristics thresholds.
@@ -438,7 +438,8 @@ def __init__(
438438
dtype:Optional[torch.dtype]=None,
439439
reduce_results:bool=False,
440440
model_config:ModelConfig=ModelConfig(),
441-
aux_stream:torch.cuda.Stream=torch.cuda.Stream(),
441+
aux_stream_dict:Optional[Dict[AuxStreamType,
442+
torch.cuda.Stream]]=None,
442443
weight_loading_mode:MoEWeightLoadingMode=MoEWeightLoadingMode.
443444
VANILLA,
444445
apply_router_weight_on_input:bool=False,
@@ -452,7 +453,7 @@ def __init__(
452453
dtype=dtype,
453454
reduce_results=reduce_results,
454455
model_config=model_config,
455-
aux_stream=aux_stream,
456+
aux_stream_dict=aux_stream_dict,
456457
weight_loading_mode=weight_loading_mode,
457458
apply_router_weight_on_input=apply_router_weight_on_input,
458459
)
@@ -554,6 +555,7 @@ def __init__(
554555
weight_loading_mode=MoEWeightLoadingMode.FUSED_GATE_UP_PROJ,
555556
model_config=model_config,
556557
apply_router_weight_on_input=True,
558+
aux_stream_dict={AuxStreamType.MoeChunkingOverlap:aux_stream},
557559
)
558560

559561
self.router=Llama4MinLatencyLinear(
@@ -801,7 +803,7 @@ def forward(
801803
orself.fusion_config.POST_MLP_FUSION
802804
ifneeds_post_allreduceandself.next_layer_layernormisnotNone:
803805
ifuse_fp8_allreduceandself.next_attnisnotNone \
804-
andhasattr(elf.next_attn.qkv_proj,'input_scale'):
806+
andhasattr(self.next_attn.qkv_proj,'input_scale'):
805807
hidden_states,residual=self.all_reduce(
806808
hidden_states,
807809
all_reduce_params=AllReduceParams(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp