Rate this Page

Tensor Parallelism - torch.distributed.tensor.parallel#

Created On: Jun 13, 2025 | Last Updated On: Sep 09, 2025

Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor(DTensor)and provides different parallelism styles: Colwise, Rowwise, and Sequence Parallelism.

Warning

Tensor Parallelism APIs are experimental and subject to change.

The entrypoint to parallelize yournn.Module using Tensor Parallelism is:

torch.distributed.tensor.parallel.parallelize_module(module,device_mesh=None,parallelize_plan=None,*,src_data_rank=0)[source]#

Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.

We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan containsParallelStyle, which indicates how user wants the module or sub_moduleto be parallelized.

User can also specify different parallel style per module fully qualified name (FQN).

Note thatparallelize_module only accepts a 1-DDeviceMesh, if you have a 2-D or N-DDeviceMesh,slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e.device_mesh["tp"])

Parameters:
  • module (nn.Module) – Module to be parallelized.

  • device_mesh (DeviceMesh, optional) – Object which describes the mesh topology of devices for the DTensor.If not specified, the call must be under a DeviceMesh context.

  • parallelize_plan (Union[ParallelStyle, Dict[str,ParallelStyle]], optional) – The plan used to parallelize the module. It can be either aParallelStyle object which contains how we prepareinput/output for Tensor Parallelism or it can be a dict of moduleFQN and its correspondingParallelStyle object. If notspecified, the call will do nothing at the moment.

Keyword Arguments:

src_data_rank (int,optional) – the rank of the source data for the logical/global tensor, it is used bydistribute_tensor() to scatter/broadcast the shards/replicas to other ranks. By default,we usegroup_rank=0 on each DeviceMesh dimension as the source data to preserve the single-devicesemantic. If passingNone explicitly,parallelize_module() simply uses its local data insteadof trying to preserve the single-device semantic via scatter/broadcast. Default: 0

Returns:

Ann.Module object parallelized.

Return type:

Module

Example::
>>>fromtorch.distributed.tensor.parallelimportparallelize_module,ColwiseParallel>>>fromtorch.distributed.device_meshimportinit_device_mesh>>>>>># Define the module.>>>m=Model(...)>>>tp_mesh=init_device_mesh("cuda",(8,))>>>m=parallelize_module(m,tp_mesh,{"w1":ColwiseParallel(),"w2":RowwiseParallel()})>>>

Note

For complex module architecture like Attention, MLP layers, we recommend composingdifferent ParallelStyles together (i.e.ColwiseParallel andRowwiseParallel) and passas a parallelize_plan, to achieves the desired sharding computation.

Tensor Parallelism supports the following parallel styles:

classtorch.distributed.tensor.parallel.ColwiseParallel(*,input_layouts=None,output_layouts=None,use_local_output=True)[source]#

Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding.Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules.(i.e. MLP, Attention)

Keyword Arguments:
  • input_layouts (Placement,optional) – The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor tobecome a DTensor. If not specified, we assume the input tensor to be replicated.

  • output_layouts (Placement,optional) – The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Modulewith the user desired layout. If not specified, the output tensor is sharded on the last dimension.

  • use_local_output (bool,optional) – Whether to use localtorch.Tensor instead ofDTensor for the module output, default: True.

Returns:

AParallelStyle object that represents Colwise sharding of the nn.Module.

Example::
>>>fromtorch.distributed.tensor.parallelimportparallelize_module,ColwiseParallel>>>fromtorch.distributed.device_meshimportinit_device_mesh>>>...>>>m=Model(...)# m is a nn.Module that contains a "w1" nn.Linear submodule>>>tp_mesh=init_device_mesh("cuda",(8,))>>>>>># By default, the input of the "w1" Linear will be converted to Replicated DTensor>>># and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.>>>>>>sharded_mod=parallelize_module(m,tp_mesh,{"w1":ColwiseParallel()})>>>...

Note

By defaultColwiseParallel output is sharded on the last dimension if theoutput_layouts notspecified, if there’re operators that require specific tensor shape (i.e. before the pairedRowwiseParallel),keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size.

