Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitf877823

Browse files
authored
[#8781][fix] Cache the AllReduce wrapper to avoid re-allocating workspace which caused a hang (#8803)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
1 parentda73410 commitf877823

File tree

1 file changed

+16
-3
lines changed
  • tensorrt_llm/_torch/auto_deploy/distributed

1 file changed

+16
-3
lines changed

‎tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py‎

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
from ...distributedimportAllReduce,allgather
99
from ...modules.linearimportAllReduceFusionOp,AllReduceParams,AllReduceStrategy
1010

11+
# Cache AllReduce modules to avoid recreating on every call
12+
# This is critical for CUDA graph compatibility - recreating modules during
13+
# warmup causes hangs due to workspace allocation with CPU synchronization
14+
_allreduce_cache= {}
15+
1116
deftrtllm_allgather(tensor,dim,sizes=None):
1217
rank,world_size=get_rank_world_size()
1318
p_config=Mapping(world_size=world_size,tp_size=world_size,rank=rank)
@@ -16,9 +21,17 @@ def trtllm_allgather(tensor, dim, sizes=None):
1621
deftrtllm_allreduce(tensor,op,all_reduce_params=None):
1722
rank,world_size=get_rank_world_size()
1823
assertop==ReduceOp.SUM,"TRT-LLM all reduce only supports SUM op."
19-
p_config=Mapping(world_size=world_size,tp_size=world_size,rank=rank)
20-
# Use Strategy.NCCL until https://nvbugspro.nvidia.com/bug/5331013 is fixed, then change to Strategy.AUTO
21-
torch_op=AllReduce(mapping=p_config,strategy=AllReduceStrategy.NCCL)
24+
25+
# Cache key includes rank, world_size, and dtype to handle different configurations
26+
cache_key= (rank,world_size,tensor.dtype)
27+
ifcache_keynotin_allreduce_cache:
28+
p_config=Mapping(world_size=world_size,tp_size=world_size,rank=rank)
29+
# Use Strategy.AUTO for optimal performance
30+
_allreduce_cache[cache_key]=AllReduce(
31+
mapping=p_config,strategy=AllReduceStrategy.AUTO,dtype=tensor.dtype
32+
)
33+
34+
torch_op=_allreduce_cache[cache_key]
2235
returntorch_op(tensor,all_reduce_params=all_reduce_params)
2336

2437
@torch.library.custom_op(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp