torch.distributed.tensor#
Created On: Jun 13, 2025 | Last Updated On: Aug 23, 2025
Note
torch.distributed.tensor is currently in alpha state and underdevelopment, we are committing backward compatibility for the most APIs listedin the doc, but there might be API changes if necessary.
PyTorch DTensor (Distributed Tensor)#
PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributedlogic, including sharded storage, operator computation and collective communications across devices/hosts.DTensor could be used to build different parallelism solutions and support sharded state_dict representationwhen working with multi-dimensional sharding.
Please see examples from the PyTorch native parallelism solutions that are built on top ofDTensor:
DTensor follows the SPMD (single program, multiple data) programming model to empower users towrite distributed program as if it’s asingle-device program with the same convergence property. Itprovides a uniform tensor sharding layout (DTensor Layout) through specifying theDeviceMeshandPlacement:
DeviceMeshrepresents the device topology and the communicators of the cluster usingan n-dimensional array.Placementdescribes the sharding layout of the logical tensor on theDeviceMesh.DTensor supports three types of placements:Shard,ReplicateandPartial.
DTensor Class APIs#
DTensor is atorch.Tensor subclass. This means once aDTensor is created, it could beused in very similar way totorch.Tensor, including running different types of PyTorch operators as ifrunning them in a single device, allowing proper distributed computation for PyTorch operators.
In addition to existingtorch.Tensor methods, it also offers a set of additional methods to interact withtorch.Tensor,redistribute the DTensor Layout to a new DTensor, get the full tensor contenton all devices, etc.
- classtorch.distributed.tensor.DTensor(local_tensor,spec,*,requires_grad)#
DTensor(Distributed Tensor) is a subclass oftorch.Tensorthat provides single-device likeabstraction to program with multi-devicetorch.Tensor. It describes the distributed tensor shardinglayout (DTensor Layout) through theDeviceMeshand following types ofPlacement:Shard: Tensor sharded on the tensor dimensiondimon the devices of theDeviceMeshdimensionReplicate: Tensor replicated on the devices of theDeviceMeshdimensionPartial: Tensor is pending reduction on the devices of theDeviceMeshdimension
When calling PyTorch operators,
DTensoroverrides the PyTorch operators to perform sharded computation and issuecommunications whenever necessary. Along with the operator computation,DTensorwill transform or propagate theplacements (DTensor Layout) properly (based on the operator semantic itself) and generate newDTensoroutputs.To ensure numerical correctness of the
DTensorsharded computation when calling PyTorch operators,DTensorrequires every Tensor argument of the operator be DTensor.Note
Directly using the Tensor subclass constructor here is not the recommended way to create a
DTensor(i.e. it does not handle autograd correctly hence is not the public API). Please refer to thecreate_dtensorsection to see how to create aDTensor.- Return type
- __create_chunk_list__()[source]#
Return a list of ChunkStorageMetadata, which is a dataclass that describes the size/offset of the local shard/replicaon current rank. For DTensor, each rank will have a single local shard/replica, so the returned list usually onlyhas one element.
This dunder method is primariy used for distributed checkpoint purpose.
- Returns
A List[
ChunkStorageMetadata] object that represents the shard size/offset on the current rank.
- staticfrom_local(local_tensor,device_mesh=None,placements=None,*,run_check=False,shape=None,stride=None)[source]#
Create a
DTensorfrom a local torch.Tensor on each rankaccording to thedevice_meshandplacementsspecified.- Parameters
local_tensor (torch.Tensor) – local torch.Tensor on each rank.
device_mesh (
DeviceMesh, optional) – DeviceMesh to place thetensor, if not specified, must be called under a DeviceMeshcontext manager, default: Noneplacements (List[
Placement], optional) – the placements thatdescribes how to place the local torch.Tensor on DeviceMesh, musthave the same number of elements asdevice_mesh.ndim.
- Keyword Arguments
run_check (bool,optional) – at a cost of extra communications, performsanity check across ranks to check each local tensor’s meta informationto ensure correctness. If have
Replicateinplacements, thedata on first rank of the device mesh dimension will be broadcastedto other ranks. default: Falseshape (torch.Size,optional) – A List of int which specifies the size ofDTensor which build on top oflocal_tensor. Note this needs to beprovided if the shape of
local_tensorare different across the ranks.If not provided,shapewill be computed assuming the given distributedtensor is evenly sharded across ranks. default: Nonestride (tuple,optional) – A List of int which specifies the stride of DTensor.If not provided,
stridewill be computed assuming the given distributedtensor is evenly sharded across ranks. default: None
- Returns
A
DTensorobject- Return type
Note
When
run_check=False, it is the user’s responsibility to ensure thelocal tensor passed in is correct across ranks (i.e. the tensor is sharded fortheShard(dim)placement or replicated for theReplicate()placement).If not, the behavior of the created DTensor is undefined.Note
from_localis differentiable, therequires_grad of the createdDTensor object will depend on iflocal_tensor requires_grad or not.
- full_tensor(*,grad_placements=None)[source]#
Return the full tensor of this DTensor. It will perform necessary collectivesto gather the local tensors from other ranks in its DeviceMesh and concatenatethem together. It’s a syntactic sugar of the following code:
dtensor.redistribute(placements=[Replicate()]*mesh.ndim).to_local()- Keyword Arguments
grad_placements (List[
Placement], optional) – the placements describesthe future layout of any gradient layout of the full Tensor returned from thisfunction.full_tensor converts DTensor to a full torch.Tensor and the returned torch.tensormight not be used as the original replicated DTensor layout later in the code. Thisargument is the hint that user can give to autograd in case the gradientlayout of the returned tensor does not match the original replicated DTensor layout.If not specified, we will assume the gradient layout of the full tensor be replicated.- Returns
A
torch.Tensorobject that represents the full tensor of this DTensor.- Return type
Note
full_tensoris differentiable.
- redistribute(device_mesh=None,placements=None,*,async_op=False,forward_dtype=None,backward_dtype=None)[source]#
redistributeperforms necessary collective operations that redistribute the currentDTensor from its current placements to a new placements, or from its current DeviceMeshto a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor byspecifying a Replicate placement for each dimension of the DeviceMesh.When redistributing from current to the new placements on one device mesh dimension, wewill perform the following operations including communication collective or local operation:
Shard(dim)->Replicate():all_gatherShard(src_dim)->Shard(dst_dim):all_to_allReplicate()->Shard(dim): local chunking (i.e.torch.chunk)Partial()->Replicate():all_reducePartial()->Shard(dim):reduce_scatter
redistributewould correctly figure out the necessary redistribute steps for DTensorsthat are created either on 1-D or N-D DeviceMesh.- Parameters
device_mesh (
DeviceMesh, optional) – DeviceMesh to place theDTensor. If not specified, it would use the current DTensor’s DeviceMesh.default: Noneplacements (List[
Placement], optional) – the new placements thatdescribes how to place the DTensor into the DeviceMesh, musthave the same number of elements asdevice_mesh.ndim.default: replicate on all mesh dimensions
- Keyword Arguments
async_op (bool,optional) – whether to perform the DTensor redistribute operationasynchronously or not. Default: False
forward_dtype (torch.dtype,optional) – the local tensor datatype can be converted to
forward_dtypebefore redistributing the local tensor in its forward.The result DTensor will be inforward_dtypeDefault: None.backward_dtype (torch.dtype,optional) – the local tensor datatype can be converted to
backward_dtypebefore redistributing the local tensor in its backward.The result DTensor gradient would be converted back to the current DTensor dtype. Default: None
- Returns
A
DTensorobject- Return type
Note
redistributeis differentiable, which means user do not need to worry aboutthe backward formula of the redistribute operation.Note
redistributecurrently only supports redistributing DTensor on the same DeviceMesh,Please file an issue if you need to redistribute DTensor to different DeviceMesh.
- to_local(*,grad_placements=None)[source]#
Get the local tensor of this DTensor on its current rank. For sharding it returnsa local shard of the logical tensor view, for replication it returns the replica onits current rank.
- Keyword Arguments
grad_placements (List[
Placement], optional) – the placements describesthe future layout of any gradient layout of the Tensor returned from thisfunction.to_local converts DTensor to local tensor and the returned local tensormight not be used as the original DTensor layout later in the code. Thisargument is the hint that user can give to autograd in case the gradientlayout of the returned tensor does not match the original DTensor layout.If not specified, we will assume the gradient layout remains the sameas the original DTensor and use that for gradient computation.- Returns
A
torch.TensororAsyncCollectiveTensorobject. it represents thelocal tensor on its current rank. When anAsyncCollectiveTensorobject is returned,it means the local tensor is not ready yet (i.e. communication is not finished). In thiscase, user needs to callwaitto wait the local tensor to be ready.- Return type
Note
to_localis differentiable, therequires_gradof the local tensor returnedwill depend on if theDTensor requires_grad or not.
- propertydevice_mesh:DeviceMesh#
The
DeviceMeshattribute that associates with this DTensor object.Note
device_meshis a read-only property, it can not be set.
- propertyplacements:tuple[torch.distributed.tensor.placement_types.Placement,...]#
The placements attribute of this DTensor that describes the layout of thisDTensor on the its DeviceMesh.
Note
placementsis a read-only property, it can not be set.
DeviceMesh as the distributed communicator#
DeviceMesh was built from DTensor as the abstraction to describe cluster’s device topology and representmulti-dimensional communicators (on top ofProcessGroup). To see the details of how to create/use a DeviceMesh,please refer to theDeviceMesh recipe.
DTensor Placement Types#
DTensor supports the following types ofPlacement on eachDeviceMesh dimension:
- classtorch.distributed.tensor.placement_types.Shard(dim)[source]#
The
Shard(dim)placement describes the DTensor sharding on tensor dimensiondimover a correspondingDeviceMeshdimension, where each rank on theDeviceMesh dimension only holds a shard/piece of the global Tensor. TheShard(dim)placement follows thetorch.chunk(dim)semantic, where thelast few shards on the DeviceMesh dimension might be empty when the tensor dimensionis not evenly divisible on the DeviceMesh dimension. TheShardplacement can beused by all DTensor APIs (i.e. distribute_tensor, from_local, etc.)- Parameters
dim (int) – The tensor dimension that describes the DTensor is sharded over itscorresponding DeviceMesh dimension.
Warning
sharding on a tensor dimension where the tensor dimension size is notevenly divisible on a DeviceMesh dimension is currently experimental and subject to change.
- classtorch.distributed.tensor.placement_types.Replicate[source]#
The
Replicate()placement describes the DTensor replicating on a correspondingDeviceMeshdimension, where each rank on the DeviceMesh dimension holds areplica of the global Tensor. TheReplicateplacement can be used by allDTensor APIs (i.e.distribute_tensor,DTensor.from_local, etc.)
- classtorch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source]#
The
Partial(reduce_op)placement describes the DTensor that is pendingreduction on a specifiedDeviceMeshdimension, where each rank on theDeviceMesh dimension holds the partial value of the global Tensor. User canredistribute thePartialDTensor to aReplicateorShard(dim)placement on the specifiedDeviceMeshdimension usingredistribute,which would trigger necessary communication operations under the hood (i.e.allreduce,reduce_scatter).- Parameters
reduce_op (str,optional) – The reduction op to be used for the partial DTensorto produce Replicated/Sharded DTensor. Only element-wise reduction operationsare supported, including: “sum”, “avg”, “product”, “max”, “min”, default: “sum”.
Note
The
Partialplacement can be generated as a result of the DTensor operators,and can only be used by theDTensor.from_localAPI.
- classtorch.distributed.tensor.placement_types.Placement[source]#
The base class for the Placement type, where it describes how a DTensor is placed onto the
DeviceMesh.PlacementandDeviceMeshtogether could describe the DTensor Layout.It is the base class of the three main DTensor Placement types:Shard,Replicate,andPartial.This class is not meant to be used directly, mainly served as a typing stub.
Different ways to create a DTensor#
- There’re three ways to construct a
DTensor: distribute_tensor()creates aDTensorfrom a logical or “global”torch.Tensoroneach rank. This could be used to shard the leaftorch.Tensors (i.e. model parameters/buffersand inputs).DTensor.from_local()creates aDTensorfrom a localtorch.Tensoron each rank, which canbe used to createDTensorfrom a non-leaftorch.Tensors (i.e. intermediate activationtensors during forward/backward).DTensor provides dedicated tensor factory functions (e.g.
empty(),ones(),randn(), etc.)to allow differentDTensorcreations by directly specifying theDeviceMeshandPlacement. Compare todistribute_tensor(), this could directly materializing the sharded memoryon device, instead of performing sharding after initializing the logical Tensor memory.
Create DTensor from a logical torch.Tensor#
The SPMD (single program, multiple data) programming model intorch.distributed launches multiple processes(i.e. viatorchrun) to execute the same program, this means that the model inside the program would beinitialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directlyon GPU if enough memory).
DTensor offers adistribute_tensor() API that could shard the model weights or Tensors toDTensor s,where it would create a DTensor from the “logical” Tensor on each process. This would empower the createdDTensor s to comply with the single device semantic, which is critical fornumerical correctness.
- torch.distributed.tensor.distribute_tensor(tensor,device_mesh=None,placements=None,*,src_data_rank=0)[source]#
Distribute a leaf
torch.Tensor(i.e. nn.Parameter/buffers) to thedevice_meshaccordingto theplacementsspecified. The rank ofdevice_meshandplacementsmust be thesame. Thetensorto distribute is the logical or “global” tensor, and the API would usethetensorfrom first rank of the DeviceMesh dimension as the source of truth to preservethe single-device semantic. If you want to construct a DTensor in the middle of the Autogradcomputation, please useDTensor.from_local()instead.- Parameters
tensor (torch.Tensor) – torch.Tensor to be distributed. Note that if youwant to shard a tensor on a dimension that is not evenly divisible bythe number of devices in that mesh dimension, we use
torch.chunksemantic to shard the tensor and scatter the shards. The uneven shardingbehavior is experimental and subject to change.device_mesh (
DeviceMesh, optional) – DeviceMesh to distribute thetensor, if not specified, must be called under a DeviceMesh contextmanager, default: Noneplacements (List[
Placement], optional) – the placements thatdescribes how to place the tensor on DeviceMesh, must have the samenumber of elements asdevice_mesh.ndim. If not specified, we willby default replicate the tensor across thedevice_meshfrom thefirst rank of each dimension of thedevice_mesh.
- Keyword Arguments
src_data_rank (int,optional) – the rank of the source data for the logical/global tensor, it isused by
distribute_tensor()to scatter/broadcast the shards/replicas to other ranks.By default, we usegroup_rank=0on each DeviceMesh dimension as the source data to preservethe single-device semantic. If passingNoneexplicitly,distribute_tensor()simply usesits local data instead of trying to preserve the single-device semantic via scatter/broadcast.Default: 0- Returns
A
DTensororXLAShardedTensorobject.- Return type
Note
When initialize the DeviceMesh with the
xladevice_type,distribute_tensorreturnXLAShardedTensor instead. seethis issuefor more details. The XLA integration is experimental and subject to change.
Along withdistribute_tensor(), DTensor also offers adistribute_module() API to allow easiersharding on thenn.Module level
- torch.distributed.tensor.distribute_module(module,device_mesh=None,partition_fn=None,input_fn=None,output_fn=None)[source]#
This function expose three functions to control the parameters/inputs/outputs of the module:
1. To perform sharding on the module before runtime execution by specifying the
partition_fn(i.e. allow user to convert Module parameters toDTensorparameters according to thepartition_fn specified).2. To control the inputs or outputs of the module during runtime execution byspecifying theinput_fnandoutput_fn. (i.e. convert the input toDTensor, convert the output back totorch.Tensor)- Parameters
module (
nn.Module) – user module to be partitioned.device_mesh (
DeviceMesh) – the device mesh to place the module.partition_fn (Callable) – the function to partition parameters (i.e. shard certainparameters across the
device_mesh). Ifpartition_fnis not specified,by default we replicate all module parameters ofmoduleacross the mesh.input_fn (Callable) – specify the input distribution, i.e. could control how theinput of the module is sharded.
input_fnwill be installed as a moduleforward_pre_hook(pre forward hook).output_fn (Callable) – specify the output distribution, i.e. could control how theoutput is sharded, or convert it back to torch.Tensor.
output_fnwill beinstalled as a moduleforward_hook(post forward hook).
- Returns
A module that contains parameters/buffers that are all
DTensors.- Return type
- Module
Note
When initialize the DeviceMesh with the
xladevice_type,distribute_modulereturn nn.Module with PyTorch/XLA SPMD annotated parameters. Seethis issuefor more details. The XLA integration is experimental and subject to change.
DTensor Factory Functions#
DTensor also provides dedicated tensor factory functions to allow creatingDTensor directlyusing torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionallyspecifying theDeviceMesh andPlacement for theDTensor created:
- torch.distributed.tensor.zeros(*size,requires_grad=False,dtype=None,layout=torch.strided,device_mesh=None,placements=None)[source]#
Returns a
DTensorfilled with the scalar value 0.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor.Can be a variable number of arguments or a collection like a list or tuple.E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))- Keyword Arguments
requires_grad (bool,optional) – If autograd should record operations on thereturned
DTensor. Default:False.dtype (
torch.dtype, optional) – the desired data type of returnedDTensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()).layout (
torch.layout, optional) – the desired layout of returnedDTensor.Default:torch.strided.device_mesh –
DeviceMeshtype, contains the mesh info of ranksplacements – a sequence of
Placementtype:Shard,Replicate
- Returns
A
DTensorobject on each rank- Return type
- torch.distributed.tensor.ones(*size,dtype=None,layout=torch.strided,requires_grad=False,device_mesh=None,placements=None)[source]#
Returns a
DTensorfilled with the scalar value 1, with the shape definedby the variable argumentsize.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor.Can be a variable number of arguments or a collection like a list or tuple.E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))- Keyword Arguments
dtype (
torch.dtype, optional) – the desired data type of returnedDTensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()).layout (
torch.layout, optional) – the desired layout of returned DTensor.Default:torch.strided.requires_grad (bool,optional) – If autograd should record operations on thereturned
DTensor. Default:False.device_mesh –
DeviceMeshtype, contains the mesh info of ranksplacements – a sequence of
Placementtype:Shard,Replicate
- Returns
A
DTensorobject on each rank- Return type
- torch.distributed.tensor.empty(*size,dtype=None,layout=torch.strided,requires_grad=False,device_mesh=None,placements=None)[source]#
Returns a
DTensorfilled with uninitialized data. The shape of theDTensoris defined by the variable argumentsize.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor.Can be a variable number of arguments or a collection like a list or tuple.E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))- Keyword Arguments
dtype (
torch.dtype, optional) – the desired data type of returnedDTensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()). layout (torch.layout, optional): the desired layout of returnedDTensor.Default:torch.strided.requires_grad (bool,optional) – If autograd should record operations on thereturned
DTensor. Default:False.device_mesh –
DeviceMeshtype, contains the mesh info of ranksplacements – a sequence of
Placementtype:Shard,Replicate
- Returns
A
DTensorobject on each rank- Return type
- torch.distributed.tensor.full(size,fill_value,*,dtype=None,layout=torch.strided,requires_grad=False,device_mesh=None,placements=None)[source]#
Returns a
DTensorfilled withfill_valueaccording todevice_meshandplacements, with the shape defined by the argumentsize.- Parameters
- Keyword Arguments
dtype (
torch.dtype, optional) – the desired data type of returnedDTensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()).layout (
torch.layout, optional) – the desired layout of returned DTensor.Default:torch.strided.requires_grad (bool,optional) – If autograd should record operations on thereturned
DTensor. Default:False.device_mesh –
DeviceMeshtype, contains the mesh info of ranks.placements – a sequence of
Placementtype:Shard,Replicate
- Returns
A
DTensorobject on each rank- Return type
- torch.distributed.tensor.rand(*size,requires_grad=False,dtype=None,layout=torch.strided,device_mesh=None,placements=None)[source]#
Returns a
DTensorfilled with random numbers from a uniform distributionon the interval[0,1). The shape of the tensor is defined by the variableargumentsize.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor.Can be a variable number of arguments or a collection like a list or tuple.E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))- Keyword Arguments
dtype (
torch.dtype, optional) – the desired data type of returnedDTensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()).layout (
torch.layout, optional) – the desired layout of returned DTensor.Default:torch.strided.requires_grad (bool,optional) – If autograd should record operations on thereturned
DTensor. Default:False.device_mesh –
DeviceMeshtype, contains the mesh info of ranks.placements – a sequence of
Placementtype:Shard,Replicate
- Returns
A
DTensorobject on each rank- Return type
- torch.distributed.tensor.randn(*size,requires_grad=False,dtype=None,layout=torch.strided,device_mesh=None,placements=None)[source]#
Returns a
DTensorfilled with random numbers from a normal distributionwith mean 0 and variance 1. The shape of the tensor is defined by the variableargumentsize.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor.Can be a variable number of arguments or a collection like a list or tuple.E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))- Keyword Arguments
dtype (
torch.dtype, optional) – the desired data type of returnedDTensor.Default: ifNone, uses a global default (seetorch.set_default_dtype()).layout (
torch.layout, optional) – the desired layout of returned DTensor.Default:torch.strided.requires_grad (bool,optional) – If autograd should record operations on thereturned
DTensor. Default:False.device_mesh –
DeviceMeshtype, contains the mesh info of ranks.placements – a sequence of
Placementtype:Shard,Replicate
- Returns
A
DTensorobject on each rank- Return type
Random Operations#
DTensor provides distributed RNG functionality to ensure that random operations on sharded tensors get unique values, and random operations on replicated tensors get the same values. This system requires that all participatingranks (e.g. SPMD ranks) start out using the same generator state before each dtensor random operation is performed,and if this is true, it ensures they all end up at the same state after each dtensor random operation completes. There is no communication performed during random operations to synchronize RNG states.
Operators that accept agenerator kwarg will utilize the user-passed generator, if passed, or the default generator for the device otherwise. Whichever generator is used, it will be advanced after the DTensor operation. It is valid to use the same generator for both DTensor and non-DTensor operations, but care must be taken to ensure the non-DTensor operations advance the generator state equally on all ranks if so.
When using DTensor together with Pipeline Parallelism, ranks for each pipeline stage should use a distinct seed, and ranks within a pipeline stage should use the same seed.
DTensor’s RNG infra is based on the philox based RNG algorithm, and supports any philox based backend (cuda, and other cuda-like devices), but unfortunately does not yet support the CPU backend.
Debugging#
Logging#
When launching the program, you can turn on additional logging using theTORCH_LOGS environment variable fromtorch._logging :
TORCH_LOGS=+dtensorwill displaylogging.DEBUGmessages and all levels above it.TORCH_LOGS=dtensorwill displaylogging.INFOmessages and above.TORCH_LOGS=-dtensorwill displaylogging.WARNINGmessages and above.
Debugging Tools#
To debug the program that applied DTensor, and understand more details about what collectives happened under thehood, DTensor provides aCommDebugMode:
- classtorch.distributed.tensor.debug.CommDebugMode#
CommDebugModeis a context manager that counts the number offunctional collectives within its context. It does this using aTorchDispatchMode.Note
Not all collectives are supported yet.
Example usage
mod=...comm_mode=CommDebugMode()withcomm_mode:mod.sum().backward()print(comm_mode.get_comm_counts())
- generate_comm_debug_tracing_table(noise_level=3)[source]#
Generates detailed table displaying operations and collective tracing informationon a module level. Amount of information is dependent on noise_level
prints module-level collective counts
prints dTensor operations not included in trivial operations, module information
prints operations not included in trivial operations
prints all operations
- generate_json_dump(file_name='comm_mode_log.json',noise_level=3)[source]#
Creates json file used to build browser visual0. prints module-level collective counts1. prints dTensor operations not included in trivial operations2. prints operations not included in trivial operations3. prints all operations
To visualize the sharding of a DTensor that have less than 3 dimensions, DTensor providesvisualize_sharding():
Experimental Features#
DTensor also provides a set of experimental features. These features are either in prototyping stage, or the basicfunctionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks tothese features.
- torch.distributed.tensor.experimental.context_parallel(mesh,*,buffers=None,buffer_seq_dims=None,no_restore_buffers=None)[source]#
context_parallelis an experimental API to enable contextparallelism (CP). This API performs two actions: 1) patch the SDPA(torch.nn.functional.scaled_dot_product_attention) with the CP-enabledone, 2) shardbuffersalong the sequence dimension and each rank willpreserve the corresponding shard accordingmesh.- Parameters
mesh (
DeviceMesh) – the device mesh for the context parallelism.buffers (Optional[List[torch.Tensor]]) – buffers that the usage dependon the sequence dimension. Examples are input batch, labels andpositional embedding buffers. These buffers must be sharded alongthe sequence dimension to ensure the accuracy. The sharding willhappen in-place, the buffer’s shape will change within the context.The buffers will be restored after the context finishes.
no_restore_bufferscan be used to specify which buffers don’tneed to be restored. Note thatbuffersshould not contain anynn.Parameter.buffer_seq_dims (Optional[List[int]]) – the sequence dimensions of
buffers.no_restore_buffers (Optional[Set[torch.Tensor]]) – buffers in these setwon’t be restored after the context exits. This set must be a subsetof
buffers. If the buffers won’t be used after the context exits,these buffers can be put in this list to avoid extra restore time.
- Return type
Generator[None, None, None]
Warning
torch.distributed.tensor.experimental.context_parallel is aprototype feature in PyTorch. The API is subject to change.
- torch.distributed.tensor.experimental.local_map(func=None,out_placements=None,in_placements=None,in_grad_placements=None,device_mesh=None,*,redistribute_inputs=False)[source]#
local_map()is an experimental API that allows users to passDTensorsto a function that is written to be applied ontorch.Tensors. It is done by extractingthe local components ofDTensor, call the function, and wrap the outputs toDTensoraccording to theout_placements.- Parameters
func (Callable) – the function to be applied on each local shard of
DTensors.out_placements (Union[PlacementType, Tuple[PlacementType, …]]) – the desired placements of the
DTensors infunc’s flattened output.If the flattenedoutputis a single value, theout_placementsshould beof typePlacementType. Otherwise if the flattenedoutputhas multiplevalues, theout_placementsshould be a tuple ofPlacementType values 1:1mapping to the flattenedoutput.Besides, forTensoroutput, we usePlacementType as itsplacements (aTuple[Placement] value). For non-Tensor output, thePlacementTypeshould beNone.Note that the only exception is when noDTensorargument is passedin. In this case, even ifout_placements is notNone, the result functionshould ignore the desired placements because the function is not running withDTensors.in_placements (Tuple[PlacementType, …], optional) – the required placements of the
DTensors in the flattened inputs offunc.Ifin_placementsis specified,local_map()would examine whether theplacements of eachDTensorargument is the same as the requiredplacements or not. If the placements are not the same andredistribute_inputsisFalse, an exception will be raised. Otherwise ifredistribute_inputsisTrue, the argument will be first redistributed tothe required sharding placements before passing its local tensor tofunc.The only exception is when required placements are notNoneand theargument is atorch.Tensor. In this case, the placements examinationwill be skipped and the argument will be directly passed tofunc.Ifin_placementsisNone, no placements examination will be performed.Default: Nonein_grad_placements (Tuple[PlacementType, …], optional) – the placements hint of the
DTensors gradient correspondsto the flattened input DTensor. This argument is the hint that usercan give toto_local()in case the gradient layout of thelocal tensor input does not match itsDTensorinput layout.If not specified, we will assume the gradient layout of the localtensor input remains the same as the originalDTensorinputand use that for gradient computation. Default: None.device_mesh (
DeviceMesh, optional) – the device mesh that the outputDTensors are placed on. If notspecified, this will be inferred from the first inputDTensor’s devicemesh. Default: None.
- Keyword Arguments
redistribute_inputs (bool,optional) – the bool value indicating whether to reshard the input
DTensors whentheir placements are different from the required input placements. If thisvalue isFalseand someDTensorinput has a different placement,an exception will be raised. Default: False.- Returns
A
Callablethat appliesfuncto each local shard of the inputDTensorand returns aDTensorconstructed from the return value offunc.- Raises
AssertionError – For any non-DTensor output, we require its corresponding output placement in
out_placementsbe None. An AssertionError will be raised if this is not the case.ValueError – If
redistribute_inputs=Falsebut the inputDTensorneeds a redistribution according toin_placements.
Example
>>>defmm_allreduce_forward(device_mesh,W,X):>>>partial_sum_tensor=torch.mm(W,X)>>>reduced_tensor=funcol.all_reduce(partial_sum_tensor,"sum",device_mesh)>>>returnreduced_tensor>>>>>>W=torch.randn(12,8,requires_grad=False)>>>X=torch.randn(8,16,requires_grad=False)>>>Y=torch.mm(W,X)>>>row_wise=[Shard(0)]# row-wise sharding placements on 1-d mesh>>>col_wise=[Shard(1)]# col-wise sharding placements on 1-d mesh>>>>>># local_mm_allreduce_forward is the function wrapped with DTensor/Tensor conversion>>>local_mm_allreduce_forward=local_map(>>>mm_allreduce_forward,>>>out_placements=[Replicate()],>>>in_placements=[col_wise,row_wise],>>>device_mesh=device_mesh,>>>)>>>>>>W_dt=distribute_tensor(...W,device_mesh,(col_wise)...)# col-wisely sharded W tensor>>>X_dt=distribute_tensor(...X,device_mesh,(row_wise)...)# row-wisely sharded X tensor>>>Y_dt=local_mm_allreduce_forward(...device_mesh,W_dt,X_dt...)# apply local_mm_allreduce_forward to DTensors
Note
This API is currently experimental and subject to change
- torch.distributed.tensor.experimental.register_sharding(op)[source]#
register_sharding()is an experimental API that allows users to register shardingstrategies for an operator when the tensor inputs and outputs are DTensor.It can be useful when: (1) there doesn’t exist a default sharding strategy forop,e.g. whenopis a custom operator that is not supported byDTensor; (2)when users would like to overwrite default sharding strategies of existing operators.- Parameters
op (Union[OpOverload,List[OpOverload]]) – An op or a list of ops to register the customized sharding function.
- Returns
A function decorator which can be used to wrap a function that defines the shardingstrategy for the operator specified in
op. The defined sharding strategy will beregistered to DTensor and will override the default sharding strategy if DTensor hasalready implemented the operator. The customized sharding function takes the same inputsas the original op (except that if an arg is atorch.Tensor, it will bereplaced by a tensor-like object that DTensor uses internally). The function shouldreturn a sequence of 2-tuples, each specifying acceptable output placements and itscorresponding input placements.
Example
>>>@register_sharding(aten._softmax.default)>>>defcustom_softmax_sharding(x,dim,half_to_float):>>>softmax_dim=dimifdim>=0elsedim+x.ndim>>>acceptable_shardings=[]>>>>>>all_replicate=([Replicate()],[Replicate(),None,None])>>>acceptable_shardings.append(all_replicate)>>>>>>forsharding_diminrange(x.ndim):>>>ifsharding_dim!=softmax_dim:>>>all_sharded=(>>>[Shard(sharding_dim)],>>>[Shard(sharding_dim),None,None],>>>)>>>acceptable_shardings.append(all_sharded)>>>>>>returnacceptable_shardings
Note
This API is currently experimental and subject to change