classtorch.distributed.tensor.parallel.RowwiseParallel(*,input_layouts=None,output_layouts=None,use_local_output=True)[source]#

Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.(i.e. MLP, Attention)

Keyword Arguments:
  • input_layouts (Placement,optional) – The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor tobecome a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.

  • output_layouts (Placement,optional) – The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Modulewith the user desired layout. If not specified, the output tensor is replicated.

  • use_local_output (bool,optional) – Whether to use localtorch.Tensor instead ofDTensor for the module output, default: True.

Returns:

AParallelStyle object that represents Rowwise sharding of the nn.Module.

Example::
>>>fromtorch.distributed.tensor.parallelimportparallelize_module,RowwiseParallel>>>fromtorch.distributed.device_meshimportinit_device_mesh>>>...>>>m=Model(...)# m is a nn.Module that contains a "w2" nn.Linear submodule>>>tp_mesh=init_device_mesh("cuda",(8,))>>>>>># By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim>>># and the output of "w2" will return a replicated :class:`torch.Tensor`.>>>>>>sharded_mod=parallelize_module(m,tp_mesh,{"w2":RowwiseParallel()}),>>>...
classtorch.distributed.tensor.parallel.SequenceParallel(*,sequence_dim=1,use_local_output=False)[source]#

SequenceParallel replicates a compatiblenn.Module parameters and runs the sharded computation withinput sharded on the sequence dimension. This currently supportsnn.LayerNorm,nn.Dropout, and theRMSNorm python implementation

This style implements the operation that is described in the paperReducing Activation Recomputation in Large Transformer Models

If the input passed in to thisnn.Module is atorch.Tensor, it assumes that the input is already shardedon the sequence dimension and converts the input to aDTensor sharded on the sequence dimension. If the inputpassed in to thisnn.Module is already aDTensor but is not sharded on the sequence dimension, it wouldredistribute the input to be sharded on the sequence dimension.

The output of thenn.Module will be sharded on the sequence dimension.

Keyword Arguments:
  • sequence_dim (int,optional) – The sequence dimension of the input tensor for thenn.Module, this is used to annotate the input tensor tobecome a DTensor that is sharded on the sequence dimension, default: 1.

  • use_local_output (bool,optional) – Whether to use localtorch.Tensor instead ofDTensor for the module output, default: False.

Returns:

AParallelStyle object that represents Sequence Parallel of thenn.Module.

Example::
>>>fromtorch.distributed.tensor.parallelimportparallelize_module,SequenceParallel>>>fromtorch.distributed.device_meshimportinit_device_mesh>>>...>>>m=Model(...)# m is a nn.Module that contains a "norm" nn.LayerNorm submodule>>>tp_mesh=init_device_mesh("cuda",(8,))>>>>>># By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim>>># and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.>>>>>>sharded_mod=parallelize_module(m,tp_mesh,{"norm":SequenceParallel()}),>>>...

Note

SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.nn.LayerNorm orRMSNorm, and they by default have ones initialization). If you have custominits for the weights on those modules, you need to broadcast the weights before/after parallelizingto ensure that they are replicated.

To simply configure the nn.Module’s inputs and outputs with DTensor layoutsand perform necessary layout redistributions, without distribute the moduleparameters to DTensors, the followingParallelStyle s can be used intheparallelize_plan when callingparallelize_module:

classtorch.distributed.tensor.parallel.PrepareModuleInput(*,input_layouts=None,desired_input_layouts=None,input_kwarg_layouts=None,desired_input_kwarg_layouts=None,use_local_output=False)[source]#

Configure the nn.Module’s inputs to convert the input tensors of the nn.Module to DTensors at runtime according toinput_layouts, and perform layout redistribution according to thedesired_input_layouts.

Keyword Arguments:
  • input_layouts (Union[Placement,Tuple[Optional[Placement]]]) – The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors toDTensors. If some inputs are not torch.Tensor or no need to convert to DTensors,None need to be specifiedas a placeholder. default: None.

  • desired_input_layouts (Union[Placement,Tuple[Optional[Placement]]]) – The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Modulehave the desired DTensor layouts. This argument needs to have the same length withinput_layouts. default: None.

  • input_kwarg_layouts (Dict[str,Placement]) – The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors.default: None

  • desired_input_kwarg_layouts – (Dict[str, Placement]):The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Modulehave the desired DTensor layouts. default: None.

  • use_local_output (bool,optional) – Whether to use localtorch.Tensor instead ofDTensor for the module inputs, default: False.

