Rate this Page

torch.distributed.tensor#

Created On: Jun 13, 2025 | Last Updated On: Aug 23, 2025

Note

torch.distributed.tensor is currently in alpha state and underdevelopment, we are committing backward compatibility for the most APIs listedin the doc, but there might be API changes if necessary.

PyTorch DTensor (Distributed Tensor)#

PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributedlogic, including sharded storage, operator computation and collective communications across devices/hosts.DTensor could be used to build different parallelism solutions and support sharded state_dict representationwhen working with multi-dimensional sharding.

Please see examples from the PyTorch native parallelism solutions that are built on top ofDTensor:

DTensor follows the SPMD (single program, multiple data) programming model to empower users towrite distributed program as if it’s asingle-device program with the same convergence property. Itprovides a uniform tensor sharding layout (DTensor Layout) through specifying theDeviceMeshandPlacement:

  • DeviceMesh represents the device topology and the communicators of the cluster usingan n-dimensional array.

  • Placement describes the sharding layout of the logical tensor on theDeviceMesh.DTensor supports three types of placements:Shard,Replicate andPartial.

DTensor Class APIs#

DTensor is atorch.Tensor subclass. This means once aDTensor is created, it could beused in very similar way totorch.Tensor, including running different types of PyTorch operators as ifrunning them in a single device, allowing proper distributed computation for PyTorch operators.

In addition to existingtorch.Tensor methods, it also offers a set of additional methods to interact withtorch.Tensor,redistribute the DTensor Layout to a new DTensor, get the full tensor contenton all devices, etc.

classtorch.distributed.tensor.DTensor(local_tensor,spec,*,requires_grad)#

DTensor (Distributed Tensor) is a subclass oftorch.Tensor that provides single-device likeabstraction to program with multi-devicetorch.Tensor. It describes the distributed tensor shardinglayout (DTensor Layout) through theDeviceMesh and following types ofPlacement:

  • Shard: Tensor sharded on the tensor dimensiondim on the devices of theDeviceMesh dimension

  • Replicate: Tensor replicated on the devices of theDeviceMesh dimension

  • Partial: Tensor is pending reduction on the devices of theDeviceMesh dimension

When calling PyTorch operators,DTensor overrides the PyTorch operators to perform sharded computation and issuecommunications whenever necessary. Along with the operator computation,DTensor will transform or propagate theplacements (DTensor Layout) properly (based on the operator semantic itself) and generate newDTensor outputs.

To ensure numerical correctness of theDTensor sharded computation when calling PyTorch operators,DTensorrequires every Tensor argument of the operator be DTensor.

Note

Directly using the Tensor subclass constructor here is not the recommended way to create aDTensor(i.e. it does not handle autograd correctly hence is not the public API). Please refer to thecreate_dtensorsection to see how to create aDTensor.

Return type

DTensor

__create_chunk_list__()[source]#

Return a list of ChunkStorageMetadata, which is a dataclass that describes the size/offset of the local shard/replicaon current rank. For DTensor, each rank will have a single local shard/replica, so the returned list usually onlyhas one element.

This dunder method is primariy used for distributed checkpoint purpose.

Returns

A List[ChunkStorageMetadata] object that represents the shard size/offset on the current rank.

staticfrom_local(local_tensor,device_mesh=None,placements=None,*,run_check=False,shape=None,stride=None)[source]#

Create aDTensor from a local torch.Tensor on each rankaccording to thedevice_mesh andplacements specified.

Parameters
  • local_tensor (torch.Tensor) – local torch.Tensor on each rank.

  • device_mesh (DeviceMesh, optional) – DeviceMesh to place thetensor, if not specified, must be called under a DeviceMeshcontext manager, default: None

  • placements (List[Placement], optional) – the placements thatdescribes how to place the local torch.Tensor on DeviceMesh, musthave the same number of elements asdevice_mesh.ndim.

Keyword Arguments
  • run_check (bool,optional) – at a cost of extra communications, performsanity check across ranks to check each local tensor’s meta informationto ensure correctness. If haveReplicate inplacements, thedata on first rank of the device mesh dimension will be broadcastedto other ranks. default: False

  • shape (torch.Size,optional) – A List of int which specifies the size ofDTensor which build on top oflocal_tensor. Note this needs to beprovided if the shape oflocal_tensor are different across the ranks.If not provided,shape will be computed assuming the given distributedtensor is evenly sharded across ranks. default: None

  • stride (tuple,optional) – A List of int which specifies the stride of DTensor.If not provided,stride will be computed assuming the given distributedtensor is evenly sharded across ranks. default: None

