Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[TRTLLM-6589][feat] Support CUDA graph for DeepEP#7514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged

Conversation

@yifeizhang-c
Copy link
Contributor

@yifeizhang-cyifeizhang-c commentedSep 4, 2025
edited
Loading

Description

Support CUDA Graph for DeepEP all to all method. Previously, only DeepEPLowLatency supports CUDA Graph, this PR enables CUDA Graph also for DeepEP, and prioritizes DeepEP over DeepEPLowLatency in all to all method selection.

DeepEP diff:https://github.com/deepseek-ai/DeepEP/compare/be2582ffe69b5e7d61c3bc9bf7a5316bc48261f9...5be51b228a7c82dbdb213ea58e77bffd12b38af8?expand=1

DeepEP changes includes:

  • Added CUDA Graph support for DeepEP internode
  • Added corresponding checks intests/test_internode.py for validation
  • Cherry-pickedNVSHMEM_QP_DEPTH related updates to avoid DeepEPLowLatency from hanging

@yifeizhang-c
Copy link
ContributorAuthor

yifeizhang-c commentedSep 4, 2025
edited
Loading

Changes to DeepEP is viewable atyifeizhang-c/DeepEP@a2a3923

@yuantailing
Copy link
Member

Please rebase DeepEP to the commit referenced by CMakeLists.txt

@yifeizhang-cyifeizhang-cforce-pushed thedev-yifeiz-enable-DeepEP-cuda-graph branch from776ba60 toc7157f3CompareSeptember 5, 2025 07:38
@yifeizhang-c
Copy link
ContributorAuthor

Please rebase DeepEP to the commit referenced by CMakeLists.txt

Sorry for the previous negligence. I updated the commit id.

@yifeizhang-cyifeizhang-cforce-pushed thedev-yifeiz-enable-DeepEP-cuda-graph branch 2 times, most recently from7ec92a3 todee504dCompareSeptember 26, 2025 05:43
Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
@yifeizhang-cyifeizhang-cforce-pushed thedev-yifeiz-enable-DeepEP-cuda-graph branch fromdee504d to3c88c07CompareSeptember 29, 2025 06:59
@yifeizhang-cyifeizhang-c marked this pull request as ready for reviewSeptember 29, 2025 07:15
@yifeizhang-cyifeizhang-c requested a review froma team as acode ownerSeptember 29, 2025 07:15
@yifeizhang-c
Copy link
ContributorAuthor

/bot run

@yifeizhang-c
Copy link
ContributorAuthor

/bot kill

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20233 [ run ] triggered by Bot

@coderabbitai
Copy link
Contributor

coderabbitaibot commentedSep 29, 2025
edited
Loading

📝 Walkthrough

Walkthrough

The commit pin in CMake was updated. VariableLengthBuffer.dispatch gained new parameters and computes num_worst_tokens conditionally. Fused MoE wide EP logic now checks DeepEP feasibility to choose between DeepEP and DeepEPLowLatency, propagates use_cuda_graph and additional runtime parameters, and sets NVSHMEM_QP_DEPTH when using the low-latency path.

Changes

Cohort / File(s)Summary
Build pin update
cpp/tensorrt_llm/deep_ep/CMakeLists.txt
UpdatedDEEP_EP_COMMIT to5be51b228a7c82dbdb213ea58e77bffd12b38af8; no other build logic changes.
DeepEP utils dispatch signature
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
ExtendedVariableLengthBuffer.dispatch to acceptall_rank_max_num_tokens,ep_size,use_cuda_graph; computesnum_worst_tokens = all_rank_max_num_tokens * ep_size whenuse_cuda_graph else0; forwards toself.buffer.dispatch.
Fused MoE EP selection and param propagation
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Added DeepEP feasibility check to select between DeepEP and DeepEPLowLatency; propagateduse_cuda_graph,all_rank_max_num_tokens,ep_size across dispatch/forward paths; setNVSHMEM_QP_DEPTH when using DeepEPLowLatency; importedlocal_mpi_size for internal checks.

Sequence Diagram(s)