Returns:

AParallelStyle object that prepares the sharding layouts of the nn.Module’s inputs.

Example::
>>>fromtorch.distributed.tensor.parallelimportparallelize_module,PrepareModuleInput>>>fromtorch.distributed.device_meshimportinit_device_mesh>>>...>>>block=TransformerBlock(...)# block is a nn.Module that contains an "attn" Attention submodule>>>tp_mesh=init_device_mesh("cuda",(8,))>>>>>># According to the style specified below, the first input of attn will be annotated to Sharded DTensor>>># and then redistributed to Replicated DTensor.>>>parallelize_module(>>>block,# this can be a submodule or module>>>tp_mesh,>>>parallelize_plan={>>>"attn":PrepareModuleInput(>>>input_layouts=(Shard(0),None,None,...),>>>desired_input_layouts=(Replicate(),None,None,...)>>>),>>>}>>>)
classtorch.distributed.tensor.parallel.PrepareModuleOutput(*,output_layouts,desired_output_layouts,use_local_output=True)[source]#

Configure the nn.Module’s outputs to convert the output tensors of the nn.Module to DTensors at runtime according tooutput_layouts, and perform layout redistribution according to thedesired_output_layouts.

Keyword Arguments:
  • output_layouts (Union[Placement,Tuple[Placement]]) – The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors toDTensors if they aretorch.Tensor. If some outputs are not torch.Tensor or no need to convert to DTensors,None need to be specified as a placeholder.

  • desired_output_layouts (Union[Placement,Tuple[Placement]]) – The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Modulehave the desired DTensor layouts.

  • use_local_output (bool,optional) – Whether to use localtorch.Tensor instead ofDTensor for the module outputs, default: True.

Returns:

A ParallelStyle object that prepares the sharding layouts of the nn.Module’s outputs.

Example::
>>>fromtorch.distributed.tensor.parallelimportparallelize_module,PrepareModuleOutput>>>fromtorch.distributed.device_meshimportinit_device_mesh>>>...>>>block=TransformerBlock(...)# block is a nn.Module that contains an "attn" Attention submodule>>>tp_mesh=init_device_mesh("cuda",(8,))>>>>>># According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor>>># and then redistributed to Sharded DTensor.>>>parallelize_module(>>>block,# this can be a submodule or module>>>tp_mesh,>>>parallelize_plan=PrepareModuleOutput(>>>output_layouts=Replicate(),>>>desired_output_layouts=Shard(0)>>>)>>>)
classtorch.distributed.tensor.parallel.PrepareModuleInputOutput(*,input_layouts=None,desired_input_layouts=None,input_kwarg_layouts=None,desired_input_kwarg_layouts=None,use_local_input=False,output_layouts,desired_output_layouts,use_local_output=True)[source]#

Configure the nn.Module’s inputs (and outputs) to convert the input tensors (and output tensors, respectively) of the nn.Moduleto DTensors at runtime according toinput_layouts (and output_layouts, respectively), and perform layout redistributionaccording to thedesired_input_layouts (anddesired_output_layouts, respectively). This is a combination ofPrepareModuleInput andPrepareModuleOutput.

Keyword Arguments:
  • input_layouts (Union[Placement,Tuple[Optional[Placement]]]) – The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors toDTensors. If some inputs are not torch.Tensor or no need to convert to DTensors,None need to be specifiedas a placeholder. default: None.

  • desired_input_layouts (Union[Placement,Tuple[Optional[Placement]]]) – The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Modulehave the desired DTensor layouts. This argument needs to have the same length withinput_layouts. default: None.

  • input_kwarg_layouts (Dict[str,Placement]) – The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors.default: None

  • desired_input_kwarg_layouts – (Dict[str, Placement]):The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Modulehave the desired DTensor layouts. default: None.

  • use_local_input (bool,optional) – Whether to use localtorch.Tensor instead ofDTensor for the module inputs, default: False.

  • output_layouts (Union[Placement,Tuple[Placement]]) – The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors toDTensors if they aretorch.Tensor. If some outputs are not torch.Tensor or no need to convert to DTensors,None need to be specified as a placeholder.

  • desired_output_layouts (Union[Placement,Tuple[Placement]]) – The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Modulehave the desired DTensor layouts.

  • use_local_output (bool,optional) – Whether to use localtorch.Tensor instead ofDTensor for the module outputs, default: True.