Returns

ADTensor object

Return type

DTensor

Note

Whenrun_check=False, it is the user’s responsibility to ensure thelocal tensor passed in is correct across ranks (i.e. the tensor is sharded fortheShard(dim) placement or replicated for theReplicate() placement).If not, the behavior of the created DTensor is undefined.

Note

from_local is differentiable, therequires_grad of the createdDTensor object will depend on iflocal_tensor requires_grad or not.

full_tensor(*,grad_placements=None)[source]#

Return the full tensor of this DTensor. It will perform necessary collectivesto gather the local tensors from other ranks in its DeviceMesh and concatenatethem together. It’s a syntactic sugar of the following code:

dtensor.redistribute(placements=[Replicate()]*mesh.ndim).to_local()

Keyword Arguments

grad_placements (List[Placement], optional) – the placements describesthe future layout of any gradient layout of the full Tensor returned from thisfunction.full_tensor converts DTensor to a full torch.Tensor and the returned torch.tensormight not be used as the original replicated DTensor layout later in the code. Thisargument is the hint that user can give to autograd in case the gradientlayout of the returned tensor does not match the original replicated DTensor layout.If not specified, we will assume the gradient layout of the full tensor be replicated.

Returns

Atorch.Tensor object that represents the full tensor of this DTensor.

Return type

Tensor

Note

full_tensor is differentiable.

redistribute(device_mesh=None,placements=None,*,async_op=False,forward_dtype=None,backward_dtype=None)[source]#

redistribute performs necessary collective operations that redistribute the currentDTensor from its current placements to a new placements, or from its current DeviceMeshto a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor byspecifying a Replicate placement for each dimension of the DeviceMesh.

When redistributing from current to the new placements on one device mesh dimension, wewill perform the following operations including communication collective or local operation:

  1. Shard(dim) ->Replicate():all_gather

  2. Shard(src_dim) ->Shard(dst_dim):all_to_all

  3. Replicate() ->Shard(dim): local chunking (i.e.torch.chunk)

  4. Partial() ->Replicate():all_reduce

  5. Partial() ->Shard(dim):reduce_scatter

redistribute would correctly figure out the necessary redistribute steps for DTensorsthat are created either on 1-D or N-D DeviceMesh.

Parameters
  • device_mesh (DeviceMesh, optional) – DeviceMesh to place theDTensor. If not specified, it would use the current DTensor’s DeviceMesh.default: None

  • placements (List[Placement], optional) – the new placements thatdescribes how to place the DTensor into the DeviceMesh, musthave the same number of elements asdevice_mesh.ndim.default: replicate on all mesh dimensions

Keyword Arguments
  • async_op (bool,optional) – whether to perform the DTensor redistribute operationasynchronously or not. Default: False

  • forward_dtype (torch.dtype,optional) – the local tensor datatype can be converted toforward_dtype before redistributing the local tensor in its forward.The result DTensor will be inforward_dtype Default: None.

  • backward_dtype (torch.dtype,optional) – the local tensor datatype can be converted tobackward_dtype before redistributing the local tensor in its backward.The result DTensor gradient would be converted back to the current DTensor dtype. Default: None

Returns

ADTensor object

Return type

DTensor

Note

redistribute is differentiable, which means user do not need to worry aboutthe backward formula of the redistribute operation.

Note

redistribute currently only supports redistributing DTensor on the same DeviceMesh,Please file an issue if you need to redistribute DTensor to different DeviceMesh.

to_local(*,grad_placements=None)[source]#

Get the local tensor of this DTensor on its current rank. For sharding it returnsa local shard of the logical tensor view, for replication it returns the replica onits current rank.

Keyword Arguments

grad_placements (List[Placement], optional) – the placements describesthe future layout of any gradient layout of the Tensor returned from thisfunction.to_local converts DTensor to local tensor and the returned local tensormight not be used as the original DTensor layout later in the code. Thisargument is the hint that user can give to autograd in case the gradientlayout of the returned tensor does not match the original DTensor layout.If not specified, we will assume the gradient layout remains the sameas the original DTensor and use that for gradient computation.

