88from ...distributed import AllReduce ,allgather
99from ...modules .linear import AllReduceFusionOp ,AllReduceParams ,AllReduceStrategy
1010
11+ # Cache AllReduce modules to avoid recreating on every call
12+ # This is critical for CUDA graph compatibility - recreating modules during
13+ # warmup causes hangs due to workspace allocation with CPU synchronization
14+ _allreduce_cache = {}
15+
1116def trtllm_allgather (tensor ,dim ,sizes = None ):
1217rank ,world_size = get_rank_world_size ()
1318p_config = Mapping (world_size = world_size ,tp_size = world_size ,rank = rank )
@@ -16,9 +21,17 @@ def trtllm_allgather(tensor, dim, sizes=None):
1621def trtllm_allreduce (tensor ,op ,all_reduce_params = None ):
1722rank ,world_size = get_rank_world_size ()
1823assert op == ReduceOp .SUM ,"TRT-LLM all reduce only supports SUM op."
19- p_config = Mapping (world_size = world_size ,tp_size = world_size ,rank = rank )
20- # Use Strategy.NCCL until https://nvbugspro.nvidia.com/bug/5331013 is fixed, then change to Strategy.AUTO
21- torch_op = AllReduce (mapping = p_config ,strategy = AllReduceStrategy .NCCL )
24+
25+ # Cache key includes rank, world_size, and dtype to handle different configurations
26+ cache_key = (rank ,world_size ,tensor .dtype )
27+ if cache_key not in _allreduce_cache :
28+ p_config = Mapping (world_size = world_size ,tp_size = world_size ,rank = rank )
29+ # Use Strategy.AUTO for optimal performance
30+ _allreduce_cache [cache_key ]= AllReduce (
31+ mapping = p_config ,strategy = AllReduceStrategy .AUTO ,dtype = tensor .dtype
32+ )
33+
34+ torch_op = _allreduce_cache [cache_key ]
2235return torch_op (tensor ,all_reduce_params = all_reduce_params )
2336
2437@torch .library .custom_op (