Returns:

AParallelStyle object that prepares the sharding layouts of the nn.Module’s inputs and outputs.

Example::
>>>fromtorch.distributed.tensor.parallelimportparallelize_module,PrepareModuleInputOutput>>>fromtorch.distributed.device_meshimportinit_device_mesh>>>...>>>block=TransformerBlock(...)# block is a nn.Module that contains an "attn" Attention submodule>>>tp_mesh=init_device_mesh("cuda",(8,))>>>>>># According to the style specified below, the first input of attn will be annotated as Sharded DTensor>>># and then redistributed to Replicated DTensor, and the output of the TransformerBlock will be annotated>>># as Replicated DTensor and then redistributed to Sharded DTensor.>>>parallelize_module(>>>block,# this can be a submodule or module>>>tp_mesh,>>>parallelize_plan={>>>"attn":PrepareModuleInputOutput(>>>input_layouts=(Shard(0),None,None,...),>>>desired_input_layouts=(Replicate(),None,None,...),>>>output_layouts=Replicate(),>>>desired_output_layouts=Shard(0),>>>),>>>}>>>)

Note

when using theShard(dim) as the input/output layouts for the aboveParallelStyle s, we assume the input/output activation tensors are evenly sharded onthe tensor dimensiondim on theDeviceMesh that TP operates on. For instance,sinceRowwiseParallel accepts input that is sharded on the last dimension, it assumesthe input tensor has already been evenly sharded on the last dimension. For the case of uneven sharded activation tensors, one could pass in DTensor directly to the partitioned modules, and useuse_local_output=False to return DTensor after eachParallelStyle, where DTensor could track the uneven sharding information.

For models like Transformer, we recommend users to useColwiseParallelandRowwiseParallel together in the parallelize_plan for achieve the desiredsharding for the entire model (i.e. Attention and MLP).

Parallelized cross-entropy loss computation (loss parallelism), is supported via the following context manager:

torch.distributed.tensor.parallel.loss_parallel()[source]#

A context manager that enables loss parallelism, where efficient parallelized loss computationcan be performed when the input is sharded on the class dimension. Currently only the cross-entropyloss is supported.

Within this context manager, one can usecross_entropy() orCrossEntropyLoss as usual, with the following assumptions on the input parameters.The correspondingbackward() call, if any, also needs to happen under this context manager.

Parameters:
  • input (DTensor) – Input logits. Assumed to be sharded on the class dimension.

  • target (Union[torch.Tensor,DTensor]) – Must be ground truth class indices (class probabilities currently not supported).Assumed to be replicated across theDeviceMesh.

  • weight (Union[torch.Tensor,DTensor], optional) – If given, assumed to be replicated across theDeviceMesh.

  • label_smoothing – Currently not supported.

Returns:

A replicatedDTensor.

Example

A sharded DTensor is manually created here to showcase the usage.In practice, it is usually the output of a TP module.

>>>fromtorch.distributed.tensor.parallelimportloss_parallel>>>fromtorch.distributed.device_meshimportinit_device_mesh>>>...>>>device_mesh=init_device_mesh("cuda",(8,))>>>input=torch.randn(4,16,device="cuda",requires_grad=True)>>>dist_input=distribute_tensor(input,device_mesh,placements=[Shard(1)])>>>target=torch.randint(16,(4,),device="cuda")>>>withloss_parallel():>>>loss=F.cross_entropy(dist_input,target,reduction="mean")>>>loss.backward()>>>...

Warning

The loss_parallel API is experimental and subject to change.