Returns

Atorch.Tensor orAsyncCollectiveTensor object. it represents thelocal tensor on its current rank. When anAsyncCollectiveTensor object is returned,it means the local tensor is not ready yet (i.e. communication is not finished). In thiscase, user needs to callwait to wait the local tensor to be ready.

Return type

Tensor

Note

to_local is differentiable, therequires_grad of the local tensor returnedwill depend on if theDTensor requires_grad or not.

propertydevice_mesh:DeviceMesh#

TheDeviceMesh attribute that associates with this DTensor object.

Note

device_mesh is a read-only property, it can not be set.

propertyplacements:tuple[torch.distributed.tensor.placement_types.Placement,...]#

The placements attribute of this DTensor that describes the layout of thisDTensor on the its DeviceMesh.

Note

placements is a read-only property, it can not be set.

DeviceMesh as the distributed communicator#

DeviceMesh was built from DTensor as the abstraction to describe cluster’s device topology and representmulti-dimensional communicators (on top ofProcessGroup). To see the details of how to create/use a DeviceMesh,please refer to theDeviceMesh recipe.

DTensor Placement Types#

DTensor supports the following types ofPlacement on eachDeviceMesh dimension:

classtorch.distributed.tensor.placement_types.Shard(dim)[source]#

TheShard(dim) placement describes the DTensor sharding on tensor dimensiondim over a correspondingDeviceMesh dimension, where each rank on theDeviceMesh dimension only holds a shard/piece of the global Tensor. TheShard(dim) placement follows thetorch.chunk(dim) semantic, where thelast few shards on the DeviceMesh dimension might be empty when the tensor dimensionis not evenly divisible on the DeviceMesh dimension. TheShard placement can beused by all DTensor APIs (i.e. distribute_tensor, from_local, etc.)

Parameters

dim (int) – The tensor dimension that describes the DTensor is sharded over itscorresponding DeviceMesh dimension.

Warning

sharding on a tensor dimension where the tensor dimension size is notevenly divisible on a DeviceMesh dimension is currently experimental and subject to change.

dim:int#
classtorch.distributed.tensor.placement_types.Replicate[source]#

TheReplicate() placement describes the DTensor replicating on a correspondingDeviceMesh dimension, where each rank on the DeviceMesh dimension holds areplica of the global Tensor. TheReplicate placement can be used by allDTensor APIs (i.e.distribute_tensor,DTensor.from_local, etc.)

classtorch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source]#

ThePartial(reduce_op) placement describes the DTensor that is pendingreduction on a specifiedDeviceMesh dimension, where each rank on theDeviceMesh dimension holds the partial value of the global Tensor. User canredistribute thePartial DTensor to aReplicate orShard(dim)placement on the specifiedDeviceMesh dimension usingredistribute,which would trigger necessary communication operations under the hood (i.e.allreduce,reduce_scatter).

Parameters

reduce_op (str,optional) – The reduction op to be used for the partial DTensorto produce Replicated/Sharded DTensor. Only element-wise reduction operationsare supported, including: “sum”, “avg”, “product”, “max”, “min”, default: “sum”.

Note

ThePartial placement can be generated as a result of the DTensor operators,and can only be used by theDTensor.from_local API.

reduce_op:str='sum'#
classtorch.distributed.tensor.placement_types.Placement[source]#

The base class for the Placement type, where it describes how a DTensor is placed onto theDeviceMesh.Placement andDeviceMesh together could describe the DTensor Layout.It is the base class of the three main DTensor Placement types:Shard,Replicate,andPartial.

This class is not meant to be used directly, mainly served as a typing stub.

is_partial(reduce_op=None)[source]#
Return type

bool

is_replicate()[source]#
Return type

bool

is_shard(dim=None)[source]#
Return type

bool

Different ways to create a DTensor#

