Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Challenges in Enabling PyTorch Native Pipeline Parallelism for Hugging Face Transformer Models#589

Discussion options

Authors:@hemildesai

Introduction

As large language models (LLMs) continue to grow in scale - from billions to hundreds of billions of parameters - training these models efficiently across multiple GPU nodes has become increasingly challenging. While data parallelism works well for smaller models, larger models often exceed the memory capacity of a single GPU or a single node, necessitating more sophisticated parallelization strategies.

Pipeline parallelism is one such strategy that addresses this challenge by splitting a model's layers across different devices and processing them in a pipelined fashion. Each device processes a different stage of the model, enabling training of models that wouldn't fit on a single device, while maintaining high GPU utilization through overlapped computation. You can read more about pipeline parallelism in thisPyTorch guide or in theMegatron-LM paper.

NeMo Automodel is a GPU-accelerated PyTorch library for training LLMs. We recently added support for PyTorch native pipeline parallelism via:

  1. AutoPipeline for any Hugging Face Transformer language model, including popular LLMs in theAutoModelForCausalLM category such asLlama,Qwen,Mistral,Gemma, with support for vision language models and additional architectures coming soon.
  2. Afunctional API for custom models, or for users seeking more granular control. Thefunctional API offers modular building blocks that can be adapted to any PyTorch model architecture—making pipeline parallelism accessible across the entire ecosystem.

This article will focus onAutoPipeline, and users can refer to the guidehere for more details on thefunctional API.

While we drew inspiration fromTorchTitan during the development of our pipelining component, enabling automatic pipeline parallelism for Hugging Face models presented a unique set of challenges. In this article, we explore those challenges and share the solutions we implemented inAutoPipeline to make pipeline parallelism both robust and user-friendly

How AutoPipeline Works: High-Level Process

To give you a high-level overview, when you callAutoPipeline(...).build(model, loss_fn), here's what happens under the hood:

  1. Model Analysis: Detect model structure (has.model attribute, number of layers, rotary embeddings, etc.)
  2. Stage Calculation: Determine virtual stages based onlayers_per_stage and validate against pipeline size
  3. Module Assignment: Generate or use provided module names for each pipeline stage (e.g., which layers go to which stage)
  4. Model Splitting: Deep copy the model (on meta device) for each stage, then remove unneeded modules (keeping only assigned modules per stage)
  5. Stage Creation: Wrap each model chunk in aPipelineStage with proper stage indexing and device placement
  6. Parallelization: Apply additional parallelization (DP/TP/FSDP) if aparallelize_fn is provided
  7. Schedule Building: Create the pipeline schedule (1f1b, interleaved, etc.) with the stages and loss function

The result is a complete pipeline-parallel setup with automatic handling of all the challenges described in this article.

Challenge 1: Module Assignment - Understanding Model Structure

When implementing pipeline parallelism, one of the first challenges is determining how to split the model across pipeline stages. This isn't simply a matter of dividing layers equally - certain components need special treatment based on how Hugging Face models are structured.

Let's examine a typical Hugging Face causal language model structure using Qwen3 as an example:

# From transformers/models/qwen3/modeling_qwen3.pyclassQwen3ForCausalLM(Qwen3PreTrainedModel,GenerationMixin):def__init__(self,config):super().__init__(config)self.model=Qwen3Model(config)# Inner model wrapperself.vocab_size=config.vocab_sizeself.lm_head=nn.Linear(config.hidden_size,config.vocab_size,bias=False)# Outside model.modelclassQwen3Model(Qwen3PreTrainedModel):def__init__(self,config:Qwen3Config):super().__init__(config)self.embed_tokens=nn.Embedding(config.vocab_size,config.hidden_size,self.padding_idx)# Inside model.modelself.layers=nn.ModuleList([Qwen3DecoderLayer(config,layer_idx)forlayer_idxinrange(config.num_hidden_layers)        ])self.norm=Qwen3RMSNorm(config.hidden_size,eps=config.rms_norm_eps)# Inside model.modelself.rotary_emb=Qwen3RotaryEmbedding(config=config)# Shared utility inside model.model

