Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Transformer sequence parallel forward#5560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Draft
Priya2698 wants to merge3 commits intomain
base:main
Choose a base branch
Loading
frompm/transformer_sp_forward

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698Priya2698 commentedNov 20, 2025
edited
Loading

Wall-clock time measured on8 80GB H100 nodes inms:

TEnvFuser
2.52.1

Timings are after changes in#5563

@github-actions
Copy link

github-actionsbot commentedNov 20, 2025
edited
Loading

Review updated until commit7249b56

Description

  • Implement sequence parallel forward pass for transformer models

  • Fix decomposeLinear sharding with proper dtype casting and bias handling

  • Add support for both tensor parallel and sequence parallel modes

  • Update benchmarking and profiling to distinguish communication kernels

Changes walkthrough

Relevant files
Bug fix
decompose_reshardings.cpp
Fix decomposeLinear sharding with proper casting                 

csrc/preseg_passes/decompose_reshardings.cpp

  • Fix decomposeRowParallelLinearWithBias to properly cast to Float
    before operations
  • Add bias broadcasting and dtype casting back to original type
  • Apply TransformReplay to multiple tensors in loop for proper domain
    replay
  • +9/-4     
    Enhancement
    communication_executor.cpp
    Update communication kernel profiling label                           

    csrc/runtime/communication_executor.cpp

  • Change scheduler type from ExprEval to Communication in profiler
  • Properly label communication kernels in performance profiling
  • +1/-1     
    benchmark_utils.py
    Add parallelism type definitions                                                 

    tests/python/multidevice/benchmark_utils.py

  • Add Parallelism enum with TENSOR_PARALLEL and SEQUENCE_PARALLEL
    options
  • Include documentation links to NVIDIA NeMo parallelism concepts
  • +8/-0     
    test_matmul.py
    Enhance linear reduce scatter test with bias and profiling

    tests/python/multidevice/test_matmul.py

  • Update linear operation to use BFloat16 dtype and add bias parameter
  • Modify multidevice scheduling to handle bias tensor properly
  • Add PythonProfiler verification for single reduce scatter kernel
  • Change test parameters and use integer tensors for deterministic
    testing
  • +30/-20 
    Tests
    test_transformer.py
    Add sequence parallel transformer forward test                     

    tests/python/multidevice/test_transformer.py

  • Add parameterized test for both tensor parallel and sequence parallel
    modes
  • Implement sequence parallel scheduling with input tensor sharding
  • Update input tensor creation and sharding based on parallelism type
  • Adjust output shape assertions for sharded sequence dimension
  • Add sequence length divisibility check for sequence parallel mode
  • +43/-30 
    Refactoring
    test_transformer_engine.py
    Remove duplicate parallelism enum definition                         

    tests/python/multidevice/test_transformer_engine.py

  • Remove duplicate Parallelism enum definition
  • Import Parallelism from benchmark_utils module
  • Maintain existing functionality while avoiding code duplication
  • +1/-8     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Sequence Parallelism Implementation

    The implementation adds sequence parallelism support to transformer forward pass. The changes include splitting input along sequence dimension and parallelizing accordingly. Need to verify that the sequence parallel implementation correctly handles all transformer components (attention, MLP, layer norm) and that the performance claims are accurate.

    ifparallelism==Parallelism.SEQUENCE_PARALLEL:inp.split(1,num_devices,inner_split=False)inp.axis(1).parallelize(nvfuser.ParallelType.mesh_x)
    Linear Operation Decomposition

    The changes to decomposeRowParallelLinearWithBias function add an intermediate upcast operation and split the bias addition into separate steps. This appears to be an optimization but needs validation to ensure it doesn't introduce numerical instability or performance regressions.

    auto* upcast_without_bias = maybeCastOp(DataType::Float, without_bias);TensorView* broadcasted_bias = [&]() {constint64_t rank_after_broadcast =std::ssize(TensorDomain::noReductions(without_bias->getLogicalDomain()));NVF_ERROR(      rank_after_broadcast >0,"without_bias is expected to be at least 1D:",      without_bias);  std::vector<bool>is_broadcast_dim(rank_after_broadcast,true);  is_broadcast_dim.back() =false;returnbroadcast(linear_op->bias(), is_broadcast_dim);}();TensorView* with_bias = add(upcast_without_bias, broadcasted_bias);TensorView* new_out = maybeCastOp(out->dtype(), with_bias);

    @Priya2698Priya2698 changed the base branch frommain topm/decompose_linearNovember 20, 2025 14:28
    @Priya2698Priya2698 changed the base branch frompm/decompose_linear tomainNovember 20, 2025 14:28
    Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

    Reviewers

    No reviews

    Assignees

    No one assigned

    Labels

    None yet

    Projects

    None yet

    Milestone

    No milestone

    Development

    Successfully merging this pull request may close these issues.

    2 participants

    @Priya2698

    [8]ページ先頭

    ©2009-2025 Movatter.jp