There’re three ways to construct aDTensor:
  • distribute_tensor() creates aDTensor from a logical or “global”torch.Tensor oneach rank. This could be used to shard the leaftorch.Tensor s (i.e. model parameters/buffersand inputs).

  • DTensor.from_local() creates aDTensor from a localtorch.Tensor on each rank, which canbe used to createDTensor from a non-leaftorch.Tensor s (i.e. intermediate activationtensors during forward/backward).

  • DTensor provides dedicated tensor factory functions (e.g.empty(),ones(),randn(), etc.)to allow differentDTensor creations by directly specifying theDeviceMesh andPlacement. Compare todistribute_tensor(), this could directly materializing the sharded memoryon device, instead of performing sharding after initializing the logical Tensor memory.

Create DTensor from a logical torch.Tensor#

The SPMD (single program, multiple data) programming model intorch.distributed launches multiple processes(i.e. viatorchrun) to execute the same program, this means that the model inside the program would beinitialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directlyon GPU if enough memory).

DTensor offers adistribute_tensor() API that could shard the model weights or Tensors toDTensor s,where it would create a DTensor from the “logical” Tensor on each process. This would empower the createdDTensor s to comply with the single device semantic, which is critical fornumerical correctness.

torch.distributed.tensor.distribute_tensor(tensor,device_mesh=None,placements=None,*,src_data_rank=0)[source]#

Distribute a leaftorch.Tensor (i.e. nn.Parameter/buffers) to thedevice_mesh accordingto theplacements specified. The rank ofdevice_mesh andplacements must be thesame. Thetensor to distribute is the logical or “global” tensor, and the API would usethetensor from first rank of the DeviceMesh dimension as the source of truth to preservethe single-device semantic. If you want to construct a DTensor in the middle of the Autogradcomputation, please useDTensor.from_local() instead.

Parameters
  • tensor (torch.Tensor) – torch.Tensor to be distributed. Note that if youwant to shard a tensor on a dimension that is not evenly divisible bythe number of devices in that mesh dimension, we usetorch.chunksemantic to shard the tensor and scatter the shards. The uneven shardingbehavior is experimental and subject to change.

  • device_mesh (DeviceMesh, optional) – DeviceMesh to distribute thetensor, if not specified, must be called under a DeviceMesh contextmanager, default: None

  • placements (List[Placement], optional) – the placements thatdescribes how to place the tensor on DeviceMesh, must have the samenumber of elements asdevice_mesh.ndim. If not specified, we willby default replicate the tensor across thedevice_mesh from thefirst rank of each dimension of thedevice_mesh.

Keyword Arguments

src_data_rank (int,optional) – the rank of the source data for the logical/global tensor, it isused bydistribute_tensor() to scatter/broadcast the shards/replicas to other ranks.By default, we usegroup_rank=0 on each DeviceMesh dimension as the source data to preservethe single-device semantic. If passingNone explicitly,distribute_tensor() simply usesits local data instead of trying to preserve the single-device semantic via scatter/broadcast.Default: 0

Returns

ADTensor orXLAShardedTensor object.

Return type

DTensor

Note

When initialize the DeviceMesh with thexla device_type,distribute_tensorreturnXLAShardedTensor instead. seethis issuefor more details. The XLA integration is experimental and subject to change.

Along withdistribute_tensor(), DTensor also offers adistribute_module() API to allow easiersharding on thenn.Module level

torch.distributed.tensor.distribute_module(module,device_mesh=None,partition_fn=None,input_fn=None,output_fn=None)[source]#

This function expose three functions to control the parameters/inputs/outputs of the module:

1. To perform sharding on the module before runtime execution by specifying thepartition_fn (i.e. allow user to convert Module parameters toDTensorparameters according to thepartition_fn specified).2. To control the inputs or outputs of the module during runtime execution byspecifying theinput_fn andoutput_fn. (i.e. convert the input toDTensor, convert the output back totorch.Tensor)

Parameters
  • module (nn.Module) – user module to be partitioned.

  • device_mesh (DeviceMesh) – the device mesh to place the module.

  • partition_fn (Callable) – the function to partition parameters (i.e. shard certainparameters across thedevice_mesh). Ifpartition_fn is not specified,by default we replicate all module parameters ofmodule across the mesh.

  • input_fn (Callable) – specify the input distribution, i.e. could control how theinput of the module is sharded.input_fn will be installed as a moduleforward_pre_hook (pre forward hook).

  • output_fn (Callable) – specify the output distribution, i.e. could control how theoutput is sharded, or convert it back to torch.Tensor.output_fn will beinstalled as a moduleforward_hook (post forward hook).

