Large Scale Transformer model training with Tensor Parallel (TP)#
Created On: Apr 19, 2024 | Last Updated: Jul 18, 2025 | Last Verified: Nov 05, 2024
Author:Wanchao Liang,Tianyu Liu
Note
View and edit this tutorial ingithub.
This tutorial demonstrates how to train a large Transformer-like model across hundreds to thousands of GPUs using Tensor Parallel and Fully Sharded Data Parallel.
Prerequisites:
PyTorch 2.3.0 or later installed with CUDA/Linux
How Tensor Parallel works?#
Tensor Parallel (TP) was originally proposed in theMegatron-LM paper,and it is an efficient model parallelism technique to train large scale Transformer models.Sequence Parallel (SP) we mention in this tutorial is a variant of TensorParallel that shards on the sequence dimension fornn.LayerNorm orRMSNorm to further save activation memoryduring training. As the model becomes larger, the activation memory becomes the bottleneck, so in TensorParallel training it usually applies Sequence Parallel toLayerNorm orRMSNorm layers.

Figure 1. represents the sharding in Tensor Parallel style on a Transformer model’s MLP and Self-Attention layer, where the matrix multiplications in both attention/MLP happens through sharded computations (image source)#
At a high level, PyTorch Tensor Parallel works as follows:
Sharding initialization
Determine which
ParallelStyleto apply to each layer and shard the initialized module by callingparallelize_module.The parallelized modules would have their model parameters be swapped to DTensors, and DTensor would be responsible to run the parallelized module using sharded computation.
Runtime foward/backward
Depending on the input/outputs DTensor layouts user specified for each
ParallelStyle, it would run proper communication operation to transform the DTensor layouts for inputs/outputs (such asallreduce,allgatherandreduce_scatter).Run sharded computation for the parallelized layers to save compute/memory (for example,
nn.Linear,nn.Embedding).
When and Why you should apply Tensor Parallel#
The PyTorch Fully Sharded Data Parallel (FSDP) already has the capability to scale model training to a specificnumber of GPUs. However, when it comes to further scale the model training in terms of model size and GPU quantity,many additional challenges arise that may require combining Tensor Parallel with FSDP.:
As the world size (number of GPUs) is becoming excessively large (exceeding 128/256 GPUs), the FSDP collectives (such as
allgather) are being dominated by ring latency.By implementing TP/SP on top of FSDP, the FSDP world size could be reduced by 8 by applying FSDP to be inter-host only, consequently decreasing the latency costs by the same amount.Hit data parallelism limit where you can not raise the global batch size to be above the number of GPUs due to both convergence and GPU memory limitations, Tensor/Sequence Parallelis the only known way to “ballpark” the global batch size and continue scaling with more GPUs. This means both model size and number of GPUs could continue to scale.
For certain types of models, when local batch size becomes smaller, TP/SP can yield matrix multiplication shapes that are more optimized for floating point operations (FLOPS).
So, when pre-training, how easy is it to hit those limits? As of now, pre-training a Large Language Model (LLM) with billions or trillions of tokens could take months, even when using thousands of GPUs.
It will always hit limitation 1 when training LLM on a large scale. For example, Llama 2 70B trained with 2k GPUs for 35 days, multi-dimensional parallelisms are needed at 2k scale.
When the Transformer model becomes larger (such as Llama2 70B), it will also quickly hit the limitation 2. One could not use FSDP alone with even local
batch_size=1due to memoryand convergence constraints. For example, Llama 2 global batch size is 1K, so data parallelism alone can not be used at 2K GPUs.
How to apply Tensor Parallel#
PyTorch Tensor Parallel APIs offers a set of module level primitives (ParallelStyle) to configure the sharding for each individual layers of the model, including:
ColwiseParallelandRowwiseParallel: Shard thenn.Linearandnn.Embeddingin the column or row fashion.SequenceParallel: Perform sharded computations onnn.LayerNorm,nn.Dropout,RMSNormPython, etc.PrepareModuleInputandPrepareModuleOutput: Configure the module inputs/outputs sharding layouts with proper communication operations.
To demonstrate how to use the PyTorch native Tensor Parallel APIs, let us look at a common Transformer model. In this tutorial, we use the most recentLlama2 model as a reference Transformer model implementation, as it is also widely used in the community.
Since Tensor Parallel shard individual tensors over a set of devices, we would need to set up the distributed environment (such as NCCL communicators) first.Tensor Parallelism is a Single-Program Multiple-Data (SPMD) sharding algorithm similar to PyTorch DDP/FSDP, and it under the hood leverages the PyTorch DTensorto perform sharding. It also utilizes the DeviceMesh abstraction (which under the hood manages ProcessGroups) for device management and sharding.To see how to utilize DeviceMesh to set up multi-dimensional parallelisms, please refer tothis tutorial. Tensor Parallel usually works within each host, so let us first initialize a DeviceMesh that connects 8 GPUs within a host.
fromtorch.distributed.device_meshimportinit_device_meshtp_mesh=init_device_mesh("cuda",(8,))
Now that we have initialized DeviceMesh, let us take a detailed look at the Llama 2 model architecture and see how we should perform the Tensor Parallel sharding.Here we focus on the coreTransformerBlock, where the Transformer model stacks the identicalTransformerBlock s to scale up the model.
The coreTransformerBlock consists of anAttention layer and aFeedForward layer. Let us first look at the simplerFeedForward layer.For theFeedForward Layer it consists of three Linear layers, where it performs a SwiGLU style MLP, looking at its forward function:
# forward in the FeedForward layerdefforward(self,x):returnself.w2(F.silu(self.w1(x))*self.w3(x))
It performsw1 andw3 matmuls concurrently and followed by aw2 matmul with the result of the combined w1/w3 linear projection results. This means we coulduse the idea from the Tensor Parallelism paper to shard the w1/w3 Linear layers in the colwise fashion and shard thew2 Linear layer in the rowwise fashion, so thatthere is only oneallreduce communication happening at the end of all the three layers. With the PyTorch native Tensor Parallel, we can simply create aparallelize_plan for theFeedForward layer like below:
fromtorch.distributed.tensor.parallelimportColwiseParallel,RowwiseParallel,parallelize_modulelayer_tp_plan={# by default ColwiseParallel input layouts is replicated# and RowwiseParallel output layouts is replicated"feed_foward.w1":ColwiseParallel(),"feed_forward.w2":RowwiseParallel(),"feed_forward.w3":ColwiseParallel(),}
That’s simply how we configure the shardings for theFeedForward layer using the PyTorch Tensor Parallel APIs. Note that users would only need to specify how to shard the individual layers and the communications (for example,allreduce) will happen under the hood.
Moving on to theAttention Layer. It consists ofwq,wk,wv Linear layers to project input toq/k /v, and then it performs attention and output projection with thewo Linear layer. Tensor Parallelism here intends to perform column-wise sharding for theq/k/v projection and row-wise sharding for thewo linear projection. So we can add the Attention plan to thetp_plan that we just drafted up:
layer_tp_plan={# by default ColwiseParallel input layouts is replicated# and RowwiseParallel output layouts is replicated"attention.wq":ColwiseParallel(use_local_output=False),"attention.wk":ColwiseParallel(use_local_output=False),"attention.wv":ColwiseParallel(use_local_output=False),"attention.wo":RowwiseParallel(),"feed_forward.w1":ColwiseParallel(),"feed_forward.w2":RowwiseParallel(),"feed_forward.w3":ColwiseParallel(),}
This is almost thelayer_tp_plan we need to apply Tensor Parallelism to theTransformerBlock. However, one thing we should be aware is that when sharding the linear layer column-wise, the output of the linear layers would become sharded on the last tensor dimension, and the row-wise sharding linear layer directly accepts an input that shards on the last dimension.If there are any more tensor operations (such as view operations) between the column-wise linear and the row-wise linear, we would need to adjust the relevant shape related ops to sharded shape.
For the Llama model, in the attention layer, there are several view operations related to shape. Specifically, for column-wise parallelism in thewq/wk/wv linear layers, the activation tensor is sharded on thenum_heads dimension. To manage the difference between global and localnum_heads, we should setuse_local_output=False to ensure the output is a DTensor. Unlike a regular tensor, a DTensor is aware of the parallelism plans and will automatically handle changes in thenum_heads dimension.
Finally, we need to callparallelize_module API to make the plan for eachTransformerBlock effective. Under the hood, it distributes the model parameters insideAttention andFeedForward layers to DTensors, and registers communication hooks for model inputs and outputs (before and after each module respectively), if necessary:
forlayer_id,transformer_blockinenumerate(model.layers):layer_tp_plan={...}# i.e. the plan we just generatedparallelize_module(module=transformer_block,device_mesh=tp_mesh,parallelize_plan=layer_tp_plan,)
Now that we have elaborated the sharding plan for eachTransformerBlock, there is usually ann.Embedding in the first layer and a finalnn.Linear projection layer, where user could choose row-wise or column-wise sharding to the firstnn.Embedding and column-wise sharding to the lastnn.Linear projection layer with proper input and output layouts specified.Here is an example:
model=parallelize_module(model,tp_mesh,{"tok_embeddings":RowwiseParallel(input_layouts=Replicate(),),"output":ColwiseParallel(output_layouts=Replicate(),),})
Note
If the model to be partitioned is too large to fit into CPU memory, one could either usemeta device initialization (for example, initialize the model on meta device first, shard the layers, and the materialize the model), or parallelize theTransformerBlock layer by layer during the Transformer model initialization.
Apply Sequence Parallel toLayerNorm/RMSNorm layers#
Sequence Parallel works on top of the Tensor Parallel illustrated above. Compared with basic Tensor Parallel, which only shards tensors within theAttention modules andFeedForward modules and keep their module inputs and outputs (namely activations in the forward pass and gradients in the backward pass) replicated, Sequence Parallel keeps them sharded on the sequence dimension.
In a typicalTransformerBlock, the forward function combines norm layers (LayerNorm orRMSNorm), an attention layer, a feed forward layer, and residual connections. For example:
# forward in a TransformerBlockdefforward(self,x):h=x+self.attention(self.attention_norm(x))out=h+self.feed_forward(self.ffn_norm(h))returnout
In most use cases, the activations (and gradients) are of the shape[batchsize,sequencelength,hiddendimension] outside theAttention andFeedForward modules. In the DTensor’s language, Sequence Parallel performs activation computation using theShard(1) layout for both forward/backward of the module.Following the code example earlier, the code below demonstrates how we apply Sequence Parallel to the norm layers within aTransformerBlock:
First let’s import the required dependencies for Sequence Parallel:
fromtorch.distributed.tensor.parallelimport(PrepareModuleInput,SequenceParallel,)
Next let’s adjust thelayer_tp_plan to enable sequence parallel on theRMSNorm layers:
layer_tp_plan={# Now the input and output of SequenceParallel has Shard(1) layouts,# to represent the input/output tensors sharded on the sequence dimension"attention_norm":SequenceParallel(),"attention":PrepareModuleInput(input_layouts=(Shard(1),Replicate()),desired_input_layouts=(Replicate(),Replicate()),),"attention.wq":ColwiseParallel(use_local_output=False),"attention.wk":ColwiseParallel(use_local_output=False),"attention.wv":ColwiseParallel(use_local_output=False),"attention.wo":RowwiseParallel(output_layouts=Shard(1)),"ffn_norm":SequenceParallel(),"feed_forward":PrepareModuleInput(input_layouts=(Shard(1),),desired_input_layouts=(Replicate(),),),"feed_forward.w1":ColwiseParallel(),"feed_forward.w2":RowwiseParallel(output_layouts=Shard(1)),"feed_forward.w3":ColwiseParallel(),}
One can see we now usePrepareModuleInput to modify the module input layouts to the Attention and FeedForward layers fromShard(1) toReplicate(), and mark their output layouts asShard(1).Just like what happens to Tensor Parallelism, one only needs to specify the tensor sharding layouts of the inputs and outputs, and the communication between layers will happen automatically.
Note that with Sequence Parallel, we assume the inputs and outputs of aTransformerBlock are always sharded on the sequence dimension, so that multipleTransformerBlocks can be concatenated seamlessly.This can be facilitated by explicitly specifying the output of the beginningnn.Embedding layer and the input of the finalnn.Linear projection layer to beShard(1):
model=parallelize_module(model,tp_mesh,{"tok_embeddings":RowwiseParallel(input_layouts=Replicate(),output_layouts=Shard(1),),"norm":SequenceParallel(),"output":ColwiseParallel(input_layouts=Shard(1),output_layouts=Replicate()),})
Apply Loss Parallel#
Loss Parallel is a related technique to save memory and communication when the loss function is computed, as model outputs are usually very large. In Loss Parallel, when the model outputs are sharded on the (often huge) vocabulary dimension, the cross-entropy loss can be computed efficiently, without gathering all the model outputs to every single GPU. This not only significantly reduces the memory consumption, but also improves training speed by reducing communication overhead and doing sharded computation in parallel. The picture below briefly illustrates how Loss Parallel avoids gathering all model outputs to every GPU by doing sharded computation.

Figure 2. Cross-entropy loss forward computation with loss parallel on one GPU. Blue represents sharded tensors; green represents replicated tensors; yellow represents tensors with partial values (to be all-reduced). Black arrows are local computations; red arrows are functional collectives among GPUs.#
In the PyTorch Tensor Parallel API, Loss Parallel can be enabled via a context managerloss_parallel, with which one can directly usetorch.nn.functional.cross_entropy ortorch.nn.CrossEntropyLoss without modifying other parts of their code.
To apply Loss Parallel, the model predictions, usually of the shape[batchsize,sequencelength,vocabularysize], should be sharded on the vocabulary dimension. This can be easily done via marking the output layouts of the last linear projection layer output:
model=parallelize_module(model,tp_mesh,{"tok_embeddings":RowwiseParallel(input_layouts=Replicate(),output_layouts=Shard(1),),"norm":SequenceParallel(),"output":ColwiseParallel(input_layouts=Shard(1),# use DTensor as the outputuse_local_output=False,),},)
In the code above, we also apply Sequence Parallel to the norm layer before output. We applyuse_local_output=False to let the output stay as a DTensor, to work with theloss_parallel context manager. After that, one can simply call the cross_entropy loss function as is shown below. Note that the backward computation also needs to happen within the context.
importtorch.nn.functionalasFfromtorch.distributed.tensor.parallelimportloss_parallelpred=model(input_ids)withloss_parallel():# assuming pred and labels are of the shape [batch, seq, vocab]loss=F.cross_entropy(pred.flatten(0,1),labels.flatten(0,1))loss.backward()
Combine Tensor Parallel with Fully Sharded Data Parallel together#
Now that we have shown how to apply Tensor/Sequence Parallel to the model, let us also take a look at how Tensor Parallel and Fully Sharded Data Parallel could work together.Since Tensor Parallelism incurs communications that block the computation, we want to make sure it runs within a fast communication channel, such as NVLink.In practice, we usually apply Tensor Parallel within each host, and apply Fully Sharded Data Parallel across the hosts.

Figure 3. FSDP and TP work on separate device dimensions, FSDP communication happens inter-host and TP communication happens intra-host.#
This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and we just need pass each “sub” DeviceMesh to each individual parallelism APIs:
fromtorch.distributed.device_meshimportinit_device_meshfromtorch.distributed.tensor.parallelimportColwiseParallel,RowwiseParallel,parallelize_modulefromtorch.distributed.fsdpimportfully_shard# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TPmesh_2d=init_device_mesh("cuda",(8,8))tp_mesh=mesh_2d["tp"]# a submesh that connects intra-host devicesdp_mesh=mesh_2d["dp"]# a submesh that connects inter-host devicesmodel=Model(...)tp_plan={...}# apply Tensor Parallel intra-host on tp_meshmodel_tp=parallelize_module(model,tp_mesh,tp_plan)# apply FSDP inter-host on dp_meshmodel_2d=fully_shard(model_tp,mesh=dp_mesh,...)
This would allow us to easily apply Tensor Parallel within each host (intra-host) and apply FSDP across hosts (inter-hosts), with0-code changes to the Llama model.The Tensor(Model) Parallel and Data Parallel techniques combined together provides the ability to continue increasing model size and training efficiently using a large number of GPUs.
Conclusion#
This tutorial demonstrates how to train a large Transformer-like model across hundreds to thousands of GPUs using Tensor Parallel in combination with Fully Sharded Data Parallel.It explains how to apply Tensor Parallel to different parts of the model, withno code changes to the model itself. Tensor Parallel is a efficient model parallelism technique for large scale training.
To see the complete end-to-end code example explained in this tutorial, please refer to theTensor Parallel examples in the pytorch/examples repository.