sequenceDiagram  autonumber  participant Model  participant FusedMoE  participant Feasibility as DeepEP Feasibility  participant DeepEP  participant LowLatency as DeepEPLowLatency  participant Env as Env Vars  Model->>FusedMoE: forward(inputs, config)  Note over FusedMoE: use_cuda_graph, all_rank_max_num_tokens, ep_size extracted  FusedMoE->>Feasibility: is_deep_ep_feasible(config, local_mpi_size)  alt DeepEP feasible    FusedMoE->>DeepEP: dispatch(..., all_rank_max_num_tokens, ep_size, use_cuda_graph)    Note over DeepEP: num_worst_tokens = all_rank_max_num_tokens * ep_size if use_cuda_graph else 0    DeepEP-->>FusedMoE: results  else Not feasible    FusedMoE->>Env: set NVSHMEM_QP_DEPTH (for on-flight WR)    Env-->>FusedMoE: ack    FusedMoE->>LowLatency: dispatch(..., all_rank_max_num_tokens, ep_size, use_cuda_graph)    LowLatency-->>FusedMoE: results  end  FusedMoE-->>Model: outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check nameStatusExplanationResolution
Docstring Coverage⚠️ WarningDocstring coverage is 0.00% which is insufficient. The required threshold is 80.00%.You can run@coderabbitai generate docstrings to improve docstring coverage.
Description Check⚠️ WarningThe pull request description includes a clear “## Description” section but is missing the required “## Test Coverage” section listing relevant tests and the “## PR Checklist” section with the standard repository checklist items defined in the template.Please add a “## Test Coverage” section detailing the tests that validate the new DeepEP CUDA Graph support and complete the “## PR Checklist” section with the required checklist items from the repository template.
✅ Passed checks (1 passed)
Check nameStatusExplanation
Title Check✅ PassedThe title succinctly describes the main change by indicating support for CUDA Graph in DeepEP, aligns with the ticket identifier and feature type, and clearly communicates the primary objective of the PR without extra noise.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment@coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitaicoderabbitaibot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (1)

1-2:Add NVIDIA Apache-2.0 header (2025).

Per coding guidelines, prepend the NVIDIA Apache-2.0 copyright header.

Apply:

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#     http://www.apache.org/licenses/LICENSE-2.0+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1)

1-1:Add NVIDIA Apache-2.0 header (2025).

Same header requirement as other Python sources.

🧹 Nitpick comments (5)
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (2)

75-78:Guard against zero when capturing graphs.

If use_cuda_graph is True, assert all_rank_max_num_tokens > 0 to avoid capturing with a degenerate size.

Apply:

-        if use_cuda_graph:-            num_worst_tokens = all_rank_max_num_tokens * ep_size+        if use_cuda_graph:+            assert all_rank_max_num_tokens > 0, "all_rank_max_num_tokens must be > 0 when using CUDA graph"+            num_worst_tokens = all_rank_max_num_tokens * ep_size

217-223:typing.Map is undefined here; use Dict from typing.

Map isn’t imported; switch to Dict to avoid type-checker errors.

Apply:

-from typing import List, Optional, Tuple, Union+from typing import Dict, List, Optional, Tuple, Union@@-        self.buffers: Map[Mapping,-                          weakref.ReferenceType[VariableLengthBuffer]] = {}-        self.low_latency_buffers: Map[-            Mapping,-            weakref.ReferenceType[VariableLengthLowLatencyBuffer]] = {}+        self.buffers: Dict[Mapping,+                           weakref.ReferenceType[VariableLengthBuffer]] = {}+        self.low_latency_buffers: Dict[+            Mapping,+            weakref.ReferenceType[VariableLengthLowLatencyBuffer]] = {}
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (3)

217-219:Don’t overwrite NVSHMEM_QP_DEPTH if already set.

Respect user/env overrides and only set a default when unset. Keep the helpful issue link comment.

Apply:

-                os.environ['NVSHMEM_QP_DEPTH'] = str(-                    2 * (self.deep_ep_max_num_tokens + 1))+                if 'NVSHMEM_QP_DEPTH' not in os.environ:+                    os.environ['NVSHMEM_QP_DEPTH'] = str(+                        2 * (self.deep_ep_max_num_tokens + 1))

257-279:DeepEP feasibility: add divisibility check and clarify intranode detection.

Avoid false positives when moe_ep_size is not a multiple of local_mpi_size; also be explicit about internode divisibility.

Apply:

         def is_deepep_feasible(num_ranks: int) -> bool:             NUM_INTRANODE_SUPPORTED_RANKS = {2, 4, 8}             REQUIRED_LOCAL_MPI_SIZE = 8             NUM_INTERNODE_SUPPORTED_RDMA_RANKS = {2, 4, 8, 16}             mpi_size = local_mpi_size()             # Intranode cases-            if num_ranks == mpi_size and num_ranks in NUM_INTRANODE_SUPPORTED_RANKS:+            if num_ranks == mpi_size and num_ranks in NUM_INTRANODE_SUPPORTED_RANKS:                 return True             # Internode cases             if mpi_size != REQUIRED_LOCAL_MPI_SIZE:                 return False-            num_rdma_nodes = num_ranks // mpi_size+            if num_ranks % mpi_size != 0:+                return False+            num_rdma_nodes = num_ranks // mpi_size             return num_rdma_nodes in NUM_INTERNODE_SUPPORTED_RDMA_RANKS