Returns

A module that contains parameters/buffers that are allDTensor s.

Return type

Module

Note

When initialize the DeviceMesh with thexla device_type,distribute_modulereturn nn.Module with PyTorch/XLA SPMD annotated parameters. Seethis issuefor more details. The XLA integration is experimental and subject to change.

DTensor Factory Functions#

DTensor also provides dedicated tensor factory functions to allow creatingDTensor directlyusing torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionallyspecifying theDeviceMesh andPlacement for theDTensor created:

torch.distributed.tensor.zeros(*size,requires_grad=False,dtype=None,layout=torch.strided,device_mesh=None,placements=None)[source]#

Returns aDTensor filled with the scalar value 0.

Parameters

size (int...) – a sequence of integers defining the shape of the outputDTensor.Can be a variable number of arguments or a collection like a list or tuple.E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))

Keyword Arguments
  • requires_grad (bool,optional) – If autograd should record operations on thereturnedDTensor. Default:False.

  • dtype (torch.dtype, optional) – the desired data type of returnedDTensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()).

  • layout (torch.layout, optional) – the desired layout of returnedDTensor.Default:torch.strided.

  • device_meshDeviceMesh type, contains the mesh info of ranks

  • placements – a sequence ofPlacement type:Shard,Replicate

Returns

ADTensor object on each rank

Return type

DTensor

torch.distributed.tensor.ones(*size,dtype=None,layout=torch.strided,requires_grad=False,device_mesh=None,placements=None)[source]#

Returns aDTensor filled with the scalar value 1, with the shape definedby the variable argumentsize.

Parameters

size (int...) – a sequence of integers defining the shape of the outputDTensor.Can be a variable number of arguments or a collection like a list or tuple.E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))

Keyword Arguments
  • dtype (torch.dtype, optional) – the desired data type of returnedDTensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()).

  • layout (torch.layout, optional) – the desired layout of returned DTensor.Default:torch.strided.

  • requires_grad (bool,optional) – If autograd should record operations on thereturnedDTensor. Default:False.

  • device_meshDeviceMesh type, contains the mesh info of ranks

  • placements – a sequence ofPlacement type:Shard,Replicate

Returns

ADTensor object on each rank

Return type

DTensor

torch.distributed.tensor.empty(*size,dtype=None,layout=torch.strided,requires_grad=False,device_mesh=None,placements=None)[source]#

Returns aDTensor filled with uninitialized data. The shape of theDTensoris defined by the variable argumentsize.

Parameters

size (int...) – a sequence of integers defining the shape of the outputDTensor.Can be a variable number of arguments or a collection like a list or tuple.E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))

Keyword Arguments
  • dtype (torch.dtype, optional) – the desired data type of returnedDTensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()). layout (torch.layout, optional): the desired layout of returnedDTensor.Default:torch.strided.

  • requires_grad (bool,optional) – If autograd should record operations on thereturnedDTensor. Default:False.

  • device_meshDeviceMesh type, contains the mesh info of ranks

  • placements – a sequence ofPlacement type:Shard,Replicate

Returns

ADTensor object on each rank

Return type

DTensor

torch.distributed.tensor.full(size,fill_value,*,dtype=None,layout=torch.strided,requires_grad=False,device_mesh=None,placements=None)[source]#

Returns aDTensor filled withfill_value according todevice_mesh andplacements, with the shape defined by the argumentsize.

Parameters
  • size (int...) – a sequence of integers defining the shape of the outputDTensor.Can be a variable number of arguments or a collection like a list or tuple.E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))

  • fill_value (Scalar) – the value to fill the output tensor with.

Keyword Arguments
  • dtype (torch.dtype, optional) – the desired data type of returnedDTensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()).

  • layout (torch.layout, optional) – the desired layout of returned DTensor.Default:torch.strided.

  • requires_grad (bool,optional) – If autograd should record operations on thereturnedDTensor. Default:False.

  • device_meshDeviceMesh type, contains the mesh info of ranks.

  • placements – a sequence ofPlacement type:Shard,Replicate

Returns

ADTensor object on each rank

Return type

DTensor

torch.distributed.tensor.rand(*size,requires_grad=False,dtype=None,layout=torch.strided,device_mesh=None,placements=None)[source]#