This creates a hierarchical structure where:

  • Top level:Qwen3ForCausalLM containsmodel (inner model) andlm_head (output projection)
  • Inner model:model.model containsembed_tokens,layers,norm, androtary_emb

When splitting this model across pipeline stages, different components have different placement requirements:

  1. Input Embeddings (model.embed_tokens): Must be in thefirst stage only - converts token IDs to embeddings
  2. Transformer Layers (model.layers): Distributed acrossmultiple stages - the core computation
  3. Final Normalization (model.norm): Must be in thelast or second last stage - applies final layer normalization
  4. Language Modeling Head (lm_head): Must be in thelast stage only - projects to vocabulary logits
  5. Rotary Embeddings (model.rotary_emb): Must be inall stages - shared position encoding utility

These placement constraints become even more pronounced for vision language models and other complex model architectures.

Ourgenerate_hf_model_fqn_per_model_part function infunctional.py handles this complexity automatically for most cases:

# From nemo_automodel/components/distributed/pipelining/functional.pydefgenerate_hf_model_fqn_per_model_part(num_stages:int,num_layers:int,include_embeddings:bool=True,include_lm_head:bool=True,include_rotary_emb:bool=True,fqn_prefix:str="model.",# "model." for nested HF models)->list[list[str]]:forstage_idxinrange(num_stages):stage_modules= []# First stage: add embeddings if requestedifstage_idx==0andinclude_embeddings:stage_modules.append(f"{fqn_prefix}embed_tokens")# model.embed_tokens# Add transformer layers for this stagefor_inrange(stage_layer_count):stage_modules.append(f"{fqn_prefix}layers.{current_layer}")# model.layers.Xcurrent_layer+=1# Last stage: add norm and lm_head if requestedifstage_idx==num_stages-1:stage_modules.append(f"{fqn_prefix}norm")# model.norm (inside model.model)ifinclude_lm_head:stage_modules.append("lm_head")# lm_head (outside model.model, no prefix!)ifinclude_rotary_emb:# Always include rotary_emb in all stages (shared utility)stage_modules.append(f"{fqn_prefix}rotary_emb")# model.rotary_emb

This implementation demonstrates several key insights:

  1. Hierarchical Naming: Thefqn_prefix="model." parameter accounts for HuggingFace's nested structure where most components are insidemodel.model

    graph TD  A[Qwen3ForCausalLM] --> B[model: Qwen3Model]  A --> C[lm_head]  B --> D[model.embed_tokens]  B --> E[model.layers]  B --> F[model.norm]  B --> G[model.rotary_emb]
    Loading
  2. Mixed Hierarchy Handling: Notice thatlm_head has no prefix because it lives at the top level (Qwen3ForCausalLM.lm_head), whilenorm uses the prefix because it's inside the inner model (Qwen3ForCausalLM.model.norm)

    graph TD  A[Qwen3ForCausalLM]  A --> C[lm_head]  A --> B[model]  B --> N[model.norm]
    Loading
  3. Shared Component Replication: Therotary_emb is added toevery stage because position embeddings are needed by all transformer layers

  4. Smart Distribution: The function automatically calculates how many layers per stage, handling remainder layers by distributing them to the first few stages

To illustrate how this works in practice, consider a 32-layer Qwen3 model split across 4 stages:

[# Stage 0: Input processing + first 8 layers + shared utilities    ["model.embed_tokens","model.layers.0", ...,"model.layers.7","model.rotary_emb"],# Stage 1: Middle layers + shared utilities    ["model.layers.8", ...,"model.layers.15","model.rotary_emb"],# Stage 2: Middle layers + shared utilities    ["model.layers.16", ...,"model.layers.23","model.rotary_emb"],# Stage 3: Final layers + output processing + shared utilities    ["model.layers.24", ...,"model.layers.31","model.norm","lm_head","model.rotary_emb"]]

This intelligent assignment ensures that each stage has exactly what it needs, while avoiding duplication of unique components like embeddings and the language modeling head. It can also serve as a reference for automatically splitting any custom models for your own use case.

Challenge 2: nn.ModuleList vs nn.ModuleDict: The Indexing Problem

A subtle but critical issue in pipeline parallelism involves how PyTorch'snn.ModuleList andnn.ModuleDict behave when models are split across stages. This seemingly minor implementation detail can cause significant problems with checkpointing and state management.

Most Hugging Face models usenn.ModuleList to store transformer layers:

# Standard HuggingFace patternclassTransformerModel(nn.Module):def__init__(self):self.layers=nn.ModuleList([TransformerLayer()for_inrange(32)# layers 0-31        ])

The problem arises when we split this model across pipeline stages. Each stage gets a subset of the layers, butnn.ModuleList automatically re-indexes its contents starting from 0.

# After splitting across 4 stages:# Stage 0 gets: layers[0:8]   -> Re-indexed as layers[0:8]   ✓ Correct# Stage 1 gets: layers[8:16]  -> Re-indexed as layers[0:8]   ✗ Wrong!# Stage 2 gets: layers[16:24] -> Re-indexed as layers[0:8]   ✗ Wrong!# Stage 3 gets: layers[24:32] -> Re-indexed as layers[0:8]   ✗ Wrong!

This seemingly innocent re-indexing creates a disaster scenario for checkpointing:

# During training, model saves checkpoint with state_dict keys:{"stage_0.layers.0.weight":tensor(...),# Actually layer 0"stage_0.layers.1.weight":tensor(...),# Actually layer 1  ..."stage_1.layers.0.weight":tensor(...),# Actually layer 8, but saved as 0!"stage_1.layers.1.weight":tensor(...),# Actually layer 9, but saved as 1!  ...}# During loading, this creates total confusion:# - Stage 1's "layer 0" weights get loaded where layer 8 weights should go# - Original layer 8-15 weights are completely lost# - Model convergence is destroyed

Fortunately, AutoPipeline solves this by convertingnn.ModuleList tonn.ModuleDict with explicit layer naming:

elifisinstance(module,nn.ModuleList):indices_to_keep= {int(idx)foridxinlayers_to_keepifidx.isdigit()}new_layers=nn.ModuleDict(        {str(i):layerfori,layerinenumerate(module)ifiinindices_to_keep}    )setattr(parent_module,name,new_layers)# After conversion and splitting:# Stage 0: {"0": layer_0, "1": layer_1, ..., "7": layer_7}# Stage 1: {"8": layer_8, "9": layer_9, ..., "15": layer_15}# Stage 2: {"16": layer_16, "17": layer_17, ..., "23": layer_23}# Stage 3: {"24": layer_24, "25": layer_25, ..., "31": layer_31}

With this approach, checkpoint saving and loading work correctly across all pipeline stages, maintaining the original layer identities throughout the training process.

Challenge 3: Forward Method Patching: Handling Missing Modules

Another complex challenge in pipeline parallelism is ensuring that forward methods work correctly when modules are distributed across different pipeline stages. Standard Hugging Face forward methods assume all components are available locally, but in pipeline parallelism, this assumption breaks down.

To understand the issue, consider a standard Hugging Face model forward method:

# Standard HuggingFace forward methoddefforward(self,input_ids,attention_mask=None,**kwargs):# This assumes embed_tokens exists on this stageinputs_embeds=self.embed_tokens(input_ids)# ← Fails on stages 1,2,3!hidden_states=inputs_embeds# This assumes layers exists on this stageforlayerinself.layers:# ← Fails if no layers on this stage!hidden_states=layer(hidden_states,**kwargs)# This assumes norm exists on this stagehidden_states=self.norm(hidden_states)# ← Fails on stages 0,1,2!    ...returnCausalLMOutputWithPast(loss=loss,logits=logits,past_key_values=outputs.past_key_values,hidden_states=outputs.hidden_states,attentions=outputs.attentions,    )

Problem 1: When we split the model across stages:

  • Stage 0 hasembed_tokens, but stages 1-3 don't
  • Stage 3 hasnorm andlm_head, but stages 0-2 don't
  • Callingself.embed_tokens(input_ids) on stage 1 results inAttributeError: 'NoneType' object has no attribute '__call__'

Problem 2: PyTorch's Pipeline Parallelism API expects each stage to return a single tensor output, which can be passed to the next stage or used by the loss function in the final stage. However, Hugging Face models typically produce customized outputs, which are not directly compatible with this requirement.

To address these fundamental incompatibilities, AutoPipeline solves this by replacing the standard forward methods with pipeline-aware versions that handle missing modules and outputs gracefully. The actual implementation can be found inhf_utils.py. AutoPipeline automatically applies these patches based on model type.

Let's examine how this transformation works in practice.

Before Patching (Fails):

# Problem 1: Missing modules# Stage 1 trying to run standard forward methodhidden_states=self.embed_tokens(input_ids)# ← AttributeError!# embed_tokens is None on this stage# Problem 2: Complex return typesdefforward(self,input_ids,**kwargs):# ... forward computation ...returnCausalLMOutputWithPast(# ← Pipeline expects simple tensor!loss=loss,logits=logits,past_key_values=outputs.past_key_values,hidden_states=outputs.hidden_states,attentions=outputs.attentions,    )

After Patching (Works):

# Problem 1 Solution: Intelligent module detectionifhasattr(self,"embed_tokens")andself.embed_tokensisnotNone:hidden_states=self.embed_tokens(input_ids)# First stageelse:hidden_states=input_ids# Middle/last stage - already embeddings# Problem 2 Solution: Pipeline-aware return typesdefpipeline_forward_causal_lm(self,input_ids,**kwargs):# Get outputs from the inner modeloutputs=self.model(input_ids=input_ids,**kwargs)hidden_states=outputsifisinstance(outputs,torch.Tensor)elseoutputs.last_hidden_state# Apply language modeling head if it exists on this stageifhasattr(self,"lm_head")andself.lm_headisnotNone:logits=self.lm_head(hidden_states)returnlogits# ← Simple tensor for pipelineelse:returnhidden_states# ← Pass hidden states to next stage

This comprehensive patching approach solves both the missing module problem and the output compatibility issue, allowing Hugging Face models to work seamlessly with PyTorch's pipeline parallelism API.

While this solution is effective, it does introduce some maintenance considerations. First, we need to keep the patched forward methods in sync whenevertransformers version is upgraded, otherwise it can cause unexpected errors. Second, not all language models may have the sameforward method skeleton, which can result in incorrectly patched methods leading to subtle issues.

Challenge 4: Gradient Scaling

A subtle but critical challenge in pipeline parallelism is ensuring correct gradient scaling when combining multiple parallelism strategies. This issue emerges during real training scenarios and can impact model convergence.

The problem became apparent during convergence testing, where we discovered that pipeline parallel training with mixed parallelism (PP + DP) resulted in different gradient norms compared to training with data parallelism alone. This occurred because, when pipeline parallelism was combined with data parallelism, gradients were incorrectly scaled by default—leading to different gradient norm curves.

According toPyTorch's pipeline parallelism documentation:

"Gradients are scaled by num_microbatches depending on the scale_grads argument, defaulting to True. This setting should match the configuration of your loss_fn, which may either average losses (scale_grads=True) or sum losses (scale_grads=False)."

However, our training recipes use per-token loss calculation, which required a different approach. As a result, we had to disable automatic scaling in the pipeline schedule (scale_grads=False) and handle gradient normalization manually in the training loop, ensuring proper scaling across all parallelism dimensions. This approach gives us precise control over gradient scaling, while maintaining compatibility with our per-token loss calculation.

Specifically, we scale gradients in pipeline parallelism by dividing by a factor ofnum_label_tokens_in_batch / dp_group_size. The/ dp_group_size is needed because FSDP averages the gradients across the data parallel ranks during reduction. (ref).

The result is identical loss curves and gradient norm patterns across all parallelism configurations, ensuring that pipeline parallelism maintains correctness.

Verified HF models supported out of the box

After solving these challenges, many Hugging Face models that previously ran into GPU OOMs now train cleanly with AutoPipeline. Below is a summary of the models we successfully fine-tuned out of the box:

Model familySizes verifiedPP sizes used
Llama (2, 3/3.1/3.3)65B, 70B4, 8
CodeLlama70B4, 8
Qwen (1.5, 2, 2.5, Math)72B4, 8
Mixtral (MoE)8x7B (46.7B Total 12.9B Active), 8x22B (141B Total 39B Active)4, 8 (8x22B: 8 only)
MistralLarge (123B)8
GLM32B, 4.5 Air (106B Total 12B Active)4, 8
Llama 70B finetunes (Hermes, H2OGPT, ChatQA, Tulu, Nemotron, etc.)70B4, 8

Note: This table summarizes models with at least one finished run. Many additional fine-tuned variants also ran successfully; the table groups them by family for brevity.

Conclusion

If you are training any HuggingFace Transformer model - Llama, Qwen, Mistral, Gemma, or any other,AutoPipeline provides the tools needed to scale your training across multiple GPUs efficiently and correctly.

If you are training custom models and prefer more granular control, thefunctional API provides modular building blocks that can be adapted to any PyTorch model architecture, ensuring that the benefits of pipeline parallelism are accessible across the entire ecosystem.

Ready to get started? Check out an example recipe with pipeline parallelismhere and moredocumentation For questions, issues, or contributions, visit ourGitHub repository.

Contributors

This work wouldn't have been possible without the incredible contributions from our team.

Special thanks toHuiying Li,Adil Asif andAlexandros Koumparoulis for their help adding pipelining support into Automodel - including checkpointing support, recipe integration, convergence sweeps, etc.

Additionally, a huge shoutout toWenwen Gao,Bernard Nguyen, andJennifer Gerhold for their invaluable guidance on the content — from shaping the narrative to ensuring technical accuracy and clarity.

You must be logged in to vote

Replies: 2 comments 2 replies

Comment options

  • Re "Challenge 3: Forward Method Patching"

Have you looked at:

import torch.distributed.pipelining as pipeliningfull_model = AutoModelForCausalLM(...)pipe = pipelining.pipeline(full_model, spec, ...)my_submod = pipe.get_stage_module(my_pp_rank)

This waymy_submod -- a nn.Module -- would have the desired forward function automatically.

  • Re "Challenge 2: nn.ModuleList vs nn.ModuleDict: The Indexing Problem":

my_submod created above would have the same FQN hierarchy as the original model as well as original indices, e.g.layers.8-16 instead oflayers.0-8, thus avoiding challenge 2 too.

Reference:Option 2: splitting a model automatically

You must be logged in to vote
1 reply
@hemildesai
Comment options

Hi@kwen2501, thanks a lot for the pointer. We did take a look initially but saw this note:

The pipeline frontend uses a tracer (torch.export) to capture your model into a single graph. If your model is not full-graph’able, you can use our manual frontend below.

Looking at a few implementations in Huggingface, there were a lot of conditionals so we weren't sure if the stock implementations would work with the tracer. But definitely happy to giveOption 2: splitting a model automatically a try as well. Let me know if you have any thoughts regarding this point.

Separately, looks like there's a few exampleshere, Created#604 to add the automatic splitting option toAutopipeline and try it out.

Comment options

Hello, For Challenge 4: Gradient Scaling, When multiple microbatches exist, calculatingnum_label_tokens_in_batch seems challenging since it represents the aggregate result across multiple microbatches. Could you share how you addressed this? Thanks.

You must be logged in to vote
1 reply
@hemildesai
Comment options

Hi@xrrain, the Pytorch PP API splits a local batch into microbatches internally, so we just pass the full local batch to itsstep function. This way, we can calculate thenum_label_tokens in one local batch. We sum up thenum_label_tokens across all gradient accumulated local batches, and in the end do an all reduce across the data parallel group to getnum_label_tokens_in_batch which represents the total label tokens in a global batch. (Codehere)

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Labels
publishedPublished discussions
3 participants
@kwen2501@hemildesai@xrrain

[8]ページ先頭

©2009-2025 Movatter.jp