301-305:Selection fallback to DeepEPLowLatency is sensible; log reason optionally.

Optional: log when DeepEP is installed but not feasible (ranks mismatch), to aid configuration debugging.

-                if is_deepep_feasible(mapping.moe_ep_size):-                    return AlltoallMethodType.DeepEP-                return AlltoallMethodType.DeepEPLowLatency+                if is_deepep_feasible(mapping.moe_ep_size):+                    return AlltoallMethodType.DeepEP+                logger.debug("DeepEP not feasible for moe_ep_size=%d (local_mpi_size=%d); using DeepEPLowLatency",+                             mapping.moe_ep_size, local_mpi_size())+                return AlltoallMethodType.DeepEPLowLatency
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and betweence381d6 and3c88c07.

📒 Files selected for processing (3)
  • cpp/tensorrt_llm/deep_ep/CMakeLists.txt (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (2 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (8 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
🧠 Learnings (1)
📚 Learning: 2025-08-14T06:36:40.701Z
Learnt from: timlee0212PR: NVIDIA/TensorRT-LLM#6886File: tensorrt_llm/_torch/models/modeling_deepseekv3.py:0-0Timestamp: 2025-08-14T06:36:40.701ZLearning: In DeepSeek V3 model (tensorrt_llm/_torch/models/modeling_deepseekv3.py), the disagreement between AllReduce.__init__ guard and _compute_mlp_tp_size logic for MNNVL usage is expected by design. The AllReduce component and MLP TP-size computation intentionally use different criteria for MNNVL availability decisions.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2)
tensorrt_llm/_utils.py (1)
  • local_mpi_size (513-514)
tensorrt_llm/mapping.py (1)
  • moe_ep_rank (364-365)
🔇 Additional comments (5)
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (1)

83-84:Propagating num_worst_tokens into Buffer.dispatch LGTM.

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (3)

104-105:Threading use_cuda_graph through the module LGTM.

Good to cache model_config.use_cuda_graph on the instance for later dispatch calls.


658-660:LGTM: DeepEP dispatch call updated for quantization; compatibility check remains valid.


568-574:DeepEP dispatch signature matches deep_ep_utils. Verified thatDeepEPBuffer.dispatch(x, topk_idx, topk_weights, num_experts, global_expert_id_offset, all_rank_max_num_tokens, ep_size, use_cuda_graph) indeep_ep_utils.py exactly aligns with the call infused_moe_wide_ep.py, so no changes are needed.

cpp/tensorrt_llm/deep_ep/CMakeLists.txt (1)

1-1:DeepEP commit pin updated: verify API compatibility and patchability

  • Tarball URL responds with HTTP 302; runcurl -sI -L https://github.com/deepseek-ai/DeepEP/archive/5be51b228a7c82dbdb213ea58e77bffd12b38af8.tar.gz to confirm a final 200 status.
  • Confirm commit 5be51b22…’sbuffer.py definesBuffer.dispatch(num_worst_tokens=…).
  • Ensure existing nvshmem patches still apply cleanly against this commit.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20234 [ kill ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20233 [ run ] completed with stateABORTED
LLM/main/L0_MergeRequest_PR #15258(Blue Ocean) completed with status: ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20234 [ kill ] completed with stateSUCCESS
Successfully killed previous jobs for commit3c88c07

@yifeizhang-c
Copy link
ContributorAuthor

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20243 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20243 [ run ] completed with stateSUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15263 completed with status: 'FAILURE'

@yifeizhang-c
Copy link
ContributorAuthor

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20257 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20257 [ run ] completed with stateSUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15274 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check thererun report for details.

@hlu1hlu1enabled auto-merge (squash)October 2, 2025 17:13
@hlu1hlu1 merged commit34d158b intoNVIDIA:mainOct 2, 2025
11 checks passed
evezhier pushed a commit to evezhier/TensorRT-LLM that referenced this pull requestOct 3, 2025
Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
faradawn pushed a commit to faradawn/TensorRT-LLM that referenced this pull requestOct 3, 2025
Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>Signed-off-by: Faradawn Yang <faradawny@gmail.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull requestNov 1, 2025
Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull requestNov 3, 2025
Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull requestNov 3, 2025
Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull requestNov 3, 2025
Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@coderabbitaicoderabbitai[bot]coderabbitai[bot] left review comments

@yuantailingyuantailingyuantailing approved these changes

@hlu1hlu1hlu1 approved these changes

Assignees

No one assigned

Labels

None yet

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

4 participants

@yifeizhang-c@yuantailing@tensorrt-cicd@hlu1

[8]ページ先頭

©2009-2025 Movatter.jp