Returns aDTensor filled with random numbers from a uniform distributionon the interval[0,1). The shape of the tensor is defined by the variableargumentsize.

Parameters

size (int...) – a sequence of integers defining the shape of the outputDTensor.Can be a variable number of arguments or a collection like a list or tuple.E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))

Keyword Arguments
  • dtype (torch.dtype, optional) – the desired data type of returnedDTensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()).

  • layout (torch.layout, optional) – the desired layout of returned DTensor.Default:torch.strided.

  • requires_grad (bool,optional) – If autograd should record operations on thereturnedDTensor. Default:False.

  • device_meshDeviceMesh type, contains the mesh info of ranks.

  • placements – a sequence ofPlacement type:Shard,Replicate

Returns

ADTensor object on each rank

Return type

DTensor

torch.distributed.tensor.randn(*size,requires_grad=False,dtype=None,layout=torch.strided,device_mesh=None,placements=None)[source]#

Returns aDTensor filled with random numbers from a normal distributionwith mean 0 and variance 1. The shape of the tensor is defined by the variableargumentsize.

Parameters

size (int...) – a sequence of integers defining the shape of the outputDTensor.Can be a variable number of arguments or a collection like a list or tuple.E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))

Keyword Arguments
  • dtype (torch.dtype, optional) – the desired data type of returnedDTensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()).

  • layout (torch.layout, optional) – the desired layout of returned DTensor.Default:torch.strided.

  • requires_grad (bool,optional) – If autograd should record operations on thereturnedDTensor. Default:False.

  • device_meshDeviceMesh type, contains the mesh info of ranks.

  • placements – a sequence ofPlacement type:Shard,Replicate

Returns

ADTensor object on each rank

Return type

DTensor

Random Operations#

DTensor provides distributed RNG functionality to ensure that random operations on sharded tensors get unique values, and random operations on replicated tensors get the same values. This system requires that all participatingranks (e.g. SPMD ranks) start out using the same generator state before each dtensor random operation is performed,and if this is true, it ensures they all end up at the same state after each dtensor random operation completes. There is no communication performed during random operations to synchronize RNG states.

Operators that accept agenerator kwarg will utilize the user-passed generator, if passed, or the default generator for the device otherwise. Whichever generator is used, it will be advanced after the DTensor operation. It is valid to use the same generator for both DTensor and non-DTensor operations, but care must be taken to ensure the non-DTensor operations advance the generator state equally on all ranks if so.

When using DTensor together with Pipeline Parallelism, ranks for each pipeline stage should use a distinct seed, and ranks within a pipeline stage should use the same seed.

DTensor’s RNG infra is based on the philox based RNG algorithm, and supports any philox based backend (cuda, and other cuda-like devices), but unfortunately does not yet support the CPU backend.

Debugging#

Logging#

When launching the program, you can turn on additional logging using theTORCH_LOGS environment variable fromtorch._logging :

  • TORCH_LOGS=+dtensor will displaylogging.DEBUG messages and all levels above it.

  • TORCH_LOGS=dtensor will displaylogging.INFO messages and above.

  • TORCH_LOGS=-dtensor will displaylogging.WARNING messages and above.

Debugging Tools#

To debug the program that applied DTensor, and understand more details about what collectives happened under thehood, DTensor provides aCommDebugMode:

classtorch.distributed.tensor.debug.CommDebugMode#

CommDebugMode is a context manager that counts the number offunctional collectives within its context. It does this using aTorchDispatchMode.

Note

Not all collectives are supported yet.

Example usage

mod=...comm_mode=CommDebugMode()withcomm_mode:mod.sum().backward()print(comm_mode.get_comm_counts())
generate_comm_debug_tracing_table(noise_level=3)[source]#

Generates detailed table displaying operations and collective tracing informationon a module level. Amount of information is dependent on noise_level

  1. prints module-level collective counts

  2. prints dTensor operations not included in trivial operations, module information

  3. prints operations not included in trivial operations

  4. prints all operations

generate_json_dump(file_name='comm_mode_log.json',noise_level=3)[source]#

Creates json file used to build browser visual0. prints module-level collective counts1. prints dTensor operations not included in trivial operations2. prints operations not included in trivial operations3. prints all operations

get_comm_counts()[source]#

Returns the communication counts as a dictionary.

