Pipeline Parallelism#
Created On: Jun 16, 2025 | Last Updated On: Aug 13, 2025
Note
torch.distributed.pipelining is currently in alpha state and underdevelopment. API changes may be possible. It was migrated from thePiPPy project.
Why Pipeline Parallel?#
Pipeline Parallelism is one of theprimitive parallelism for deep learning.It allows theexecution of a model to be partitioned such that multiplemicro-batches can execute different parts of the model code concurrently.Pipeline parallelism can be an effective technique for:
large-scale training
bandwidth-limited clusters
large model inference
The above scenarios share a commonality that the computation per device cannothide the communication of conventional parallelism, for example, the weightall-gather of FSDP.
What istorch.distributed.pipelining?#
While promising for scaling, pipelining is often difficult to implement becauseit needs topartition the execution of a model in addition to model weights.The partitioning of execution often requires intrusive code changes to yourmodel. Another aspect of complexity comes fromscheduling micro-batches in adistributed environment, withdata flow dependency considered.
Thepipelining package provides a toolkit that does said thingsautomatically which allows easy implementation of pipeline parallelismongeneral models.
It consists of two parts: asplitting frontend and adistributed runtime.The splitting frontend takes your model code as-is, splits it up into “modelpartitions”, and captures the data-flow relationship. The distributed runtimeexecutes the pipeline stages on different devices in parallel, handling thingslike micro-batch splitting, scheduling, communication, and gradient propagation,etc.
Overall, thepipelining package provides the following features:
Splitting of model code based on simple specification.
Rich support for pipeline schedules, including GPipe, 1F1B,Interleaved 1F1B and Looped BFS, and providing the infrastructure for writingcustomized schedules.
First-class support for cross-host pipeline parallelism, as this is where PPis typically used (over slower interconnects).
Composability with other PyTorch parallel techniques such as data parallel(DDP, FSDP) or tensor parallel. TheTorchTitan project demonstrates a “3D parallel”application on the Llama model.
Step 1: buildPipelineStage#
Before we can use aPipelineSchedule, we need to createPipelineStageobjects that wrap the part of the model running in that stage. ThePipelineStage is responsible for allocating communication buffers andcreating send/recv ops to communicate with its peers. It manages intermediatebuffers e.g. for the outputs of forward that have not been consumed yet, and itprovides a utility for running the backwards for the stage model.
APipelineStage needs to know the input and output shapes for the stagemodel, so that it can correctly allocate communication buffers. The shapes mustbe static, e.g. at runtime the shapes can not change from step to step. A classPipeliningShapeError will be raised if runtime shapes do not match theexpected shapes. When composing with other paralleisms or applying mixedprecision, these techniques must be taken into account so thePipelineStageknows the correct shape (and dtype) for the output of the stage module atruntime.
Users may construct aPipelineStage instance directly, by passing in annn.Module representing the portion of the model that should run on thestage. This may require changes to the original model code. See the exampleinOption 1: splitting a model manually.
Alternatively, the splitting frontend can use graph partitioning to split yourmodel into a series ofnn.Module automatically. This technique requires themodel is traceable withtorch.Export. Composability of the resultingnn.Module with other parallelism techniques is experimental, and may requiresome workarounds. Usage of this frontend may be more appealing if the usercannot easily change the model code. SeeOption 2: splitting a model automatically for moreinformation.
Step 2: usePipelineSchedule for execution#
We can now attach thePipelineStage to a pipeline schedule, and run theschedule with input data. Here is a GPipe example:
fromtorch.distributed.pipeliningimportScheduleGPipe# Create a scheduleschedule=ScheduleGPipe(stage,n_microbatches)# Input data (whole batch)x=torch.randn(batch_size,in_dim,device=device)# Run the pipeline with input `x`# `x` will be divided into microbatches automaticallyifrank==0:schedule.step(x)else:output=schedule.step()
Note that the above code needs to be launched for each worker, thus we use alauncher service to launch multiple processes:
torchrun--nproc_per_node=2example.py
Options for Splitting a Model#
Option 1: splitting a model manually#
To directly construct aPipelineStage, the user is responsible for providinga singlenn.Module instance that owns the relevantnn.Parameters andnn.Buffers, and defines aforward() method that executes the operationsrelevant for that stage. For example, a condensed version of the Transformerclass defined in Torchtitan shows a pattern of building an easily partitionablemodel.
classTransformer(nn.Module):def__init__(self,model_args:ModelArgs):super().__init__()self.tok_embeddings=nn.Embedding(...)# Using a ModuleDict lets us delete layers without affecting names,# ensuring checkpoints will correctly save and load.self.layers=torch.nn.ModuleDict()forlayer_idinrange(model_args.n_layers):self.layers[str(layer_id)]=TransformerBlock(...)self.output=nn.Linear(...)defforward(self,tokens:torch.Tensor):# Handling layers being 'None' at runtime enables easy pipeline splittingh=self.tok_embeddings(tokens)ifself.tok_embeddingselsetokensforlayerinself.layers.values():h=layer(h,self.freqs_cis)h=self.norm(h)ifself.normelsehoutput=self.output(h).float()ifself.outputelsehreturnoutput
A model defined in this manner can be easily configured per stage by firstinitializing the whole model (using meta-device to avoid OOM errors), deletingundesired layers for that stage, and then creating a PipelineStage that wrapsthe model. For example:
withtorch.device("meta"):assertnum_stages==2,"This is a simple 2-stage example"# we construct the entire model, then delete the parts we do not need for this stage# in practice, this can be done using a helper function that automatically divides up layers across stages.model=Transformer()ifstage_index==0:# prepare the first stage modeldelmodel.layers["1"]model.norm=Nonemodel.output=Noneelifstage_index==1:# prepare the second stage modelmodel.tok_embeddings=Nonedelmodel.layers["0"]fromtorch.distributed.pipeliningimportPipelineStagestage=PipelineStage(model,stage_index,num_stages,device,)
When composing with other Data or Model parallelism techniques,output_argsmay also be required, if the output shape/dtype of the model chunk will beaffected.
Option 2: splitting a model automatically#
If you have a full model and do not want to spend time on modifying it into asequence of “model partitions”, thepipeline API is here to help.Here is a brief example:
classModel(torch.nn.Module):def__init__(self)->None:super().__init__()self.emb=torch.nn.Embedding(10,3)self.layers=torch.nn.ModuleList(Layer()for_inrange(2))self.lm=LMHead()defforward(self,x:torch.Tensor)->torch.Tensor:x=self.emb(x)forlayerinself.layers:x=layer(x)x=self.lm(x)returnx
If we print the model, we can see multiple hierarchies, which makes it hard to split by hand:
Model((emb):Embedding(10,3)(layers):ModuleList((0-1):2xLayer((lin):Linear(in_features=3,out_features=3,bias=True)))(lm):LMHead((proj):Linear(in_features=3,out_features=3,bias=True)))
Let us see how thepipeline API works:
fromtorch.distributed.pipeliningimportpipeline,SplitPoint# An example micro-batch inputx=torch.LongTensor([1,2,4,5])pipe=pipeline(module=mod,mb_args=(x,),split_spec={"layers.1":SplitPoint.BEGINNING,})
Thepipeline API splits your model given asplit_spec, whereSplitPoint.BEGINNING stands for adding a split pointbefore execution of certain submodule in theforward function, andsimilarly,SplitPoint.END for split pointafter such.
If weprint(pipe), we can see:
GraphModule((submod_0):GraphModule((emb):InterpreterModule()(layers):Module((0):InterpreterModule((lin):InterpreterModule())))(submod_1):GraphModule((layers):Module((1):InterpreterModule((lin):InterpreterModule()))(lm):InterpreterModule((proj):InterpreterModule())))defforward(self,x):submod_0=self.submod_0(x);x=Nonesubmod_1=self.submod_1(submod_0);submod_0=Nonereturn(submod_1,)
The “model partitions” are represented by submodules (submod_0,submod_1), each of which is reconstructed with original model operations, weightsand hierarchies. In addition, a “root-level”forward function isreconstructed to capture the data flow between those partitions. Such data flowwill be replayed by the pipeline runtime later, in a distributed fashion.
ThePipe object provides a method for retrieving the “model partitions”:
stage_mod:nn.Module=pipe.get_stage_module(stage_idx)
The returnedstage_mod is ann.Module, with which you can create anoptimizer, save or load checkpoints, or apply other parallelisms.
Pipe also allows you to create a distributed stage runtime on a device givenaProcessGroup:
stage=pipe.build_stage(stage_idx,device,group)
Alternatively, if you would like to build the stage runtime later after somemodification to thestage_mod, you can use a functional version of thebuild_stage API. For example:
fromtorch.distributed.pipeliningimportbuild_stagefromtorch.nn.parallelimportDistributedDataParalleldp_mod=DistributedDataParallel(stage_mod)info=pipe.info()stage=build_stage(dp_mod,stage_idx,info,device,group)
Note
Thepipeline frontend uses a tracer (torch.export) to capture yourmodel into a single graph. If your model is not full-graph’able, you can useour manual frontend below.
Hugging Face Examples#
In thePiPPy repo where this package wasoriginal created, we kept examples based on unmodified Hugging Face models.See theexamples/huggingface directory.
Examples include:
Technical Deep Dive#
How does thepipeline API split a model?#
First, thepipeline API turns our model into a directed acyclic graph (DAG)by tracing the model. It traces the model usingtorch.export – a PyTorch 2full-graph capturing tool.
Then, it groups together theoperations and parameters needed by a stageinto a reconstructed submodule:submod_0,submod_1, …
Different from conventional submodule access methods likeModule.children(),thepipeline API does not only cut the module structure of your model, butalso theforward function of your model.
This is necessary because model structure likeModule.children() merelycaptures information duringModule.__init__(), and does not capture anyinformation aboutModule.forward(). Said differently,Module.children()lacks information about the following aspects key to pipelininig:
Execution order of child modules in
forwardActivation flows between child modules
Whether there are any functional operators between child modules (for example,
reluoraddoperations will not be captured byModule.children()).
Thepipeline API, on the contrary, makes sure that theforward behavioris truly preserved. It also captures the activation flow between the partitions,helping the distributed runtime to make correct send/receive calls without humanintervention.
Another flexibility of thepipeline API is that split points can be atarbitrary levels within your model hierarchy. In the split partitions, the original modelhierarchy related to that partition will be reconstructed at no cost to you.At a result, fully-qualified names (FQNs) pointing to a submodule or parameterwould be still valid, and services that relies on FQNs (such as FSDP, TP orcheckpointing) can still run with your partitioned modules with almost zero codechange.
Implementing Your Own Schedule#
You can implement your own pipeline schedule by extending one of the following two class:
PipelineScheduleSinglePipelineScheduleMulti
PipelineScheduleSingle is for schedules that assignsonly one stage per rank.PipelineScheduleMulti is for schedules that assigns multiple stages per rank.
For example,ScheduleGPipe andSchedule1F1B are subclasses ofPipelineScheduleSingle.Whereas,ScheduleInterleaved1F1B,ScheduleLoopedBFS,ScheduleInterleavedZeroBubble, andScheduleZBVZeroBubbleare subclasses ofPipelineScheduleMulti.
Logging#
You can turn on additional logging using theTORCH_LOGS environment variable fromtorch._logging:
TORCH_LOGS=+ppwill displaylogging.DEBUGmessages and all levels above it.TORCH_LOGS=ppwill displaylogging.INFOmessages and above.TORCH_LOGS=-ppwill displaylogging.WARNINGmessages and above.
API Reference#
Model Split APIs#
The following set of APIs transform your model into a pipeline representation.
- classtorch.distributed.pipelining.SplitPoint(value)[source]#
Enum representing the points at which a split can occur in the execution of a submodule.:ivar BEGINNING: Represents adding a split pointbefore the execution of a certain submodule in theforward function.:ivar END: Represents adding a split pointafter the execution of a certain submodule in theforward function.
- torch.distributed.pipelining.pipeline(module,mb_args,mb_kwargs=None,split_spec=None,split_policy=None)[source]#
Split a module based on a specification.
SeePipe for more details.
- Parameters
module (Module) – The module to be split.
mb_args (tuple[Any,...]) – Example positional inputs, in micro-batch form.
mb_kwargs (Optional[dict[str,Any]]) – Example keyword inputs, in micro-batch form. (default:None)
split_spec (Optional[dict[str,torch.distributed.pipelining._IR.SplitPoint]]) – A dictionary using submodule names as split marker. (default:None)
split_policy (Optional[Callable[[GraphModule],GraphModule]]) – The policy to use for splitting the module. (default:None)
- Return type
A pipeline representation of classPipe.
- classtorch.distributed.pipelining.Pipe(split_gm,num_stages,has_loss_and_backward,loss_spec)[source]#
- torch.distributed.pipelining.pipe_split()[source]#
pipe_split is a special operator that is used to mark the boundary betweenstages in a module. It is used to split the module into stages. It is ano-op if your annotated module is run eagerly.
Example
>>>defforward(self,x):>>>x=torch.mm(x,self.mm_param)>>>x=torch.relu(x)>>>pipe_split()>>>x=self.lin(x)>>>returnx
The above example will be split into two stages.
Microbatch Utilities#
- classtorch.distributed.pipelining.microbatch.TensorChunkSpec(split_dim)[source]#
Class used to specify chunking of inputs
- torch.distributed.pipelining.microbatch.split_args_kwargs_into_chunks(args,kwargs,chunks,args_chunk_spec=None,kwargs_chunk_spec=None)[source]#
Given a sequence of args and kwargs, split them into a number of chunksaccording to their respective chunking specs.
- Parameters
chunks (int) – Number of chunks to split the args and kwargs into
args_chunk_spec (Optional[tuple[torch.distributed.pipelining.microbatch.TensorChunkSpec,...]]) – chunking specs for args, in same shape as args
kwargs_chunk_spec (Optional[dict[str,torch.distributed.pipelining.microbatch.TensorChunkSpec]]) – chunking specs for kwargs, in same shape as kwargs
- Returns
List of sharded argskwargs_split: List of sharded kwargs
- Return type
args_split
Pipeline Stages#
- classtorch.distributed.pipelining.stage.PipelineStage(submodule,stage_index,num_stages,device,input_args=None,output_args=None,group=None,dw_builder=None)[source]#
A class representing a pipeline stage in a pipeline parallelism setup.
PipelineStage assumes sequential partitioning of the model, i.e. the model is split into chunks where outputs fromone chunk feed into inputs of the next chunk, with no skip connections.
PipelineStage performs runtime shape/dtype inference automatically by propagating the outputs from stage0 tostage1 and so forth, in linear order. To bypass shape inference, pass theinput_args andoutput_args to eachPipelineStage instance.
- Parameters
submodule (nn.Module) – The PyTorch module wrapped by this stage.
stage_index (int) – The ID of this stage.
num_stages (int) – The total number of stages.
device (torch.device) – The device where this stage is located.
input_args (Union[torch.Tensor,Tuple[torch.tensor]],optional) – The input arguments for the submodule.
output_args (Union[torch.Tensor,Tuple[torch.tensor]],optional) – The output arguments for the submodule.
group (dist.ProcessGroup,optional) – The process group for distributed training. If None, default group.
dw_builder (Optional[Callable[[],Callable[...,None]]) – If provided, dw_builder will build a new dw_runner functionthat will the W action (input weights) for F, I, W (Fwd, Input, Weight) zero bubble schedules.
- torch.distributed.pipelining.stage.build_stage(stage_module,stage_index,pipe_info,device,group=None)[source]#
Create a pipeline stage given a stage_module to be wrapped by this stageand pipeline information.
- Parameters
stage_module (torch.nn.Module) – the module to be wrapped by this stage
stage_index (int) – the index of this stage in the pipeline
pipe_info (PipeInfo) – information about the pipeline, can be retrieved bypipe.info()
device (torch.device) – the device to be used by this stage
group (Optional[dist.ProcessGroup]) – the process group to be used by this stage
- Returns
a pipeline stage that can run withPipelineSchedules.
- Return type
_PipelineStage
Pipeline Schedules#
- classtorch.distributed.pipelining.schedules.ScheduleGPipe(stage,n_microbatches,loss_fn=None,args_chunk_spec=None,kwargs_chunk_spec=None,output_merge_spec=None,scale_grads=True)[source]#
The GPipe schedule.Will go through all the microbatches in a fill-drain manner.
- classtorch.distributed.pipelining.schedules.Schedule1F1B(stage,n_microbatches,loss_fn=None,args_chunk_spec=None,kwargs_chunk_spec=None,output_merge_spec=None,scale_grads=True)[source]#
The 1F1B schedule.Will perform one forward and one backward on the microbatches in steady state.
- classtorch.distributed.pipelining.schedules.ScheduleInterleaved1F1B(stages,n_microbatches,loss_fn=None,args_chunk_spec=None,kwargs_chunk_spec=None,output_merge_spec=None,scale_grads=True)[source]#
The Interleaved 1F1B schedule.Seehttps://arxiv.org/pdf/2104.04473 for details.Will perform one forward and one backward on the microbatches in steadystate and supports multiple stages per rank. When microbatches are ready formultiple local stages, Interleaved 1F1B prioritizes the earlier microbatch(also called “depth first”).
This schedule is mostly similar to the original paper.It differs by being relaxing the requirement of num_microbatch % pp_size == 0.Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) andit works as long as n_microbatches % num_rounds is 0. As a few examples, support
pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
- classtorch.distributed.pipelining.schedules.ScheduleLoopedBFS(stages,n_microbatches,loss_fn=None,output_merge_spec=None,scale_grads=True)[source]#
Breadth-First Pipeline Parallelism.Seehttps://arxiv.org/abs/2211.05953 for details.Similar to Interleaved 1F1B, Looped BFS supports multiple stages per rank.What is different is that when microbatches are ready for multiple localstages, Loops BFS will prioritizes the earlier stage, running all availablemicrobatches at once.
- classtorch.distributed.pipelining.schedules.ScheduleInterleavedZeroBubble(stages,n_microbatches,loss_fn=None,args_chunk_spec=None,kwargs_chunk_spec=None,output_merge_spec=None,scale_grads=True)[source]#
The Interleaved Zero Bubble schedule.Seehttps://arxiv.org/pdf/2401.10241 for details.Will perform one forward and one backward on inputs for the microbatches in steadystate and supports multiple stages per rank. Uses the backward for weights to fill inthe pipeline bubble.
In particular this is implementing the ZB1P schedule in the paper.
- classtorch.distributed.pipelining.schedules.ScheduleZBVZeroBubble(stages,n_microbatches,loss_fn=None,args_chunk_spec=None,kwargs_chunk_spec=None,output_merge_spec=None,scale_grads=True)[source]#
The Zero Bubble schedule (ZBV variant).Seehttps://arxiv.org/pdf/2401.10241 Section 6 for details.
This schedules requires exactly two stages per rank.
This schedule will perform one forward and one backward on inputs for the microbatches in steadystate and supports multiple stages per rank. Uses backward with respect to weights to fill inthe pipeline bubble.
This ZB-V schedule would have the “zero bubble” property only if time forward == time backward input == time backward weights.In practice, this is not likely true for real models so alternativelya greedy scheduler could be implemented for unequal/unbalanced time.
- classtorch.distributed.pipelining.schedules.ScheduleDualPipeV(stages,n_microbatches,loss_fn=None,args_chunk_spec=None,kwargs_chunk_spec=None,output_merge_spec=None,scale_grads=True)[source]#
The DualPipeV schedule. A more efficient schedule variant based on theDualPipe schedule introduced by DeepSeek inhttps://arxiv.org/pdf/2412.19437
Based on the open sourced code fromdeepseek-ai/DualPipe
- classtorch.distributed.pipelining.schedules.PipelineScheduleSingle(stage,n_microbatches,loss_fn=None,args_chunk_spec=None,kwargs_chunk_spec=None,output_merge_spec=None,scale_grads=True)[source]#
Base class for single-stage schedules.Implements thestep method.Derived classes should implement_step_microbatches.
Gradients are scaled by num_microbatches depending on thescale_grads argument, defaulting to True. This settingshould match the configuration of your loss_fn, which may either average losses (scale_grads=True)or sum losses (scale_grads=False).
- step(*args,target=None,losses=None,**kwargs)[source]#
Run one iteration of the pipeline schedule withwhole-batch input.Will chunk the input into microbatches automatically, and go through themicrobatches according to the schedule implementation.
args: positional arguments to the model (as in non-pipeline case).kwargs: keyword arguments to the model (as in non-pipeline case).target: target for the loss function.losses: a list to store the losses for each microbatch.
- classtorch.distributed.pipelining.schedules.PipelineScheduleMulti(stages,n_microbatches,loss_fn=None,args_chunk_spec=None,kwargs_chunk_spec=None,output_merge_spec=None,use_full_backward=None,scale_grads=True)[source]#
Base class for multi-stage schedules.Implements thestep method.
Gradients are scaled by num_microbatches depending on thescale_grads argument, defaulting to True. This settingshould match the configuration of your loss_fn, which may either average losses (scale_grads=True)or sum losses (scale_grads=False).
- step(*args,target=None,losses=None,**kwargs)[source]#
Run one iteration of the pipeline schedule withwhole-batch input.Will chunk the input into microbatches automatically, and go through themicrobatches according to the schedule implementation.
args: positional arguments to the model (as in non-pipeline case).kwargs: keyword arguments to the model (as in non-pipeline case).target: target for the loss function.losses: a list to store the losses for each microbatch.