Returns

The communication counts as a dictionary.

Return type

Dict[Any,int]

get_parameter_info()[source]#
Return type

dict[str,dict[str,Any]]

get_sharding_info()[source]#
Return type

dict[str,dict[str,Any]]

get_total_counts()[source]#
Return type

int

log_comm_debug_tracing_table_to_file(file_name='comm_mode_log.txt',noise_level=3)[source]#

Alternative to console CommDebugMode output, writes to file specified by the user

To visualize the sharding of a DTensor that have less than 3 dimensions, DTensor providesvisualize_sharding():

torch.distributed.tensor.debug.visualize_sharding(dtensor,header='',use_rich=False)[source]#

Visualizes sharding in the terminal forDTensor that are 1D or 2D.

Note

This requires thetabulate package, orrich andmatplotlib.No sharding info will be printed for empty tensors

Experimental Features#

DTensor also provides a set of experimental features. These features are either in prototyping stage, or the basicfunctionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks tothese features.

torch.distributed.tensor.experimental.context_parallel(mesh,*,buffers=None,buffer_seq_dims=None,no_restore_buffers=None)[source]#

context_parallel is an experimental API to enable contextparallelism (CP). This API performs two actions: 1) patch the SDPA(torch.nn.functional.scaled_dot_product_attention) with the CP-enabledone, 2) shardbuffers along the sequence dimension and each rank willpreserve the corresponding shard accordingmesh.

Parameters
  • mesh (DeviceMesh) – the device mesh for the context parallelism.

  • buffers (Optional[List[torch.Tensor]]) – buffers that the usage dependon the sequence dimension. Examples are input batch, labels andpositional embedding buffers. These buffers must be sharded alongthe sequence dimension to ensure the accuracy. The sharding willhappen in-place, the buffer’s shape will change within the context.The buffers will be restored after the context finishes.no_restore_buffers can be used to specify which buffers don’tneed to be restored. Note thatbuffers should not contain anynn.Parameter.

  • buffer_seq_dims (Optional[List[int]]) – the sequence dimensions ofbuffers.

  • no_restore_buffers (Optional[Set[torch.Tensor]]) – buffers in these setwon’t be restored after the context exits. This set must be a subsetofbuffers. If the buffers won’t be used after the context exits,these buffers can be put in this list to avoid extra restore time.

Return type

Generator[None, None, None]

Warning

torch.distributed.tensor.experimental.context_parallel is aprototype feature in PyTorch. The API is subject to change.

torch.distributed.tensor.experimental.local_map(func=None,out_placements=None,in_placements=None,in_grad_placements=None,device_mesh=None,*,redistribute_inputs=False)[source]#

local_map() is an experimental API that allows users to passDTensor sto a function that is written to be applied ontorch.Tensor s. It is done by extractingthe local components ofDTensor, call the function, and wrap the outputs toDTensor according to theout_placements.

Parameters
  • func (Callable) – the function to be applied on each local shard ofDTensor s.

  • out_placements (Union[PlacementType, Tuple[PlacementType, …]]) – the desired placements of theDTensor s infunc’s flattened output.If the flattenedoutput is a single value, theout_placements should beof typePlacementType. Otherwise if the flattenedoutput has multiplevalues, theout_placements should be a tuple ofPlacementType values 1:1mapping to the flattenedoutput.Besides, forTensor output, we usePlacementType as itsplacements (aTuple[Placement] value). For non-Tensor output, thePlacementTypeshould beNone.Note that the only exception is when noDTensor argument is passedin. In this case, even ifout_placements is notNone, the result functionshould ignore the desired placements because the function is not running withDTensor s.

  • in_placements (Tuple[PlacementType, …], optional) – the required placements of theDTensor s in the flattened inputs offunc.Ifin_placements is specified,local_map() would examine whether theplacements of eachDTensor argument is the same as the requiredplacements or not. If the placements are not the same andredistribute_inputs isFalse, an exception will be raised. Otherwise ifredistribute_inputs isTrue, the argument will be first redistributed tothe required sharding placements before passing its local tensor tofunc.The only exception is when required placements are notNone and theargument is atorch.Tensor. In this case, the placements examinationwill be skipped and the argument will be directly passed tofunc.Ifin_placements isNone, no placements examination will be performed.Default: None

  • in_grad_placements (Tuple[PlacementType, …], optional) – the placements hint of theDTensor s gradient correspondsto the flattened input DTensor. This argument is the hint that usercan give toto_local() in case the gradient layout of thelocal tensor input does not match itsDTensor input layout.If not specified, we will assume the gradient layout of the localtensor input remains the same as the originalDTensor inputand use that for gradient computation. Default: None.

  • device_mesh (DeviceMesh, optional) – the device mesh that the outputDTensor s are placed on. If notspecified, this will be inferred from the first inputDTensor’s devicemesh. Default: None.

Keyword Arguments

redistribute_inputs (bool,optional) – the bool value indicating whether to reshard the inputDTensor s whentheir placements are different from the required input placements. If thisvalue isFalse and someDTensor input has a different placement,an exception will be raised. Default: False.

Returns

ACallable that appliesfunc to each local shard of the inputDTensorand returns aDTensor constructed from the return value offunc.

Raises
  • AssertionError – For any non-DTensor output, we require its corresponding output placement inout_placements be None. An AssertionError will be raised if this is not the case.

  • ValueError – Ifredistribute_inputs=False but the inputDTensor needs a redistribution according toin_placements.

Example

>>>defmm_allreduce_forward(device_mesh,W,X):>>>partial_sum_tensor=torch.mm(W,X)>>>reduced_tensor=funcol.all_reduce(partial_sum_tensor,"sum",device_mesh)>>>returnreduced_tensor>>>>>>W=torch.randn(12,8,requires_grad=False)>>>X=torch.randn(8,16,requires_grad=False)>>>Y=torch.mm(W,X)>>>row_wise=[Shard(0)]# row-wise sharding placements on 1-d mesh>>>col_wise=[Shard(1)]# col-wise sharding placements on 1-d mesh>>>>>># local_mm_allreduce_forward is the function wrapped with DTensor/Tensor conversion>>>local_mm_allreduce_forward=local_map(>>>mm_allreduce_forward,>>>out_placements=[Replicate()],>>>in_placements=[col_wise,row_wise],>>>device_mesh=device_mesh,>>>)>>>>>>W_dt=distribute_tensor(...W,device_mesh,(col_wise)...)# col-wisely sharded W tensor>>>X_dt=distribute_tensor(...X,device_mesh,(row_wise)...)# row-wisely sharded X tensor>>>Y_dt=local_mm_allreduce_forward(...device_mesh,W_dt,X_dt...)# apply local_mm_allreduce_forward to DTensors

Note

This API is currently experimental and subject to change

torch.distributed.tensor.experimental.register_sharding(op)[source]#

register_sharding() is an experimental API that allows users to register shardingstrategies for an operator when the tensor inputs and outputs are DTensor.It can be useful when: (1) there doesn’t exist a default sharding strategy forop,e.g. whenop is a custom operator that is not supported byDTensor; (2)when users would like to overwrite default sharding strategies of existing operators.

Parameters

op (Union[OpOverload,List[OpOverload]]) – An op or a list of ops to register the customized sharding function.

Returns

A function decorator which can be used to wrap a function that defines the shardingstrategy for the operator specified inop. The defined sharding strategy will beregistered to DTensor and will override the default sharding strategy if DTensor hasalready implemented the operator. The customized sharding function takes the same inputsas the original op (except that if an arg is atorch.Tensor, it will bereplaced by a tensor-like object that DTensor uses internally). The function shouldreturn a sequence of 2-tuples, each specifying acceptable output placements and itscorresponding input placements.

Example

>>>@register_sharding(aten._softmax.default)>>>defcustom_softmax_sharding(x,dim,half_to_float):>>>softmax_dim=dimifdim>=0elsedim+x.ndim>>>acceptable_shardings=[]>>>>>>all_replicate=([Replicate()],[Replicate(),None,None])>>>acceptable_shardings.append(all_replicate)>>>>>>forsharding_diminrange(x.ndim):>>>ifsharding_dim!=softmax_dim:>>>all_sharded=(>>>[Shard(sharding_dim)],>>>[Shard(sharding_dim),None,None],>>>)>>>acceptable_shardings.append(all_sharded)>>>>>>returnacceptable_shardings

Note

This API is currently experimental and subject to change

On this page