Rate this Page

Distributed communication package - torch.distributed#

Created On: Jul 12, 2017 | Last Updated On: Sep 04, 2025

Note

Please refer toPyTorch Distributed Overviewfor a brief introduction to all features related to distributed training.

Backends#

torch.distributed supports four built-in backends, each withdifferent capabilities. The table below shows which functions are availablefor use with a CPU or GPU for each backend. For NCCL, GPU refers to CUDA GPUwhile for XCCL to XPU GPU.

MPI supports CUDA only if the implementation used to build PyTorch supports it.

Backend

gloo

mpi

nccl

xccl

Device

CPU

GPU

CPU

GPU

CPU

GPU

CPU

GPU

send

?

recv

?

broadcast

?

all_reduce

?

reduce

?

all_gather

?

gather

?

scatter

?

reduce_scatter

all_to_all

?

barrier

?

Backends that come with PyTorch#

PyTorch distributed package supports Linux (stable), MacOS (stable), and Windows (prototype).By default for Linux, the Gloo and NCCL backends are built and included in PyTorchdistributed (NCCL only when building with CUDA). MPI is an optional backend that can only beincluded if you build PyTorch from source. (e.g. building PyTorch on a host that has MPIinstalled.)

Note

As of PyTorch v1.8, Windows supports all collective communications backend but NCCL,If theinit_method argument ofinit_process_group() points to a file it must adhereto the following schema:

  • Local file system,init_method="file:///d:/tmp/some_file"

  • Shared file system,init_method="file://////{machine_name}/{share_folder_name}/some_file"

Same as on Linux platform, you can enable TcpStore by setting environment variables,MASTER_ADDR and MASTER_PORT.

Which backend to use?#

In the past, we were often asked: “which backend should I use?”.

  • Rule of thumb

    • Use the NCCL backend for distributed training with CUDAGPU.

    • Use the XCCL backend for distributed training with XPUGPU.

    • Use the Gloo backend for distributed training withCPU.

  • GPU hosts with InfiniBand interconnect

    • Use NCCL, since it’s the only backend that currently supportsInfiniBand and GPUDirect.

  • GPU hosts with Ethernet interconnect

    • Use NCCL, since it currently provides the best distributed GPUtraining performance, especially for multiprocess single-node ormulti-node distributed training. If you encounter any problem withNCCL, use Gloo as the fallback option. (Note that Gloo currentlyruns slower than NCCL for GPUs.)

  • CPU hosts with InfiniBand interconnect

    • If your InfiniBand has enabled IP over IB, use Gloo, otherwise,use MPI instead. We are planning on adding InfiniBand support forGloo in the upcoming releases.

  • CPU hosts with Ethernet interconnect

    • Use Gloo, unless you have specific reasons to use MPI.

Common environment variables#

Choosing the network interface to use#

By default, both the NCCL and Gloo backends will try to find the right network interface to use.If the automatically detected interface is not correct, you can override it using the followingenvironment variables (applicable to the respective backend):

  • NCCL_SOCKET_IFNAME, for exampleexportNCCL_SOCKET_IFNAME=eth0

  • GLOO_SOCKET_IFNAME, for exampleexportGLOO_SOCKET_IFNAME=eth0

If you’re using the Gloo backend, you can specify multiple interfaces by separatingthem by a comma, like this:exportGLOO_SOCKET_IFNAME=eth0,eth1,eth2,eth3.The backend will dispatch operations in a round-robin fashion across these interfaces.It is imperative that all processes specify the same number of interfaces in this variable.

Other NCCL environment variables#

Debugging - in case of NCCL failure, you can setNCCL_DEBUG=INFO to print an explicitwarning message as well as basic NCCL initialization information.

You may also useNCCL_DEBUG_SUBSYS to get more details about a specificaspect of NCCL. For example,NCCL_DEBUG_SUBSYS=COLL would print logs ofcollective calls, which may be helpful when debugging hangs, especially thosecaused by collective type or message size mismatch. In case of topologydetection failure, it would be helpful to setNCCL_DEBUG_SUBSYS=GRAPHto inspect the detailed detection result and save as reference if further helpfrom NCCL team is needed.

Performance tuning - NCCL performs automatic tuning based on its topology detection to save users’tuning effort. On some socket-based systems, users may still try tuningNCCL_SOCKET_NTHREADS andNCCL_NSOCKS_PERTHREAD to increase socketnetwork bandwidth. These two environment variables have been pre-tuned by NCCLfor some cloud providers, such as AWS or GCP.

For a full list of NCCL environment variables, please refer toNVIDIA NCCL’s official documentation

You can tune NCCL communicators even further usingtorch.distributed.ProcessGroupNCCL.NCCLConfigandtorch.distributed.ProcessGroupNCCL.Options. Learn more about them usinghelp(e.g.help(torch.distributed.ProcessGroupNCCL.NCCLConfig)) in the interpreter.

Basics#

Thetorch.distributed package provides PyTorch support and communication primitivesfor multiprocess parallelism across several computation nodes running on one or moremachines. The classtorch.nn.parallel.DistributedDataParallel() builds on thisfunctionality to provide synchronous distributed training as a wrapper around anyPyTorch model. This differs from the kinds of parallelism provided byMultiprocessing package - torch.multiprocessing andtorch.nn.DataParallel() in that it supportsmultiple network-connected machines and in that the user must explicitly launch a separatecopy of the main training script for each process.

In the single-machine synchronous case,torch.distributed or thetorch.nn.parallel.DistributedDataParallel() wrapper may still have advantages over otherapproaches to data-parallelism, includingtorch.nn.DataParallel():

  • Each process maintains its own optimizer and performs a complete optimization step with eachiteration. While this may appear redundant, since the gradients have already been gatheredtogether and averaged across processes and are thus the same for every process, this meansthat no parameter broadcast step is needed, reducing time spent transferring tensors betweennodes.

  • Each process contains an independent Python interpreter, eliminating the extra interpreteroverhead and “GIL-thrashing” that comes from driving several execution threads, modelreplicas, or GPUs from a single Python process. This is especially important for models thatmake heavy use of the Python runtime, including models with recurrent layers or many smallcomponents.

Initialization#

The package needs to be initialized using thetorch.distributed.init_process_group()ortorch.distributed.device_mesh.init_device_mesh() function before calling any other methods.Both block until all processes have joined.

Warning

Initialization is not thread-safe. Process group creation should be performed from a single thread, to preventinconsistent ‘UUID’ assignment across ranks, and to prevent races during initialization that can lead to hangs.

torch.distributed.is_available()[source]#

ReturnTrue if the distributed package is available.

Otherwise,torch.distributed does not expose any other APIs. Currently,torch.distributed is available on Linux, MacOS and Windows. SetUSE_DISTRIBUTED=1 to enable it when building PyTorch from source.Currently, the default value isUSE_DISTRIBUTED=1 for Linux and Windows,USE_DISTRIBUTED=0 for MacOS.

Return type

bool

torch.distributed.init_process_group(backend=None,init_method=None,timeout=None,world_size=-1,rank=-1,store=None,group_name='',pg_options=None,device_id=None)[source]#

Initialize the default distributed process group.

This will also initialize the distributed package.

There are 2 main ways to initialize a process group:
  1. Specifystore,rank, andworld_size explicitly.

  2. Specifyinit_method (a URL string) which indicates where/howto discover peers. Optionally specifyrank andworld_size,or encode all required parameters in the URL and omit them.

If neither is specified,init_method is assumed to be “env://”.

Parameters
  • backend (str orBackend,optional) – The backend to use. Depending onbuild-time configurations, valid values includempi,gloo,nccl,ucc,xccl or one that is registered by a third-partyplugin.Since 2.6, ifbackend is not provided, c10d will use a backendregistered for the device type indicated by thedevice_id kwarg(if provided). The known default registrations today are:ncclforcuda,gloo forcpu,xccl forxpu.If neitherbackend nordevice_id is provided, c10d willdetect the accelerator on the run-time machine and use a backendregistered for that detected accelerator (orcpu).This field can be given as a lowercase string (e.g.,"gloo"),which can also be accessed viaBackend attributes (e.g.,Backend.GLOO).If using multiple processes per machine withnccl backend, eachprocess must have exclusive access to every GPU it uses, as sharingGPUs between processes can result in deadlock or NCCL invalid usage.ucc backend is experimental.Default backend for the device can be queried withget_default_backend_for_device().

  • init_method (str,optional) – URL specifying how to initialize theprocess group. Default is “env://” if noinit_method orstore is specified.Mutually exclusive withstore.

  • world_size (int,optional) – Number of processes participating inthe job. Required ifstore is specified.

  • rank (int,optional) – Rank of the current process (it should be anumber between 0 andworld_size-1).Required ifstore is specified.

  • store (Store,optional) – Key/value store accessible to all workers, usedto exchange connection/address information.Mutually exclusive withinit_method.

  • timeout (timedelta,optional) – Timeout for operations executed againstthe process group. Default value is 10 minutes for NCCL and 30 minutes for other backends.This is the duration after which collectives will be aborted asynchronously and the process will crash.This is done since CUDA execution is async and it is no longer safe to continue executing user code sincefailed async NCCL operations might result in subsequent CUDA operations running on corrupted data.When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout.

  • group_name (str,optional,deprecated) – Group name. This argument is ignored

  • pg_options (ProcessGroupOptions,optional) – process group optionsspecifying what additional options need to be passed in duringthe construction of specific process groups. As of now, the onlyoptions we support isProcessGroupNCCL.Options for thencclbackend,is_high_priority_stream can be specified so thatthe nccl backend can pick up high priority cuda streams whenthere’re compute kernels waiting. For other available options to config nccl,Seehttps://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t

  • device_id (torch.device |int,optional) – a single, specific devicethis process will work on, allowing for backend-specificoptimizations. Currently this has two effects, only underNCCL: the communicator is immediately formed (callingncclCommInit* immediately rather than the normal lazycall) and sub-groups will usencclCommSplit whenpossible to avoid unnecessary overhead of group creation. If youwant to know NCCL initialization error early, you can also use thisfield. If anint is provided, the API assumes that the acceleratortype at compile time will be used.

Note

To enablebackend==Backend.MPI, PyTorch needs to be built from sourceon a system that supports MPI.

Note

Support for multiple backends is experimental. Currently when no backend isspecified, bothgloo andnccl backends will be created. Thegloo backendwill be used for collectives with CPU tensors and thenccl backend will be usedfor collectives with CUDA tensors. A custom backend can be specified by passing ina string with format “<device_type>:<backend_name>,<device_type>:<backend_name>”, e.g.“cpu:gloo,cuda:custom_backend”.

torch.distributed.device_mesh.init_device_mesh(device_type,mesh_shape,*,mesh_dim_names=None,backend_override=None)[source]#

Initializes aDeviceMesh based ondevice_type,mesh_shape, andmesh_dim_names parameters.

This creates a DeviceMesh with an n-dimensional array layout, wheren is the length ofmesh_shape.Ifmesh_dim_names is provided, each dimension is labeled asmesh_dim_names[i].

Note

init_device_mesh follows SPMD programming model, meaning the same PyTorch Python programruns on all processes/ranks in the cluster. Ensuremesh_shape (the dimensions of the nD arraydescribing device layout) is identical across all ranks. Inconsistentmesh_shape may lead to hanging.

Note

If no process group is found, init_device_mesh will initialize distributed process group/groupsrequired for distributed communications behind the scene.

Parameters
  • device_type (str) – The device type of the mesh. Currently supports: “cpu”, “cuda/cuda-like”, “xpu”.Passing in a device type with a GPU index, such as “cuda:0”, is not allowed.

  • mesh_shape (Tuple[int]) – A tuple defining the dimensions of the multi-dimensional arraydescribing the layout of devices.

  • mesh_dim_names (Tuple[str],optional) – A tuple of mesh dimension names to assign to each dimensionof the multi-dimensional array describing the layout of devices. Its length must match the lengthofmesh_shape. Each string inmesh_dim_names must be unique.

  • backend_override (Dict[int |str,tuple[str,Options]|str |Options],optional) – Overrides for some or all ofthe ProcessGroups that will be created for each mesh dimension. Each key can be either the index of adimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the nameof the backend and its options, or just one of these two components (in which case the other will beset to its default value).

Returns

ADeviceMesh object representing the device layout.

Return type

DeviceMesh

Example:

>>>fromtorch.distributed.device_meshimportinit_device_mesh>>>>>>mesh_1d=init_device_mesh("cuda",mesh_shape=(8,))>>>mesh_2d=init_device_mesh("cuda",mesh_shape=(2,8),mesh_dim_names=("dp","tp"))
torch.distributed.is_initialized()[source]#

Check if the default process group has been initialized.

Return type

bool

torch.distributed.is_mpi_available()[source]#

Check if the MPI backend is available.

Return type

bool

torch.distributed.is_nccl_available()[source]#

Check if the NCCL backend is available.

Return type

bool

torch.distributed.is_gloo_available()[source]#

Check if the Gloo backend is available.

Return type

bool

torch.distributed.distributed_c10d.is_xccl_available()[source]#

Check if the XCCL backend is available.

Return type

bool

torch.distributed.is_torchelastic_launched()[source]#

Check whether this process was launched withtorch.distributed.elastic (aka torchelastic).

The existence ofTORCHELASTIC_RUN_ID environmentvariable is used as a proxy to determine whether the current processwas launched with torchelastic. This is a reasonable proxy sinceTORCHELASTIC_RUN_ID maps to the rendezvous id which is always anon-null value indicating the job id for peer discovery purposes..

Return type

bool

torch.distributed.get_default_backend_for_device(device)[source]#

Return the default backend for the given device.

Parameters

device (Union[str,torch.device]) – The device to get the default backend for.

Returns

The default backend for the given device as a lower case string.

Return type

str


Currently three initialization methods are supported:

TCP initialization#

There are two ways to initialize using TCP, both requiring a network addressreachable from all processes and a desiredworld_size. The first wayrequires specifying an address that belongs to the rank 0 process. Thisinitialization method requires that all processes have manually specified ranks.

Note that multicast address is not supported anymore in the latest distributedpackage.group_name is deprecated as well.

importtorch.distributedasdist# Use address of one of the machinesdist.init_process_group(backend,init_method='tcp://10.1.1.20:23456',rank=args.rank,world_size=4)

Shared file-system initialization#

Another initialization method makes use of a file system that is shared andvisible from all machines in a group, along with a desiredworld_size. The URL should startwithfile:// and contain a path to a non-existent file (in an existingdirectory) on a shared file system. File-system initialization will automaticallycreate that file if it doesn’t exist, but will not delete the file. Therefore, itis your responsibility to make sure that the file is cleaned up before the nextinit_process_group() call on the same file path/name.

Note that automatic rank assignment is not supported anymore in the latestdistributed package andgroup_name is deprecated as well.

Warning

This method assumes that the file system supports locking usingfcntl - mostlocal systems and NFS support it.

Warning

This method will always create the file and try its best to clean up and removethe file at the end of the program. In other words, each initialization withthe file init method will need a brand new empty file in order for the initializationto succeed. If the same file used by the previous initialization (which happens notto get cleaned up) is used again, this is unexpected behavior and can often causedeadlocks and failures. Therefore, even though this method will try its best to clean upthe file, if the auto-delete happens to be unsuccessful, it is your responsibilityto ensure that the file is removed at the end of the training to prevent the samefile to be reused again during the next time. This is especially importantif you plan to callinit_process_group() multiple times on the same file name.In other words, if the file is not removed/cleaned up and you callinit_process_group() again on that file, failures are expected.The rule of thumb here is that, make sure that the file is non-existent orempty every timeinit_process_group() is called.

importtorch.distributedasdist# rank should always be specifieddist.init_process_group(backend,init_method='file:///mnt/nfs/sharedfile',world_size=4,rank=args.rank)

Environment variable initialization#

This method will read the configuration from environment variables, allowingone to fully customize how the information is obtained. The variables to be setare:

  • MASTER_PORT - required; has to be a free port on machine with rank 0

  • MASTER_ADDR - required (except for rank 0); address of rank 0 node

  • WORLD_SIZE - required; can be set either here, or in a call to init function

  • RANK - required; can be set either here, or in a call to init function

The machine with rank 0 will be used to set up all connections.

This is the default method, meaning thatinit_method does not have to be specified (orcan beenv://).

Improving initialization time#

  • TORCH_GLOO_LAZY_INIT - establishes connections on demand rather thanusing a full mesh which can greatly improve initialization time for non all2alloperations.

Post-Initialization#

Oncetorch.distributed.init_process_group() was run, the following functions can be used. Tocheck whether the process group has already been initialized usetorch.distributed.is_initialized().

classtorch.distributed.Backend(name)[source]#

An enum-like class for backends.

Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends.

The values of this class are lowercase strings, e.g.,"gloo". They canbe accessed as attributes, e.g.,Backend.NCCL.

This class can be directly called to parse the string, e.g.,Backend(backend_str) will check ifbackend_str is valid, andreturn the parsed lowercase string if so. It also accepts uppercase strings,e.g.,Backend("GLOO") returns"gloo".

Note

The entryBackend.UNDEFINED is present but only used asinitial value of some fields. Users should neither use it directlynor assume its existence.

classmethodregister_backend(name,func,extended_api=False,devices=None)[source]#

Register a new backend with the given name and instantiating function.

This class method is used by 3rd partyProcessGroup extension toregister new backends.

Parameters
  • name (str) – Backend name of theProcessGroup extension. Itshould match the one ininit_process_group().

  • func (function) – Function handler that instantiates the backend.The function should be implemented in the backendextension and takes four arguments, includingstore,rank,world_size, andtimeout.

  • extended_api (bool,optional) – Whether the backend supports extended argument structure.Default:False. If set toTrue, the backendwill get an instance ofc10d::DistributedBackendOptions, anda process group options object as defined by the backend implementation.

  • device (str orlist ofstr,optional) – device type this backendsupports, e.g. “cpu”, “cuda”, etc. IfNone,assuming both “cpu” and “cuda”

Note

This support of 3rd party backend is experimental and subject to change.

torch.distributed.get_backend(group=None)[source]#

Return the backend of the given process group.

Parameters

group (ProcessGroup,optional) – The process group to work on. Thedefault is the general main process group. If another specific groupis specified, the calling process must be part ofgroup.

Returns

The backend of the given process group as a lower case string.

Return type

Backend

torch.distributed.get_rank(group=None)[source]#

Return the rank of the current process in the providedgroup, default otherwise.

Rank is a unique identifier assigned to each process within a distributedprocess group. They are always consecutive integers ranging from 0 toworld_size.

Parameters

group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

Returns

The rank of the process group-1, if not part of the group

Return type

int

torch.distributed.get_world_size(group=None)[source]#

Return the number of processes in the current process group.

Parameters

group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

Returns

The world size of the process group-1, if not part of the group

Return type

int

Shutdown#

It is important to clean up resources on exit by callingdestroy_process_group().

The simplest pattern to follow is to destroy every process group and backend by callingdestroy_process_group() with the default value of None for thegroup argument, at apoint in the training script where communications are no longer needed, usually near theend of main(). The call should be made once per trainer-process, not at the outerprocess-launcher level.

ifdestroy_process_group() is not called by all ranks in a pg within the timeout duration,especially when there are multiple process-groups in the application e.g. for N-D parallelism,hangs on exit are possible. This is because the destructor for ProcessGroupNCCL calls ncclCommAbort,which must be called collectively, but the order of calling ProcessGroupNCCL’s destructor if calledby python’s GC is not deterministic. Callingdestroy_process_group() helps by ensuringncclCommAbort is called in a consistent order across ranks, and avoids calling ncclCommAbortduring ProcessGroupNCCL’s destructor.

Reinitialization#

destroy_process_group can also be used to destroy individual process groups. One usecase could be fault tolerant training, where a process group may be destroyed and thena new one initialized during runtime. In this case, it’s critical to synchronize the trainerprocesses using some means other than torch.distributed primitives _after_ calling destroy andbefore subsequently initializing. This behavior is currently unsupported/untested, due tothe difficulty of achieving this synchronization, and is considered a known issue. Please filea github issue or RFC if this is a use case that’s blocking you.


Groups#

By default collectives operate on the default group (also called the world) andrequire all processes to enter the distributed function call. However, some workloads can benefitfrom more fine-grained communication. This is where distributed groups comeinto play.new_group() function can beused to create new groups, with arbitrary subsets of all processes. It returnsan opaque group handle that can be given as agroup argument to all collectives(collectives are distributed functions to exchange information in certain well-known programming patterns).

torch.distributed.new_group(ranks=None,timeout=None,backend=None,pg_options=None,use_local_synchronization=False,group_desc=None,device_id=None)[source]#

Create a new distributed group.

This function requires that all processes in the main group (i.e. allprocesses that are part of the distributed job) enter this function, evenif they are not going to be members of the group. Additionally, groupsshould be created in the same order in all processes.

Warning

Safe concurrent usage:When using multiple process groups with theNCCL backend, the usermust ensure a globally consistent execution order of collectives acrossranks.

If multiple threads within a process issue collectives, explicitsynchronization is necessary to ensure consistent ordering.

When using async variants of torch.distributed communication APIs,a work object is returned and the communication kernel isenqueued on a separate CUDA stream, allowing overlap of communicationand computation. Once one or more async ops have been issued on one processgroup, they must be synchronized with other cuda streams by callingwork.wait()before using another process group.

SeeUsing multiple NCCL communicators concurrently<https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using-multiple-nccl-communicators-concurrently>for more details.

Parameters
  • ranks (list[int]) – List of ranks of group members. IfNone, will beset to all ranks. Default isNone.

  • timeout (timedelta,optional) – seeinit_process_group for details and default value.

  • backend (str orBackend,optional) – The backend to use. Depending onbuild-time configurations, valid values aregloo andnccl.By default uses the same backend as the global group. This fieldshould be given as a lowercase string (e.g.,"gloo"), which canalso be accessed viaBackend attributes (e.g.,Backend.GLOO). IfNone is passed in, the backendcorresponding to the default process group will be used. Default isNone.

  • pg_options (ProcessGroupOptions,optional) – process group optionsspecifying what additional options need to be passed in duringthe construction of specific process groups. i.e. for thencclbackend,is_high_priority_stream can be specified so thatprocess group can pick up high priority cuda streams. For other available options to config nccl,Seehttps://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-tuse_local_synchronization(bool, optional): perform a group-local barrier at the end of the process group creation.This is different in that non-member ranks don’t need to call into API and don’tjoin the barrier.

  • group_desc (str,optional) – a string to describe the process group.

  • device_id (torch.device,optional) – a single, specific deviceto “bind” this process to, Thenew_group call will try to initializea communication backend immediately for the device if this field is given.

Returns

A handle of distributed group that can be given to collective calls orGroupMember.NON_GROUP_MEMBER if the rank is not part ofranks.

N.B. use_local_synchronization doesn’t work with MPI.

N.B. While use_local_synchronization=True can be significantly faster with largerclusters and small process groups, care must be taken since it changes cluster behavioras non-member ranks don’t join the group barrier().

N.B. use_local_synchronization=True can lead to deadlocks when each rank createsmultiple overlapping process groups. To avoid that, make sure all ranks follow thesame global creation order.

torch.distributed.get_group_rank(group,global_rank)[source]#

Translate a global rank into a group rank.

global_rank must be part ofgroup otherwise this raises RuntimeError.

Parameters
  • group (ProcessGroup) – ProcessGroup to find the relative rank.

  • global_rank (int) – Global rank to query.

Returns

Group rank ofglobal_rank relative togroup

Return type

int

N.B. calling this function on the default process group returns identity

torch.distributed.get_global_rank(group,group_rank)[source]#

Translate a group rank into a global rank.

group_rank must be part ofgroup otherwise this raises RuntimeError.

Parameters
  • group (ProcessGroup) – ProcessGroup to find the global rank from.

  • group_rank (int) – Group rank to query.

Returns

Global rank ofgroup_rank relative togroup

Return type

int

N.B. calling this function on the default process group returns identity

torch.distributed.get_process_group_ranks(group)[source]#

Get all ranks associated withgroup.

Parameters

group (Optional[ProcessGroup]) – ProcessGroup to get all ranks from.If None, the default process group will be used.

Returns

List of global ranks ordered by group rank.

Return type

list[int]

DeviceMesh#

DeviceMesh is a higher level abstraction that manages process groups (or NCCL communicators).It allows user to easily create inter node and intra node process groups without worrying abouthow to set up the ranks correctly for different sub process groups, and it helps manage thosedistributed process group easily.init_device_mesh() function can beused to create new DeviceMesh, with a mesh shape describing the device topology.

classtorch.distributed.device_mesh.DeviceMesh(device_type,mesh,*,mesh_dim_names=None,backend_override=None,_init_backend=True)[source]#

DeviceMesh represents a mesh of devices, where layout of devices could berepresented as a n-d dimension array, and each value of the n-d dimensionalarray is the global id of the default process group ranks.

DeviceMesh could be used to setup the N dimensional device connections across the cluster,and manage the ProcessGroups for N dimensional parallelisms. Communications could happen oneach dimension of the DeviceMesh separately. DeviceMesh respects the device that user selectsalready (i.e. if user calltorch.cuda.set_device before the DeviceMesh initialization),and will select/set the device for the current process if user does not set the devicebeforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization.

DeviceMesh can also be used as a context manager when using together with DTensor APIs.

Note

DeviceMesh follows SPMD programming model, which means the same PyTorch Python programis running on all processes/ranks in the cluster. Therefore, users need to make sure themesh array (which describes the layout of devices) should be identical across all ranks.Inconsistentmesh will lead to silent hang.

Parameters
  • device_type (str) – The device type of the mesh. Currently supports: “cpu”, “cuda/cuda-like”.

  • mesh (ndarray) – A multi-dimensional array or an integer tensor describing the layoutof devices, where the IDs are global IDs of the default process group.

Returns

ADeviceMesh object representing the device layout.

Return type

DeviceMesh

The following program runs on each process/rank in an SPMD manner. In this example, we have 2hosts with 4 GPUs each.A reduction over the first dimension of mesh will reduce acrosscolumns (0, 4), .. and (3, 7), a reduction over the second dimensionof mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).

Example:

>>>fromtorch.distributed.device_meshimportDeviceMesh>>>>>># Initialize device mesh as (2, 4) to represent the topology>>># of cross-host(dim 0), and within-host (dim 1).>>>mesh=DeviceMesh(device_type="cuda",mesh=[[0,1,2,3],[4,5,6,7]])
staticfrom_group(group,device_type,mesh=None,*,mesh_dim_names=None)[source]#

Constructs aDeviceMesh withdevice_type from anexistingProcessGroup or a list of existingProcessGroup.

The constructed device mesh has number of dimensions equal to thenumber of groups passed. For example, if a single process group is passed in,the resulted DeviceMesh is a 1D mesh. If a list of 2 process groups is passed in,the resulted DeviceMesh is a 2D mesh.

If more than one group is passed, then themesh andmesh_dim_names argumentsare required. The order of the process groups passed in determines the topology ofthe mesh. For example, the first process group will be the 0th dimension of the DeviceMesh.Themesh tensor passed in must have the same number of dimensions as the number of processgroups passed in, and the order of the dimensions in themesh tensor must match the orderin the process groups passed in.

Parameters
  • group (ProcessGroup orlist[ProcessGroup]) – the existing ProcessGroupor a list of existing ProcessGroups.

  • device_type (str) – The device type of the mesh. Currently supports: “cpu”,“cuda/cuda-like”. Passing in a device type with a GPU index, such as “cuda:0”,is not allowed.

  • mesh (torch.Tensor orArrayLike,optional) – A multi-dimensional array or aninteger tensor describing the layout of devices, where the IDs are global IDsof the default process group. Default is None.

  • mesh_dim_names (tuple[str],optional) – A tuple of mesh dimension names to assignto each dimension of the multi-dimensional array describing the layout of devices.Its length must match the length ofmesh_shape. Each string inmesh_dim_namesmust be unique. Default is None.

Returns

ADeviceMesh object representing the device layout.

Return type

DeviceMesh

get_all_groups()[source]#

Returns a list of ProcessGroups for all mesh dimensions.

Returns

A list ofProcessGroup object.

Return type

list[torch.distributed.distributed_c10d.ProcessGroup]

get_coordinate()[source]#

Return the relative indices of this rank relative to alldimensions of the mesh. If this rank is not part of the mesh, return None.

Return type

Optional[list[int]]

get_group(mesh_dim=None)[source]#

Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and theDeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh.

Parameters
  • mesh_dim (str/python:int,optional) – it can be the name of the mesh dimension or the index

  • None. (of the mesh dimension. Default is) –

Returns

AProcessGroup object.

Return type

ProcessGroup

get_local_rank(mesh_dim=None)[source]#

Returns the local rank of the given mesh_dim of the DeviceMesh.

Parameters
  • mesh_dim (str/python:int,optional) – it can be the name of the mesh dimension or the index

  • None. (of the mesh dimension. Default is) –

Returns

An integer denotes the local rank.

Return type

int

The following program runs on each process/rank in an SPMD manner. In this example, we have 2hosts with 4 GPUs each.Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0.Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1.Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0.Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1.Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2.Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3.

Example:

>>>fromtorch.distributed.device_meshimportDeviceMesh>>>>>># Initialize device mesh as (2, 4) to represent the topology>>># of cross-host(dim 0), and within-host (dim 1).>>>mesh=DeviceMesh(device_type="cuda",mesh=[[0,1,2,3],[4,5,6,7]])
get_rank()[source]#

Returns the current global rank.

Return type

int

Point-to-point communication#

torch.distributed.send(tensor,dst=None,group=None,tag=0,group_dst=None)[source]#

Send a tensor synchronously.

Warning

tag is not supported with the NCCL backend.

Parameters
  • tensor (Tensor) – Tensor to send.

  • dst (int) – Destination rank on global process group (regardless ofgroup argument).Destination rank should not be the same as the rank of the current process.

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • tag (int,optional) – Tag to match send with remote recv

  • group_dst (int,optional) – Destination rank ongroup. Invalid to specify bothdst andgroup_dst.

torch.distributed.recv(tensor,src=None,group=None,tag=0,group_src=None)[source]#

Receives a tensor synchronously.

Warning

tag is not supported with the NCCL backend.

Parameters
  • tensor (Tensor) – Tensor to fill with received data.

  • src (int,optional) – Source rank on global process group (regardless ofgroup argument).Will receive from any process if unspecified.

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • tag (int,optional) – Tag to match recv with remote send

  • group_src (int,optional) – Destination rank ongroup. Invalid to specify bothsrc andgroup_src.

Returns

Sender rank-1, if not part of the group

Return type

int

isend() andirecv()return distributed request objects when used. In general, the type of this object is unspecifiedas they should never be created manually, but they are guaranteed to support two methods:

  • is_completed() - returns True if the operation has finished

  • wait() - will block the process until the operation is finished.is_completed() is guaranteed to return True once it returns.

torch.distributed.isend(tensor,dst=None,group=None,tag=0,group_dst=None)[source]#

Send a tensor asynchronously.

Warning

Modifyingtensor before the request completes causes undefinedbehavior.

Warning

tag is not supported with the NCCL backend.

Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self.

Parameters
  • tensor (Tensor) – Tensor to send.

  • dst (int) – Destination rank on global process group (regardless ofgroup argument)

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • tag (int,optional) – Tag to match send with remote recv

  • group_dst (int,optional) – Destination rank ongroup. Invalid to specify bothdst andgroup_dst

Returns

A distributed request object.None, if not part of the group

Return type

Optional[Work]

torch.distributed.irecv(tensor,src=None,group=None,tag=0,group_src=None)[source]#

Receives a tensor asynchronously.

Warning

tag is not supported with the NCCL backend.

Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self.

Parameters
  • tensor (Tensor) – Tensor to fill with received data.

  • src (int,optional) – Source rank on global process group (regardless ofgroup argument).Will receive from any process if unspecified.

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • tag (int,optional) – Tag to match recv with remote send

  • group_src (int,optional) – Destination rank ongroup. Invalid to specify bothsrc andgroup_src.

Returns

A distributed request object.None, if not part of the group

Return type

Optional[Work]

torch.distributed.send_object_list(object_list,dst=None,group=None,device=None,group_dst=None,use_batch=False)[source]#

Sends picklable objects inobject_list synchronously.

Similar tosend(), but Python objects can be passed in.Note that all objects inobject_list must be picklable in order to besent.

Parameters
  • object_list (List[Any]) – List of input objects to sent.Each object must be picklable. Receiver must provide lists of equal sizes.

  • dst (int) – Destination rank to sendobject_list to.Destination rank is based on global process group (regardless ofgroup argument)

  • group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None,the default process group will be used. Default isNone.

  • device (torch.device, optional) – If not None, the objects areserialized and converted to tensors which are moved to thedevice before sending. Default isNone.

  • group_dst (int,optional) – Destination rank ongroup.Must specify one ofdst andgroup_dst but not both

  • use_batch (bool,optional) – If True, use batch p2p operations instead ofregular send operations. This avoids initializing 2-rank communicators anduses existing entire group communicators. See batch_isend_irecv for usage andassumptions. Default isFalse.

Returns

None.

Note

For NCCL-based process groups, internal tensor representationsof objects must be moved to the GPU device before communication takesplace. In this case, the device used is given bytorch.cuda.current_device() and it is the user’s responsibility toensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device().

Warning

Object collectives have a number of serious performance and scalabilitylimitations. SeeObject collectives for details.

Warning

send_object_list() usespickle module implicitly, whichis known to be insecure. It is possible to construct malicious pickledata which will execute arbitrary code during unpickling. Only call thisfunction with data you trust.

Warning

Callingsend_object_list() with GPU tensors is not well supportedand inefficient as it incurs GPU -> CPU transfer since tensors would bepickled. Please consider usingsend() instead.

Example::
>>># Note: Process group initialization omitted on each rank.>>>importtorch.distributedasdist>>># Assumes backend is not NCCL>>>device=torch.device("cpu")>>>ifdist.get_rank()==0:>>># Assumes world_size of 2.>>>objects=["foo",12,{1:2}]# any picklable object>>>dist.send_object_list(objects,dst=1,device=device)>>>else:>>>objects=[None,None,None]>>>dist.recv_object_list(objects,src=0,device=device)>>>objects['foo', 12, {1: 2}]
torch.distributed.recv_object_list(object_list,src=None,group=None,device=None,group_src=None,use_batch=False)[source]#

Receives picklable objects inobject_list synchronously.

Similar torecv(), but can receive Python objects.

Parameters
  • object_list (List[Any]) – List of objects to receive into.Must provide a list of sizes equal to the size of the list being sent.

  • src (int,optional) – Source rank from which to recvobject_list.Source rank is based on global process group (regardless ofgroup argument)Will receive from any rank if set to None. Default isNone.

  • group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None,the default process group will be used. Default isNone.

  • device (torch.device, optional) – If not None, receives on this device.Default isNone.

  • group_src (int,optional) – Destination rank ongroup. Invalid to specify bothsrc andgroup_src.

  • use_batch (bool,optional) – If True, use batch p2p operations instead ofregular send operations. This avoids initializing 2-rank communicators anduses existing entire group communicators. See batch_isend_irecv for usage andassumptions. Default isFalse.

Returns

Sender rank. -1 if rank is not part of the group. If rank is part of the group,object_list will contain the sent objects fromsrc rank.

Note

For NCCL-based process groups, internal tensor representationsof objects must be moved to the GPU device before communication takesplace. In this case, the device used is given bytorch.cuda.current_device() and it is the user’s responsibility toensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device().

Warning

Object collectives have a number of serious performance and scalabilitylimitations. SeeObject collectives for details.

Warning

recv_object_list() usespickle module implicitly, whichis known to be insecure. It is possible to construct malicious pickledata which will execute arbitrary code during unpickling. Only call thisfunction with data you trust.

Warning

Callingrecv_object_list() with GPU tensors is not well supportedand inefficient as it incurs GPU -> CPU transfer since tensors would bepickled. Please consider usingrecv() instead.

Example::
>>># Note: Process group initialization omitted on each rank.>>>importtorch.distributedasdist>>># Assumes backend is not NCCL>>>device=torch.device("cpu")>>>ifdist.get_rank()==0:>>># Assumes world_size of 2.>>>objects=["foo",12,{1:2}]# any picklable object>>>dist.send_object_list(objects,dst=1,device=device)>>>else:>>>objects=[None,None,None]>>>dist.recv_object_list(objects,src=0,device=device)>>>objects['foo', 12, {1: 2}]
torch.distributed.batch_isend_irecv(p2p_op_list)[source]#

Send or Receive a batch of tensors asynchronously and return a list of requests.

Process each of the operations inp2p_op_list and return the correspondingrequests. NCCL, Gloo, and UCC backend are currently supported.

Parameters

p2p_op_list (list[torch.distributed.distributed_c10d.P2POp]) – A list of point-to-point operations(type of each operator istorch.distributed.P2POp). The order of the isend/irecv in the listmatters and it needs to match with corresponding isend/irecv on theremote end.

Returns

A list of distributed request objects returned by calling the correspondingop in the op_list.

Return type

list[torch.distributed.distributed_c10d.Work]

Examples

>>>send_tensor=torch.arange(2,dtype=torch.float32)+2*rank>>>recv_tensor=torch.randn(2,dtype=torch.float32)>>>send_op=dist.P2POp(dist.isend,send_tensor,(rank+1)%world_size)>>>recv_op=dist.P2POp(...dist.irecv,recv_tensor,(rank-1+world_size)%world_size...)>>>reqs=batch_isend_irecv([send_op,recv_op])>>>forreqinreqs:>>>req.wait()>>>recv_tensortensor([2, 3])     # Rank 0tensor([0, 1])     # Rank 1

Note

Note that when this API is used with the NCCL PG backend, users must setthe current GPU device withtorch.cuda.set_device, otherwise it willlead to unexpected hang issues.

In addition, if this API is the first collective call in thegrouppassed todist.P2POp, all ranks of thegroup must participate inthis API call; otherwise, the behavior is undefined. If this API call isnot the first collective call in thegroup, batched P2P operationsinvolving only a subset of ranks of thegroup are allowed.

classtorch.distributed.P2POp(op,tensor,peer=None,group=None,tag=0,group_peer=None)[source]#

A class to build point-to-point operations forbatch_isend_irecv.

This class builds the type of P2P operation, communication buffer, peer rank,Process Group, and tag. Instances of this class will be passed tobatch_isend_irecv for point-to-point communications.

Parameters
  • op (Callable) – A function to send data to or receive data from a peer process.The type ofop is eithertorch.distributed.isend ortorch.distributed.irecv.

  • tensor (Tensor) – Tensor to send or receive.

  • peer (int,optional) – Destination or source rank.

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • tag (int,optional) – Tag to match send with recv.

  • group_peer (int,optional) – Destination or source rank.

Synchronous and asynchronous collective operations#

Every collective operation function supports the following two kinds of operations,depending on the setting of theasync_op flag passed into the collective:

Synchronous operation - the default mode, whenasync_op is set toFalse.When the function returns, it is guaranteed thatthe collective operation is performed. In the case of CUDA operations, it is not guaranteedthat the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, anyfurther function calls utilizing the output of the collective call will behave as expected. For CUDA collectives,function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care ofsynchronization under the scenario of running under different streams. For details on CUDA semantics such as streamsynchronization, seeCUDA Semantics.See the below script to see examples of differences in these semantics for CPU and CUDA operations.

Asynchronous operation - whenasync_op is set to True. The collective operation functionreturns a distributed request object. In general, you don’t need to create it manually and itis guaranteed to support two methods:

  • is_completed() - in the case of CPU collectives, returnsTrue if completed. In the case of CUDA operations,returnsTrue if the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on thedefault stream without further synchronization.

  • wait() - in the case of CPU collectives, will block the process until the operation is completed. In the caseof CUDA collectives, will block the currently active CUDA stream until the operation is completed (but will not block the CPU).

  • get_future() - returnstorch._C.Future object. Supported for NCCL, also supported for most operations on GLOOand MPI, except for peer to peer operations.Note: as we continue adopting Futures and merging APIs,get_future() call might become redundant.

Example

The following code can serve as a reference regarding semantics for CUDA operations when using distributed collectives.It shows the explicit need to synchronize when using collective outputs on different CUDA streams:

# Code runs on each rank.dist.init_process_group("nccl",rank=rank,world_size=2)output=torch.tensor([rank]).cuda(rank)s=torch.cuda.Stream()handle=dist.all_reduce(output,async_op=True)# Wait ensures the operation is enqueued, but not necessarily complete.handle.wait()# Using result on non-default stream.withtorch.cuda.stream(s):s.wait_stream(torch.cuda.default_stream())output.add_(100)ifrank==0:# if the explicit call to wait_stream was omitted, the output below will be# non-deterministically 1 or 101, depending on whether the allreduce overwrote# the value after the add completed.print(output)

Collective functions#

torch.distributed.broadcast(tensor,src=None,group=None,async_op=False,group_src=None)[source]#

Broadcasts the tensor to the whole group.

tensor must have the same number of elements in all processesparticipating in the collective.

Parameters
  • tensor (Tensor) – Data to be sent ifsrc is the rank of currentprocess, and tensor to be used to save received data otherwise.

  • src (int) – Source rank on global process group (regardless ofgroup argument).

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • async_op (bool,optional) – Whether this op should be an async op

  • group_src (int) – Source rank ongroup. Must specify one ofgroup_srcandsrc but not both.

Returns

Async work handle, if async_op is set to True.None, if not async_op or if not part of the group

torch.distributed.broadcast_object_list(object_list,src=None,group=None,device=None,group_src=None)[source]#

Broadcasts picklable objects inobject_list to the whole group.

Similar tobroadcast(), but Python objects can be passed in.Note that all objects inobject_list must be picklable in order to bebroadcasted.

Parameters
  • object_list (List[Any]) – List of input objects to broadcast.Each object must be picklable. Only objects on thesrc rank willbe broadcast, but each rank must provide lists of equal sizes.

  • src (int) – Source rank from which to broadcastobject_list.Source rank is based on global process group (regardless ofgroup argument)

  • group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None,the default process group will be used. Default isNone.

  • device (torch.device, optional) – If not None, the objects areserialized and converted to tensors which are moved to thedevice before broadcasting. Default isNone.

  • group_src (int) – Source rank ongroup. Must not specify one ofgroup_srcandsrc but not both.

Returns

None. If rank is part of the group,object_list will contain thebroadcasted objects fromsrc rank.

Note

For NCCL-based process groups, internal tensor representationsof objects must be moved to the GPU device before communication takesplace. In this case, the device used is given bytorch.cuda.current_device() and it is the user’s responsibility toensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device().

Note

Note that this API differs slightly from thebroadcast()collective since it does not provide anasync_op handle and thuswill be a blocking call.

Warning

Object collectives have a number of serious performance and scalabilitylimitations. SeeObject collectives for details.

Warning

broadcast_object_list() usespickle module implicitly, whichis known to be insecure. It is possible to construct malicious pickledata which will execute arbitrary code during unpickling. Only call thisfunction with data you trust.

Warning

Callingbroadcast_object_list() with GPU tensors is not well supportedand inefficient as it incurs GPU -> CPU transfer since tensors would bepickled. Please consider usingbroadcast() instead.

Example::
>>># Note: Process group initialization omitted on each rank.>>>importtorch.distributedasdist>>>ifdist.get_rank()==0:>>># Assumes world_size of 3.>>>objects=["foo",12,{1:2}]# any picklable object>>>else:>>>objects=[None,None,None]>>># Assumes backend is not NCCL>>>device=torch.device("cpu")>>>dist.broadcast_object_list(objects,src=0,device=device)>>>objects['foo', 12, {1: 2}]
torch.distributed.all_reduce(tensor,op=<RedOpType.SUM:0>,group=None,async_op=False)[source]#

Reduces the tensor data across all machines in a way that all get the final result.

After the calltensor is going to be bitwise identical in all processes.

Complex tensors are supported.

Parameters
  • tensor (Tensor) – Input and output of the collective. The functionoperates in-place.

  • op (optional) – One of the values fromtorch.distributed.ReduceOpenum. Specifies an operation used for element-wise reductions.

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • async_op (bool,optional) – Whether this op should be an async op

Returns

Async work handle, if async_op is set to True.None, if not async_op or if not part of the group

Examples

>>># All tensors below are of torch.int64 type.>>># We have 2 process groups, 2 ranks.>>>device=torch.device(f"cuda:{rank}")>>>tensor=torch.arange(2,dtype=torch.int64,device=device)+1+2*rank>>>tensortensor([1, 2], device='cuda:0') # Rank 0tensor([3, 4], device='cuda:1') # Rank 1>>>dist.all_reduce(tensor,op=ReduceOp.SUM)>>>tensortensor([4, 6], device='cuda:0') # Rank 0tensor([4, 6], device='cuda:1') # Rank 1
>>># All tensors below are of torch.cfloat type.>>># We have 2 process groups, 2 ranks.>>>tensor=torch.tensor(...[1+1j,2+2j],dtype=torch.cfloat,device=device...)+2*rank*(1+1j)>>>tensortensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1>>>dist.all_reduce(tensor,op=ReduceOp.SUM)>>>tensortensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1
torch.distributed.reduce(tensor,dst=None,op=<RedOpType.SUM:0>,group=None,async_op=False,group_dst=None)[source]#

Reduces the tensor data across all machines.

Only the process with rankdst is going to receive the final result.

Parameters
  • tensor (Tensor) – Input and output of the collective. The functionoperates in-place.

  • dst (int) – Destination rank on global process group (regardless ofgroup argument)

  • op (optional) – One of the values fromtorch.distributed.ReduceOpenum. Specifies an operation used for element-wise reductions.

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • async_op (bool,optional) – Whether this op should be an async op

  • group_dst (int) – Destination rank ongroup. Must specify one ofgroup_dstanddst but not both.

Returns

Async work handle, if async_op is set to True.None, if not async_op or if not part of the group

torch.distributed.all_gather(tensor_list,tensor,group=None,async_op=False)[source]#

Gathers tensors from the whole group in a list.

Complex and uneven sized tensors are supported.

Parameters
  • tensor_list (list[Tensor]) – Output list. It should containcorrectly-sized tensors to be used for output of the collective.Uneven sized tensors are supported.

  • tensor (Tensor) – Tensor to be broadcast from current process.

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • async_op (bool,optional) – Whether this op should be an async op

Returns

Async work handle, if async_op is set to True.None, if not async_op or if not part of the group

Examples

>>># All tensors below are of torch.int64 dtype.>>># We have 2 process groups, 2 ranks.>>>device=torch.device(f"cuda:{rank}")>>>tensor_list=[...torch.zeros(2,dtype=torch.int64,device=device)for_inrange(2)...]>>>tensor_list[tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0[tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1>>>tensor=torch.arange(2,dtype=torch.int64,device=device)+1+2*rank>>>tensortensor([1, 2], device='cuda:0') # Rank 0tensor([3, 4], device='cuda:1') # Rank 1>>>dist.all_gather(tensor_list,tensor)>>>tensor_list[tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0[tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1
>>># All tensors below are of torch.cfloat dtype.>>># We have 2 process groups, 2 ranks.>>>tensor_list=[...torch.zeros(2,dtype=torch.cfloat,device=device)for_inrange(2)...]>>>tensor_list[tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0[tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1>>>tensor=torch.tensor(...[1+1j,2+2j],dtype=torch.cfloat,device=device...)+2*rank*(1+1j)>>>tensortensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1>>>dist.all_gather(tensor_list,tensor)>>>tensor_list[tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0[tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1
torch.distributed.all_gather_into_tensor(output_tensor,input_tensor,group=None,async_op=False)[source]#

Gather tensors from all ranks and put them in a single output tensor.

This function requires all tensors to be the same size on each process.

Parameters
  • output_tensor (Tensor) – Output tensor to accommodate tensor elementsfrom all ranks. It must be correctly sized to have one of thefollowing forms:(i) a concatenation of all the input tensors along the primarydimension; for definition of “concatenation”, seetorch.cat();(ii) a stack of all the input tensors along the primary dimension;for definition of “stack”, seetorch.stack().Examples below may better explain the supported output forms.

  • input_tensor (Tensor) – Tensor to be gathered from current rank.Different from theall_gather API, the input tensors in thisAPI must have the same size across all ranks.

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • async_op (bool,optional) – Whether this op should be an async op

Returns

Async work handle, if async_op is set to True.None, if not async_op or if not part of the group

Examples

>>># All tensors below are of torch.int64 dtype and on CUDA devices.>>># We have two ranks.>>>device=torch.device(f"cuda:{rank}")>>>tensor_in=torch.arange(2,dtype=torch.int64,device=device)+1+2*rank>>>tensor_intensor([1, 2], device='cuda:0') # Rank 0tensor([3, 4], device='cuda:1') # Rank 1>>># Output in concatenation form>>>tensor_out=torch.zeros(world_size*2,dtype=torch.int64,device=device)>>>dist.all_gather_into_tensor(tensor_out,tensor_in)>>>tensor_outtensor([1, 2, 3, 4], device='cuda:0') # Rank 0tensor([1, 2, 3, 4], device='cuda:1') # Rank 1>>># Output in stack form>>>tensor_out2=torch.zeros(world_size,2,dtype=torch.int64,device=device)>>>dist.all_gather_into_tensor(tensor_out2,tensor_in)>>>tensor_out2tensor([[1, 2],        [3, 4]], device='cuda:0') # Rank 0tensor([[1, 2],        [3, 4]], device='cuda:1') # Rank 1
torch.distributed.all_gather_object(object_list,obj,group=None)[source]#

Gathers picklable objects from the whole group into a list.

Similar toall_gather(), but Python objects can be passed in.Note that the object must be picklable in order to be gathered.

Parameters
  • object_list (list[Any]) – Output list. It should be correctly sized as thesize of the group for this collective and will contain the output.

  • obj (Any) – Pickable Python object to be broadcast from current process.

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used. Default isNone.

Returns

None. If the calling rank is part of this group, the output of thecollective will be populated into the inputobject_list. If thecalling rank is not part of the group, the passed inobject_list willbe unmodified.

Note

Note that this API differs slightly from theall_gather()collective since it does not provide anasync_op handle and thuswill be a blocking call.

Note

For NCCL-based processed groups, internal tensor representationsof objects must be moved to the GPU device before communication takesplace. In this case, the device used is given bytorch.cuda.current_device() and it is the user’s responsibility toensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device().

Warning

Object collectives have a number of serious performance and scalabilitylimitations. SeeObject collectives for details.

Warning

all_gather_object() usespickle module implicitly, which isknown to be insecure. It is possible to construct malicious pickle datawhich will execute arbitrary code during unpickling. Only call thisfunction with data you trust.

Warning

Callingall_gather_object() with GPU tensors is not well supportedand inefficient as it incurs GPU -> CPU transfer since tensors would bepickled. Please consider usingall_gather() instead.

Example::
>>># Note: Process group initialization omitted on each rank.>>>importtorch.distributedasdist>>># Assumes world_size of 3.>>>gather_objects=["foo",12,{1:2}]# any picklable object>>>output=[Nonefor_ingather_objects]>>>dist.all_gather_object(output,gather_objects[dist.get_rank()])>>>output['foo', 12, {1: 2}]
torch.distributed.gather(tensor,gather_list=None,dst=None,group=None,async_op=False,group_dst=None)[source]#

Gathers a list of tensors in a single process.

This function requires all tensors to be the same size on each process.

Parameters
  • tensor (Tensor) – Input tensor.

  • gather_list (list[Tensor],optional) – List of appropriately,same-sized tensors to use for gathered data(default is None, must be specified on the destination rank)

  • dst (int,optional) – Destination rank on global process group (regardless ofgroup argument).(If bothdst andgroup_dst are None, default is global rank 0)

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • async_op (bool,optional) – Whether this op should be an async op

  • group_dst (int,optional) – Destination rank ongroup. Invalid to specify bothdst andgroup_dst

Returns

Async work handle, if async_op is set to True.None, if not async_op or if not part of the group

Note

Note that all Tensors in gather_list must have the same size.

Example::
>>># We have 2 process groups, 2 ranks.>>>tensor_size=2>>>device=torch.device(f'cuda:{rank}')>>>tensor=torch.ones(tensor_size,device=device)+rank>>>ifdist.get_rank()==0:>>>gather_list=[torch.zeros_like(tensor,device=device)foriinrange(2)]>>>else:>>>gather_list=None>>>dist.gather(tensor,gather_list,dst=0)>>># Rank 0 gets gathered data.>>>gather_list[tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0None                                                                   # Rank 1
torch.distributed.gather_object(obj,object_gather_list=None,dst=None,group=None,group_dst=None)[source]#

Gathers picklable objects from the whole group in a single process.

Similar togather(), but Python objects can be passed in. Note that theobject must be picklable in order to be gathered.

Parameters
  • obj (Any) – Input object. Must be picklable.

  • object_gather_list (list[Any]) – Output list. On thedst rank, itshould be correctly sized as the size of the group for thiscollective and will contain the output. Must beNone on non-dstranks. (default isNone)

  • dst (int,optional) – Destination rank on global process group (regardless ofgroup argument).(If bothdst andgroup_dst are None, default is global rank 0)

  • group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None,the default process group will be used. Default isNone.

  • group_dst (int,optional) – Destination rank ongroup. Invalid to specify bothdst andgroup_dst

Returns

None. On thedst rank,object_gather_list will contain theoutput of the collective.

Note

Note that this API differs slightly from the gather collectivesince it does not provide an async_op handle and thus will be a blockingcall.

Note

For NCCL-based processed groups, internal tensor representationsof objects must be moved to the GPU device before communication takesplace. In this case, the device used is given bytorch.cuda.current_device() and it is the user’s responsibility toensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device().

Warning

Object collectives have a number of serious performance and scalabilitylimitations. SeeObject collectives for details.

Warning

gather_object() usespickle module implicitly, which isknown to be insecure. It is possible to construct malicious pickle datawhich will execute arbitrary code during unpickling. Only call thisfunction with data you trust.

Warning

Callinggather_object() with GPU tensors is not well supportedand inefficient as it incurs GPU -> CPU transfer since tensors would bepickled. Please consider usinggather() instead.

Example::
>>># Note: Process group initialization omitted on each rank.>>>importtorch.distributedasdist>>># Assumes world_size of 3.>>>gather_objects=["foo",12,{1:2}]# any picklable object>>>output=[Nonefor_ingather_objects]>>>dist.gather_object(...gather_objects[dist.get_rank()],...outputifdist.get_rank()==0elseNone,...dst=0...)>>># On rank 0>>>output['foo', 12, {1: 2}]
torch.distributed.scatter(tensor,scatter_list=None,src=None,group=None,async_op=False,group_src=None)[source]#

Scatters a list of tensors to all processes in a group.

Each process will receive exactly one tensor and store its data in thetensor argument.

Complex tensors are supported.

Parameters
  • tensor (Tensor) – Output tensor.

  • scatter_list (list[Tensor]) – List of tensors to scatter (default isNone, must be specified on the source rank)

  • src (int) – Source rank on global process group (regardless ofgroup argument).(If bothsrc andgroup_src are None, default is global rank 0)

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • async_op (bool,optional) – Whether this op should be an async op

  • group_src (int,optional) – Source rank ongroup. Invalid to specify bothsrc andgroup_src

Returns

Async work handle, if async_op is set to True.None, if not async_op or if not part of the group

Note

Note that all Tensors in scatter_list must have the same size.

Example::
>>># Note: Process group initialization omitted on each rank.>>>importtorch.distributedasdist>>>tensor_size=2>>>device=torch.device(f'cuda:{rank}')>>>output_tensor=torch.zeros(tensor_size,device=device)>>>ifdist.get_rank()==0:>>># Assumes world_size of 2.>>># Only tensors, all of which must be the same size.>>>t_ones=torch.ones(tensor_size,device=device)>>>t_fives=torch.ones(tensor_size,device=device)*5>>>scatter_list=[t_ones,t_fives]>>>else:>>>scatter_list=None>>>dist.scatter(output_tensor,scatter_list,src=0)>>># Rank i gets scatter_list[i].>>>output_tensortensor([1., 1.], device='cuda:0') # Rank 0tensor([5., 5.], device='cuda:1') # Rank 1
torch.distributed.scatter_object_list(scatter_object_output_list,scatter_object_input_list=None,src=None,group=None,group_src=None)[source]#

Scatters picklable objects inscatter_object_input_list to the whole group.

Similar toscatter(), but Python objects can be passed in. Oneach rank, the scattered object will be stored as the first element ofscatter_object_output_list. Note that all objects inscatter_object_input_list must be picklable in order to be scattered.

Parameters
  • scatter_object_output_list (List[Any]) – Non-empty list whose firstelement will store the object scattered to this rank.

  • scatter_object_input_list (List[Any],optional) – List of input objects to scatter.Each object must be picklable. Only objects on thesrc rank willbe scattered, and the argument can beNone for non-src ranks.

  • src (int) – Source rank from which to scatterscatter_object_input_list.Source rank is based on global process group (regardless ofgroup argument).(If bothsrc andgroup_src are None, default is global rank 0)

  • group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None,the default process group will be used. Default isNone.

  • group_src (int,optional) – Source rank ongroup. Invalid to specify bothsrc andgroup_src

Returns

None. If rank is part of the group,scatter_object_output_listwill have its first element set to the scattered object for this rank.

Note

Note that this API differs slightly from the scatter collectivesince it does not provide anasync_op handle and thus will be ablocking call.

Warning

Object collectives have a number of serious performance and scalabilitylimitations. SeeObject collectives for details.

Warning

scatter_object_list() usespickle module implicitly, whichis known to be insecure. It is possible to construct malicious pickledata which will execute arbitrary code during unpickling. Only call thisfunction with data you trust.

Warning

Callingscatter_object_list() with GPU tensors is not well supportedand inefficient as it incurs GPU -> CPU transfer since tensors would bepickled. Please consider usingscatter() instead.

Example::
>>># Note: Process group initialization omitted on each rank.>>>importtorch.distributedasdist>>>ifdist.get_rank()==0:>>># Assumes world_size of 3.>>>objects=["foo",12,{1:2}]# any picklable object>>>else:>>># Can be any list on non-src ranks, elements are not used.>>>objects=[None,None,None]>>>output_list=[None]>>>dist.scatter_object_list(output_list,objects,src=0)>>># Rank i gets objects[i]. For example, on rank 2:>>>output_list[{1: 2}]
torch.distributed.reduce_scatter(output,input_list,op=<RedOpType.SUM:0>,group=None,async_op=False)[source]#

Reduces, then scatters a list of tensors to all processes in a group.

Parameters
  • output (Tensor) – Output tensor.

  • input_list (list[Tensor]) – List of tensors to reduce and scatter.

  • op (optional) – One of the values fromtorch.distributed.ReduceOpenum. Specifies an operation used for element-wise reductions.

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • async_op (bool,optional) – Whether this op should be an async op.

Returns

Async work handle, if async_op is set to True.None, if not async_op or if not part of the group.

torch.distributed.reduce_scatter_tensor(output,input,op=<RedOpType.SUM:0>,group=None,async_op=False)[source]#

Reduces, then scatters a tensor to all ranks in a group.

Parameters
  • output (Tensor) – Output tensor. It should have the same size across allranks.

  • input (Tensor) – Input tensor to be reduced and scattered. Its sizeshould be output tensor size times the world size. The input tensorcan have one of the following shapes:(i) a concatenation of the output tensors along the primarydimension, or(ii) a stack of the output tensors along the primary dimension.For definition of “concatenation”, seetorch.cat().For definition of “stack”, seetorch.stack().

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • async_op (bool,optional) – Whether this op should be an async op.

Returns

Async work handle, if async_op is set to True.None, if not async_op or if not part of the group.

Examples

>>># All tensors below are of torch.int64 dtype and on CUDA devices.>>># We have two ranks.>>>device=torch.device(f"cuda:{rank}")>>>tensor_out=torch.zeros(2,dtype=torch.int64,device=device)>>># Input in concatenation form>>>tensor_in=torch.arange(world_size*2,dtype=torch.int64,device=device)>>>tensor_intensor([0, 1, 2, 3], device='cuda:0') # Rank 0tensor([0, 1, 2, 3], device='cuda:1') # Rank 1>>>dist.reduce_scatter_tensor(tensor_out,tensor_in)>>>tensor_outtensor([0, 2], device='cuda:0') # Rank 0tensor([4, 6], device='cuda:1') # Rank 1>>># Input in stack form>>>tensor_in=torch.reshape(tensor_in,(world_size,2))>>>tensor_intensor([[0, 1],        [2, 3]], device='cuda:0') # Rank 0tensor([[0, 1],        [2, 3]], device='cuda:1') # Rank 1>>>dist.reduce_scatter_tensor(tensor_out,tensor_in)>>>tensor_outtensor([0, 2], device='cuda:0') # Rank 0tensor([4, 6], device='cuda:1') # Rank 1
torch.distributed.all_to_all_single(output,input,output_split_sizes=None,input_split_sizes=None,group=None,async_op=False)[source]#

Split input tensor and then scatter the split list to all processes in a group.

Later the received tensors are concatenated from all the processes in the groupand returned as a single output tensor.

Complex tensors are supported.

Parameters
  • output (Tensor) – Gathered concatenated output tensor.

  • input (Tensor) – Input tensor to scatter.

  • output_split_sizes – (list[Int], optional): Output split sizes for dim 0if specified None or empty, dim 0 ofoutput tensor must divideequally byworld_size.

  • input_split_sizes – (list[Int], optional): Input split sizes for dim 0if specified None or empty, dim 0 ofinput tensor must divideequally byworld_size.

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • async_op (bool,optional) – Whether this op should be an async op.

Returns

Async work handle, if async_op is set to True.None, if not async_op or if not part of the group.

Warning

all_to_all_single is experimental and subject to change.

Examples

>>>input=torch.arange(4)+rank*4>>>inputtensor([0, 1, 2, 3])     # Rank 0tensor([4, 5, 6, 7])     # Rank 1tensor([8, 9, 10, 11])   # Rank 2tensor([12, 13, 14, 15]) # Rank 3>>>output=torch.empty([4],dtype=torch.int64)>>>dist.all_to_all_single(output,input)>>>outputtensor([0, 4, 8, 12])    # Rank 0tensor([1, 5, 9, 13])    # Rank 1tensor([2, 6, 10, 14])   # Rank 2tensor([3, 7, 11, 15])   # Rank 3
>>># Essentially, it is similar to following operation:>>>scatter_list=list(input.chunk(world_size))>>>gather_list=list(output.chunk(world_size))>>>foriinrange(world_size):>>>dist.scatter(gather_list[i],scatter_listifi==rankelse[],src=i)
>>># Another example with uneven split>>>inputtensor([0, 1, 2, 3, 4, 5])                                       # Rank 0tensor([10, 11, 12, 13, 14, 15, 16, 17, 18])                     # Rank 1tensor([20, 21, 22, 23, 24])                                     # Rank 2tensor([30, 31, 32, 33, 34, 35, 36])                             # Rank 3>>>input_splits[2, 2, 1, 1]                                                     # Rank 0[3, 2, 2, 2]                                                     # Rank 1[2, 1, 1, 1]                                                     # Rank 2[2, 2, 2, 1]                                                     # Rank 3>>>output_splits[2, 3, 2, 2]                                                     # Rank 0[2, 2, 1, 2]                                                     # Rank 1[1, 2, 1, 2]                                                     # Rank 2[1, 2, 1, 1]                                                     # Rank 3>>>output=...>>>dist.all_to_all_single(output,input,output_splits,input_splits)>>>outputtensor([ 0,  1, 10, 11, 12, 20, 21, 30, 31])                     # Rank 0tensor([ 2,  3, 13, 14, 22, 32, 33])                             # Rank 1tensor([ 4, 15, 16, 23, 34, 35])                                 # Rank 2tensor([ 5, 17, 18, 24, 36])                                     # Rank 3
>>># Another example with tensors of torch.cfloat type.>>>input=torch.tensor(...[1+1j,2+2j,3+3j,4+4j],dtype=torch.cfloat...)+4*rank*(1+1j)>>>inputtensor([1+1j, 2+2j, 3+3j, 4+4j])                                # Rank 0tensor([5+5j, 6+6j, 7+7j, 8+8j])                                # Rank 1tensor([9+9j, 10+10j, 11+11j, 12+12j])                          # Rank 2tensor([13+13j, 14+14j, 15+15j, 16+16j])                        # Rank 3>>>output=torch.empty([4],dtype=torch.int64)>>>dist.all_to_all_single(output,input)>>>outputtensor([1+1j, 5+5j, 9+9j, 13+13j])                              # Rank 0tensor([2+2j, 6+6j, 10+10j, 14+14j])                            # Rank 1tensor([3+3j, 7+7j, 11+11j, 15+15j])                            # Rank 2tensor([4+4j, 8+8j, 12+12j, 16+16j])                            # Rank 3
torch.distributed.all_to_all(output_tensor_list,input_tensor_list,group=None,async_op=False)[source]#

Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.

Complex tensors are supported.

Parameters
  • output_tensor_list (list[Tensor]) – List of tensors to be gathered oneper rank.

  • input_tensor_list (list[Tensor]) – List of tensors to scatter one per rank.

  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • async_op (bool,optional) – Whether this op should be an async op.

Returns

Async work handle, if async_op is set to True.None, if not async_op or if not part of the group.

Warning

all_to_all is experimental and subject to change.

Examples

>>>input=torch.arange(4)+rank*4>>>input=list(input.chunk(4))>>>input[tensor([0]), tensor([1]), tensor([2]), tensor([3])]     # Rank 0[tensor([4]), tensor([5]), tensor([6]), tensor([7])]     # Rank 1[tensor([8]), tensor([9]), tensor([10]), tensor([11])]   # Rank 2[tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3>>>output=list(torch.empty([4],dtype=torch.int64).chunk(4))>>>dist.all_to_all(output,input)>>>output[tensor([0]), tensor([4]), tensor([8]), tensor([12])]    # Rank 0[tensor([1]), tensor([5]), tensor([9]), tensor([13])]    # Rank 1[tensor([2]), tensor([6]), tensor([10]), tensor([14])]   # Rank 2[tensor([3]), tensor([7]), tensor([11]), tensor([15])]   # Rank 3
>>># Essentially, it is similar to following operation:>>>scatter_list=input>>>gather_list=output>>>foriinrange(world_size):>>>dist.scatter(gather_list[i],scatter_listifi==rankelse[],src=i)
>>>inputtensor([0, 1, 2, 3, 4, 5])                                       # Rank 0tensor([10, 11, 12, 13, 14, 15, 16, 17, 18])                     # Rank 1tensor([20, 21, 22, 23, 24])                                     # Rank 2tensor([30, 31, 32, 33, 34, 35, 36])                             # Rank 3>>>input_splits[2, 2, 1, 1]                                                     # Rank 0[3, 2, 2, 2]                                                     # Rank 1[2, 1, 1, 1]                                                     # Rank 2[2, 2, 2, 1]                                                     # Rank 3>>>output_splits[2, 3, 2, 2]                                                     # Rank 0[2, 2, 1, 2]                                                     # Rank 1[1, 2, 1, 2]                                                     # Rank 2[1, 2, 1, 1]                                                     # Rank 3>>>input=list(input.split(input_splits))>>>input[tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])]                   # Rank 0[tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1[tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])]                 # Rank 2[tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])]         # Rank 3>>>output=...>>>dist.all_to_all(output,input)>>>output[tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])]   # Rank 0[tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])]           # Rank 1[tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])]              # Rank 2[tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])]                  # Rank 3
>>># Another example with tensors of torch.cfloat type.>>>input=torch.tensor(...[1+1j,2+2j,3+3j,4+4j],dtype=torch.cfloat...)+4*rank*(1+1j)>>>input=list(input.chunk(4))>>>input[tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])]            # Rank 0[tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])]            # Rank 1[tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])]      # Rank 2[tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])]    # Rank 3>>>output=list(torch.empty([4],dtype=torch.int64).chunk(4))>>>dist.all_to_all(output,input)>>>output[tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])]          # Rank 0[tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])]        # Rank 1[tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])]        # Rank 2[tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])]        # Rank 3
torch.distributed.barrier(group=None,async_op=False,device_ids=None)[source]#

Synchronize all processes.

This collective blocks processes until the whole group enters this function,if async_op is False, or if async work handle is called on wait().

Parameters
  • group (ProcessGroup,optional) – The process group to work on. If None,the default process group will be used.

  • async_op (bool,optional) – Whether this op should be an async op

  • device_ids ([int],optional) – List of device/GPU ids. Only one id is expected.

Returns

Async work handle, if async_op is set to True.None, if not async_op or if not part of the group

Note

ProcessGroupNCCL now blocks the cpu thread till the completion of the barrier collective.

Note

ProcessGroupNCCL implements barrier as an all_reduce of a 1-element tensor. A device must be chosenfor allocating this tensor. The device choice is made by checking in this order (1) the first device passed todevice_ids arg of barrier if not None, (2) the device passed to init_process_group if not None, (3) the devicethat was first used with this process group, if another collective with tensor inputs has been performed, (4)the device index indicated by the global rank mod local device count.

torch.distributed.monitored_barrier(group=None,timeout=None,wait_all_ranks=False)[source]#

Synchronize processes similar totorch.distributed.barrier, but consider a configurable timeout.

It is able to report ranks that did not pass this barrier within the provided timeout.Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0.Rank 0 will block until all send /recv from other ranks are processed, and will reportfailures for ranks that failed to respond in time. Note that if one rank does not reach themonitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier.

This collective will block all processes/ranks in the group, until thewhole group exits the function successfully, making it useful for debuggingand synchronizing. However, it can have a performance impact and should onlybe used for debugging or scenarios that require full synchronization pointson the host-side. For debugging purposes, this barrier can be insertedbefore the application’s collective calls to check if any ranks aredesynchronized.

Note

Note that this collective is only supported with the GLOO backend.

Parameters
  • group (ProcessGroup,optional) – The process group to work on. IfNone, the default process group will be used.

  • timeout (datetime.timedelta,optional) – Timeout for monitored_barrier.IfNone, the default process group timeout will be used.

  • wait_all_ranks (bool,optional) – Whether to collect all failed ranks ornot. By default, this isFalse andmonitored_barrier on rank 0will throw on the first failed rank it encounters in order to failfast. By settingwait_all_ranks=Truemonitored_barrier willcollect all failed ranks and throw an error containing informationabout all failed ranks.

Returns

None.

Example::
>>># Note: Process group initialization omitted on each rank.>>>importtorch.distributedasdist>>>ifdist.get_rank()!=1:>>>dist.monitored_barrier()# Raises exception indicating that>>># rank 1 did not call into monitored_barrier.>>># Example with wait_all_ranks=True>>>ifdist.get_rank()==0:>>>dist.monitored_barrier(wait_all_ranks=True)# Raises exception>>># indicating that ranks 1, 2, ... world_size - 1 did not call into>>># monitored_barrier.
classtorch.distributed.Work#

AWork object represents the handle to a pending asynchronous operation inPyTorch’s distributed package. It is returned by non-blocking collective operations,such asdist.all_reduce(tensor, async_op=True).

block_current_stream(self:torch._C._distributed_c10d.Work)None#

Blocks the currently active GPU stream on the operation tocomplete. For GPU based collectives this is equivalent tosynchronize. For CPU initiated collectives such as with Gloo thiswill block the CUDA stream until the operation is complete.

This returns immediately in all cases.

To check whether an operation was successful you should check theWork object result asynchronously.

boxed(self:torch._C._distributed_c10d.Work)object#
exception(self:torch._C._distributed_c10d.Work)std::__exception_ptr::exception_ptr#
get_future(self:torch._C._distributed_c10d.Work)torch.Future#
Returns

Atorch.futures.Future object which is associated with the completion oftheWork. As an example, a future object can be retrievedbyfut=process_group.allreduce(tensors).get_future().

Example::

Below is an example of a simple allreduce DDP communication hook that usesget_future API to retrieve a Future associated with the completion ofallreduce.

>>>defallreduce(process_group:dist.ProcessGroup,bucket:dist.GradBucket):->torch.futures.Future>>>group_to_use=process_groupifprocess_groupisnotNoneelsetorch.distributed.group.WORLD>>>tensor=bucket.buffer().div_(group_to_use.size())>>>returntorch.distributed.all_reduce(tensor,group=group_to_use,async_op=True).get_future()>>>ddp_model.register_comm_hook(state=None,hook=allreduce)

Warning

get_future API supports NCCL, and partially GLOO and MPI backends(no support for peer-to-peer operations like send/recv) and will return atorch.futures.Future.

In the example above,allreduce work will be done on GPU using NCCL backend,fut.wait() will return after synchronizing the appropriate NCCL streamswith PyTorch’s current device streams to ensure we can have asynchronous CUDAexecution and it does not wait for the entire operation to complete on GPU. Note thatCUDAFuture does not supportTORCH_NCCL_BLOCKING_WAIT flag or NCCL’sbarrier().In addition, if a callback function was added byfut.then(), it will wait untilWorkNCCL’s NCCL streams synchronize withProcessGroupNCCL’s dedicated callbackstream and invoke the callback inline after running the callback on the callback stream.fut.then() will return anotherCUDAFuture that holds the return value of thecallback and aCUDAEvent that recorded the callback stream.

  1. For CPU work,fut.done() returns true when work has been completed and value()tensors are ready.

  2. For GPU work,fut.done() returns true only whether the operation has been enqueued.

  3. For mixed CPU-GPU work (e.g. sending GPU tensors with GLOO),fut.done() returnstrue when tensors have arrived on respective nodes, but not yet necessarily synched onrespective GPUs (similarly to GPU work).

get_future_result(self:torch._C._distributed_c10d.Work)torch.Future#
Returns

Atorch.futures.Future object of int type which maps to the enum type of WorkResultAs an example, a future object can be retrievedbyfut=process_group.allreduce(tensor).get_future_result().

Example::

users can usefut.wait() to blocking wait for the completion of the work andget the WorkResult byfut.value().Also, users can usefut.then(call_back_func) to register a callback function to be calledwhen the work is completed, without blocking the current thread.

Warning

get_future_result API supports NCCL

is_completed(self:torch._C._distributed_c10d.Work)bool#
is_success(self:torch._C._distributed_c10d.Work)bool#
result(self:torch._C._distributed_c10d.Work)list[torch.Tensor]#
source_rank(self:torch._C._distributed_c10d.Work)int#
synchronize(self:torch._C._distributed_c10d.Work)None#
staticunbox(arg0:object)torch._C._distributed_c10d.Work#
wait(self:torch._C._distributed_c10d.Work,timeout:datetime.timedelta=datetime.timedelta(0))bool#
Returns

true/false.

Example::
try:

work.wait(timeout)

except:

# some handling

Warning

In normal cases, users do not need to set the timeout.calling wait() is the same as calling synchronize():Letting the current stream block on the completion of the NCCL work.However, if timeout is set, it will block the CPU thread until the NCCL work is completedor timed out. If timeout, exception will be thrown.

classtorch.distributed.ReduceOp#

An enum-like class for available reduction operations:SUM,PRODUCT,MIN,MAX,BAND,BOR,BXOR, andPREMUL_SUM.

BAND,BOR, andBXOR reductions are not available whenusing theNCCL backend.

AVG divides values by the world size before summing across ranks.AVG is only available with theNCCL backend,and only for NCCL versions 2.10 or later.

PREMUL_SUM multiplies inputs by a given scalar locally before reduction.PREMUL_SUM is only available with theNCCL backend,and only available for NCCL versions 2.11 or later. Users are supposed tousetorch.distributed._make_nccl_premul_sum.

Additionally,MAX,MIN andPRODUCT are not supported for complex tensors.

The values of this class can be accessed as attributes, e.g.,ReduceOp.SUM.They are used in specifying strategies for reduction collectives, e.g.,reduce().

This class does not support__members__ property.

classtorch.distributed.reduce_op#

Deprecated enum-like class for reduction operations:SUM,PRODUCT,MIN, andMAX.

ReduceOp is recommended to use instead.

Distributed Key-Value Store#

The distributed package comes with a distributed key-value store, which can beused to share information between processes in the group as well as toinitialize the distributed package intorch.distributed.init_process_group() (by explicitly creating the storeas an alternative to specifyinginit_method.) There are 3 choices forKey-Value Stores:TCPStore,FileStore, andHashStore.

classtorch.distributed.Store#

Base class for all store implementations, such as the 3 provided by PyTorchdistributed: (TCPStore,FileStore,andHashStore).

__init__(self:torch._C._distributed_c10d.Store)None#
add(self:torch._C._distributed_c10d.Store,arg0:str,arg1:SupportsInt)int#

The first call to add for a givenkey creates a counter associatedwithkey in the store, initialized toamount. Subsequent calls to addwith the samekey increment the counter by the specifiedamount.Callingadd() with a key that has alreadybeen set in the store byset() will resultin an exception.

Parameters
  • key (str) – The key in the store whose counter will be incremented.

  • amount (int) – The quantity by which the counter will be incremented.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>># Using TCPStore as an example, other store types can also be used>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>>store.add("first_key",1)>>>store.add("first_key",6)>>># Should return 7>>>store.get("first_key")
append(self:torch._C._distributed_c10d.Store,arg0:str,arg1:str)None#

Append the key-value pair into the store based on the suppliedkey andvalue. Ifkey does not exists in the store, it will be created.

Parameters
  • key (str) – The key to be appended to the store.

  • value (str) – The value associated withkey to be added to the store.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>>store.append("first_key","po")>>>store.append("first_key","tato")>>># Should return "potato">>>store.get("first_key")
check(self:torch._C._distributed_c10d.Store,arg0:collections.abc.Sequence[str])bool#

The call to check whether a given list ofkeys have value stored inthe store. This call immediately returns in normal cases but still suffersfrom some edge deadlock cases, e.g, calling check after TCPStore has been destroyed.Callingcheck() with a list of keys thatone wants to check whether stored in the store or not.

Parameters

keys (list[str]) – The keys to query whether stored in the store.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>># Using TCPStore as an example, other store types can also be used>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>>store.add("first_key",1)>>># Should return 7>>>store.check(["first_key"])
clone(self:torch._C._distributed_c10d.Store)torch._C._distributed_c10d.Store#

Clones the store and returns a new object that points to the same underlyingstore. The returned store can be used concurrently with the original object.This is intended to provide a safe way to use a store from multiple threads bycloning one store per thread.

compare_set(self:torch._C._distributed_c10d.Store,arg0:str,arg1:str,arg2:str)bytes#

Inserts the key-value pair into the store based on the suppliedkey andperforms comparison betweenexpected_value anddesired_value before inserting.desired_valuewill only be set ifexpected_value for thekey already exists in the store or ifexpected_valueis an empty string.

Parameters
  • key (str) – The key to be checked in the store.

  • expected_value (str) – The value associated withkey to be checked before insertion.

  • desired_value (str) – The value associated withkey to be added to the store.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>>store.set("key","first_value")>>>store.compare_set("key","first_value","second_value")>>># Should return "second_value">>>store.get("key")
delete_key(self:torch._C._distributed_c10d.Store,arg0:str)bool#

Deletes the key-value pair associated withkey from the store. Returnstrue if the key was successfully deleted, andfalse if it was not.

Warning

Thedelete_key API is only supported by theTCPStore andHashStore. Using this APIwith theFileStore will result in an exception.

Parameters

key (str) – The key to be deleted from the store

Returns

True ifkey was deleted, otherwiseFalse.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>># Using TCPStore as an example, HashStore can also be used>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>>store.set("first_key")>>># This should return true>>>store.delete_key("first_key")>>># This should return false>>>store.delete_key("bad_key")
get(self:torch._C._distributed_c10d.Store,arg0:str)bytes#

Retrieves the value associated with the givenkey in the store. Ifkey is notpresent in the store, the function will wait fortimeout, which is definedwhen initializing the store, before throwing an exception.

Parameters

key (str) – The function will return the value associated with this key.

Returns

Value associated withkey ifkey is in the store.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>>store.set("first_key","first_value")>>># Should return "first_value">>>store.get("first_key")
has_extended_api(self:torch._C._distributed_c10d.Store)bool#

Returns true if the store supports extended operations.

multi_get(self:torch._C._distributed_c10d.Store,arg0:collections.abc.Sequence[str])list[bytes]#

Retrieve all values inkeys. If any key inkeys is notpresent in the store, the function will wait fortimeout

Parameters

keys (List[str]) – The keys to be retrieved from the store.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>>store.set("first_key","po")>>>store.set("second_key","tato")>>># Should return [b"po", b"tato"]>>>store.multi_get(["first_key","second_key"])
multi_set(self:torch._C._distributed_c10d.Store,arg0:collections.abc.Sequence[str],arg1:collections.abc.Sequence[str])None#

Inserts a list key-value pair into the store based on the suppliedkeys andvalues

Parameters
  • keys (List[str]) – The keys to insert.

  • values (List[str]) – The values to insert.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>>store.multi_set(["first_key","second_key"],["po","tato"])>>># Should return b"po">>>store.get("first_key")
num_keys(self:torch._C._distributed_c10d.Store)int#

Returns the number of keys set in the store. Note that this number will typicallybe one greater than the number of keys added byset()andadd() since one key is used to coordinate allthe workers using the store.

Warning

When used with theTCPStore,num_keys returns the number of keys written to the underlying file. If the store is destructed and another store is created with the same file, the original keys will be retained.

Returns

The number of keys present in the store.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>># Using TCPStore as an example, other store types can also be used>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>>store.set("first_key","first_value")>>># This should return 2>>>store.num_keys()
queue_len(self:torch._C._distributed_c10d.Store,arg0:str)int#

Returns the length of the specified queue.

If the queue doesn’t exist it returns 0.

See queue_push for more details.

Parameters

key (str) – The key of the queue to get the length.

queue_pop(self:torch._C._distributed_c10d.Store,key:str,block:bool=True)bytes#

Pops a value from the specified queue or waits until timeout if the queue is empty.

See queue_push for more details.

If block is False, a dist.QueueEmptyError will be raised if the queue is empty.

Parameters
  • key (str) – The key of the queue to pop from.

  • block (bool) – Whether to block waiting for the key or immediately return.

queue_push(self:torch._C._distributed_c10d.Store,arg0:str,arg1:str)None#

Pushes a value into the specified queue.

Using the same key for queues and set/get operations may result in unexpectedbehavior.

wait/check operations are supported for queues.

wait with queues will only wake one waiting worker rather than all.

Parameters
  • key (str) – The key of the queue to push to.

  • value (str) – The value to push into the queue.

set(self:torch._C._distributed_c10d.Store,arg0:str,arg1:str)None#

Inserts the key-value pair into the store based on the suppliedkey andvalue. Ifkey already exists in the store, it will overwrite the oldvalue with the new suppliedvalue.

Parameters
  • key (str) – The key to be added to the store.

  • value (str) – The value associated withkey to be added to the store.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>>store.set("first_key","first_value")>>># Should return "first_value">>>store.get("first_key")
set_timeout(self:torch._C._distributed_c10d.Store,arg0:datetime.timedelta)None#

Sets the store’s default timeout. This timeout is used during initialization and inwait() andget().

Parameters

timeout (timedelta) – timeout to be set in the store.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>># Using TCPStore as an example, other store types can also be used>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>>store.set_timeout(timedelta(seconds=10))>>># This will throw an exception after 10 seconds>>>store.wait(["bad_key"])
propertytimeout#

Gets the timeout of the store.

wait(*args,**kwargs)#

Overloaded function.

  1. wait(self: torch._C._distributed_c10d.Store, arg0: collections.abc.Sequence[str]) -> None

Waits for each key inkeys to be added to the store. If not all keys areset before thetimeout (set during store initialization), thenwaitwill throw an exception.

Parameters

keys (list) – List of keys on which to wait until they are set in the store.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>># Using TCPStore as an example, other store types can also be used>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>># This will throw an exception after 30 seconds>>>store.wait(["bad_key"])
  1. wait(self: torch._C._distributed_c10d.Store, arg0: collections.abc.Sequence[str], arg1: datetime.timedelta) -> None

Waits for each key inkeys to be added to the store, and throws an exceptionif the keys have not been set by the suppliedtimeout.

Parameters
  • keys (list) – List of keys on which to wait until they are set in the store.

  • timeout (timedelta) – Time to wait for the keys to be added before throwing an exception.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>># Using TCPStore as an example, other store types can also be used>>>store=dist.TCPStore("127.0.0.1",0,1,True,timedelta(seconds=30))>>># This will throw an exception after 10 seconds>>>store.wait(["bad_key"],timedelta(seconds=10))
classtorch.distributed.TCPStore#

A TCP-based distributed key-value store implementation. The server store holdsthe data, while the client stores can connect to the server store over TCP andperform actions such asset() to insert a key-valuepair,get() to retrieve a key-value pair, etc. Thereshould always be one server store initialized because the client store(s) will wait forthe server to establish a connection.

Parameters
  • host_name (str) – The hostname or IP Address the server store should run on.

  • port (int) – The port on which the server store should listen for incoming requests.

  • world_size (int,optional) – The total number of store users (number of clients + 1 for the server). Default is None (None indicates a non-fixed number of store users).

  • is_master (bool,optional) – True when initializing the server store and False for client stores. Default is False.

  • timeout (timedelta,optional) – Timeout used by the store during initialization and for methods such asget() andwait(). Default is timedelta(seconds=300)

  • wait_for_workers (bool,optional) – Whether to wait for all the workers to connect with the server store. This is only applicable when world_size is a fixed value. Default is True.

  • multi_tenant (bool,optional) – If True, allTCPStore instances in the current process with the same host/port will use the same underlyingTCPServer. Default is False.

  • master_listen_fd (int,optional) – If specified, the underlyingTCPServer will listen on this file descriptor, which must be a socket already bound toport. To bind an ephemeral port we recommend setting the port to 0 and reading.port. Default is None (meaning the server creates a new socket and attempts to bind it toport).

  • use_libuv (bool,optional) – If True, use libuv forTCPServer backend. Default is True.

Example::
>>>importtorch.distributedasdist>>>fromdatetimeimporttimedelta>>># Run on process 1 (server)>>>server_store=dist.TCPStore("127.0.0.1",1234,2,True,timedelta(seconds=30))>>># Run on process 2 (client)>>>client_store=dist.TCPStore("127.0.0.1",1234,2,False)>>># Use any of the store methods from either the client or server after initialization>>>server_store.set("first_key","first_value")>>>client_store.get("first_key")
__init__(self:torch._C._distributed_c10d.TCPStore,host_name:str,port:SupportsInt,world_size:SupportsInt|None=None,is_master:bool=False,timeout:datetime.timedelta=datetime.timedelta(seconds=300),wait_for_workers:bool=True,multi_tenant:bool=False,master_listen_fd:SupportsInt|None=None,use_libuv:bool=True)None#

Creates a new TCPStore.

propertyhost#

Gets the hostname on which the store listens for requests.

propertylibuvBackend#

Returns True if it’s using the libuv backend.

propertyport#

Gets the port number on which the store listens for requests.

classtorch.distributed.HashStore#

A thread-safe store implementation based on an underlying hashmap. This store can be usedwithin the same process (for example, by other threads), but cannot be used across processes.

Example::
>>>importtorch.distributedasdist>>>store=dist.HashStore()>>># store can be used from other threads>>># Use any of the store methods after initialization>>>store.set("first_key","first_value")
__init__(self:torch._C._distributed_c10d.HashStore)None#

Creates a new HashStore.

classtorch.distributed.FileStore#

A store implementation that uses a file to store the underlying key-value pairs.

Parameters
  • file_name (str) – path of the file in which to store the key-value pairs

  • world_size (int,optional) – The total number of processes using the store. Default is -1 (a negative value indicates a non-fixed number of store users).

Example::
>>>importtorch.distributedasdist>>>store1=dist.FileStore("/tmp/filestore",2)>>>store2=dist.FileStore("/tmp/filestore",2)>>># Use any of the store methods from either the client or server after initialization>>>store1.set("first_key","first_value")>>>store2.get("first_key")
__init__(self:torch._C._distributed_c10d.FileStore,file_name:str,world_size:SupportsInt=-1)None#

Creates a new FileStore.

propertypath#

Gets the path of the file used by FileStore to store key-value pairs.

classtorch.distributed.PrefixStore#

A wrapper around any of the 3 key-value stores (TCPStore,FileStore, andHashStore)that adds a prefix to each key inserted to the store.

Parameters
  • prefix (str) – The prefix string that is prepended to each key before being inserted into the store.

  • store (torch.distributed.store) – A store object that forms the underlying key-value store.

__init__(self:torch._C._distributed_c10d.PrefixStore,prefix:str,store:torch._C._distributed_c10d.Store)None#

Creates a new PrefixStore.

propertyunderlying_store#

Gets the underlying store object that PrefixStore wraps around.

Profiling Collective Communication#

Note that you can usetorch.profiler (recommended, only available after 1.8.1) ortorch.autograd.profiler to profile collective communication and point-to-point communication APIs mentioned here. All out-of-the-box backends (gloo,nccl,mpi) are supported and collective communication usage will be rendered as expected in profiling output/traces. Profiling your code is the same as any regular torch operator:

importtorchimporttorch.distributedasdistwithtorch.profiler():tensor=torch.randn(20,10)dist.all_reduce(tensor)

Please refer to theprofiler documentation for a full overview of profiler features.

Multi-GPU collective functions#

Warning

The multi-GPU functions (which stand for multiple GPUs per CPU thread) aredeprecated. As of today, PyTorch Distributed’s preferred programming modelis one device per thread, as exemplified by the APIs in this document. Ifyou are a backend developer and want to support multiple devices per thread,please contact PyTorch Distributed’s maintainers.

Object collectives#

Warning

Object collectives have a number of serious limitations. Read further to determineif they are safe to use for your use case.

Object collectives are a set of collective-like operations that work on arbitraryPython objects, as long as they can be pickled. There are various collective patternsimplemented (e.g. broadcast, all_gather, …) but they each roughly follow this pattern:

  1. convert the input object into a pickle (raw bytes), then shove it into a byte tensor

  2. communicate the size of this byte tensor to peers (first collective operation)

  3. allocate appropriately sized tensor to perform the real collective

  4. communicate the object data (second collective operation)

  5. convert raw data back into Python (unpickle)

Object collectives sometimes have surprising performance or memory characteristics that lead tolong runtimes or OOMs, and thus they should be used with caution. Here are some common issues.

Asymmetric pickle/unpickle time - Pickling objects can be slow, depending on the number, type and size of the objects.When the collective has a fan-in (e.g. gather_object), the receiving rank(s) must unpickle N times more objects thanthe sending rank(s) had to pickle, which can cause other ranks to time out on their next collective.

Inefficient tensor communication - Tensors should be sent via regular collective APIs, not object collective APIs.It is possible to send Tensors via object collective APIs, but they will be serialized and deserialized (including aCPU-sync and device-to-host copy in the case of non-CPU tensors), and in almost every case other than debugging ortroubleshooting code, it would be worth the trouble to refactor the code to use non-object collectives instead.

Unexpected tensor devices - If you still want to send tensors via object collectives, there is another aspectspecific to cuda (and possibly other accelerators) tensors. If you pickle a tensor that is currently oncuda:3, andthen unpickle it, you will get another tensor oncuda:3regardless of which process you are on, or which CUDA deviceis the ‘default’ device for that process. With regular tensor collective APIs, ‘output tensors’ will always be on thesame, local device, which is generally what you’d expect.

Unpickling a tensor will implicitly activate a CUDA context if it is the firsttime a GPU is used by the process, which can waste significant amounts of GPU memory. This issue can be avoided bymoving tensors to CPU before passing them as inputs to an object collective.

Third-party backends#

Besides the builtin GLOO/MPI/NCCL backends, PyTorch distributed supportsthird-party backends through a run-time register mechanism.For references on how to develop a third-party backend through C++ Extension,please refer toTutorials - Custom C++ and CUDA Extensions andtest/cpp_extensions/cpp_c10d_extension.cpp. The capability of third-partybackends are decided by their own implementations.

The new backend derives fromc10d::ProcessGroup and registers the backendname and the instantiating interface throughtorch.distributed.Backend.register_backend()when imported.

When manually importing this backend and invokingtorch.distributed.init_process_group()with the corresponding backend name, thetorch.distributed package runs onthe new backend.

Warning

The support of third-party backend is experimental and subject to change.

Launch utility#

Thetorch.distributed package also provides a launch utility intorch.distributed.launch. This helper utility can be used to launchmultiple processes per node for distributed training.

Moduletorch.distributed.launch.

torch.distributed.launch is a module that spawns up multiple distributedtraining processes on each of the training nodes.

Warning

This module is going to be deprecated in favor oftorchrun.

The utility can be used for single-node distributed training, in which one ormore processes per node will be spawned. The utility can be used for eitherCPU training or GPU training. If the utility is used for GPU training,each distributed process will be operating on a single GPU. This can achievewell-improved single-node training performance. It can also be used inmulti-node distributed training, by spawning up multiple processes on each nodefor well-improved multi-node distributed training performance as well.This will especially be beneficial for systems with multiple Infinibandinterfaces that have direct-GPU support, since all of them can be utilized foraggregated communication bandwidth.

In both cases of single-node distributed training or multi-node distributedtraining, this utility will launch the given number of processes per node(--nproc-per-node). If used for GPU training, this number needs to be lessor equal to the number of GPUs on the current system (nproc_per_node),and each process will be operating on a single GPU fromGPU 0 toGPU (nproc_per_node - 1).

How to use this module:

  1. Single-Node multi-process distributed training

python-mtorch.distributed.launch--nproc-per-node=NUM_GPUS_YOU_HAVEYOUR_TRAINING_SCRIPT.py(--arg1--arg2--arg3andallotherargumentsofyourtrainingscript)
  1. Multi-Node multi-process distributed training: (e.g. two nodes)

Node 1:(IP: 192.168.1.1, and has a free port: 1234)

python-mtorch.distributed.launch--nproc-per-node=NUM_GPUS_YOU_HAVE--nnodes=2--node-rank=0--master-addr="192.168.1.1"--master-port=1234YOUR_TRAINING_SCRIPT.py(--arg1--arg2--arg3andallotherargumentsofyourtrainingscript)

Node 2:

python-mtorch.distributed.launch--nproc-per-node=NUM_GPUS_YOU_HAVE--nnodes=2--node-rank=1--master-addr="192.168.1.1"--master-port=1234YOUR_TRAINING_SCRIPT.py(--arg1--arg2--arg3andallotherargumentsofyourtrainingscript)
  1. To look up what optional arguments this module offers:

python-mtorch.distributed.launch--help

Important Notices:

1. This utility and multi-process distributed (single-node ormulti-node) GPU training currently only achieves the best performance usingthe NCCL distributed backend. Thus NCCL backend is the recommended backend touse for GPU training.

2. In your training program, you must parse the command-line argument:--local-rank=LOCAL_PROCESS_RANK, which will be provided by this module.If your training program uses GPUs, you should ensure that your code onlyruns on the GPU device of LOCAL_PROCESS_RANK. This can be done by:

Parsing the local_rank argument

>>>importargparse>>>parser=argparse.ArgumentParser()>>>parser.add_argument("--local-rank","--local_rank",type=int)>>>args=parser.parse_args()

Set your device to local rank using either

>>>torch.cuda.set_device(args.local_rank)# before your code runs

or

>>>withtorch.cuda.device(args.local_rank):>>># your code to run>>>...

Changed in version 2.0.0:The launcher will passes the--local-rank=<rank> argument to your script.From PyTorch 2.0.0 onwards, the dashed--local-rank is preferred over thepreviously used underscored--local_rank.

For backward compatibility, it may be necessary for users to handle bothcases in their argument parsing code. This means including both"--local-rank"and"--local_rank" in the argument parser. If only"--local_rank" isprovided, the launcher will trigger an error: “error: unrecognized arguments:–local-rank=<rank>”. For training code that only supports PyTorch 2.0.0+,including"--local-rank" should be sufficient.

3. In your training program, you are supposed to call the following functionat the beginning to start the distributed backend. It is strongly recommendedthatinit_method=env://. Other init methods (e.g.tcp://) may work,butenv:// is the one that is officially supported by this module.

>>>torch.distributed.init_process_group(backend='YOUR BACKEND',>>>init_method='env://')

4. In your training program, you can either use regular distributed functionsor usetorch.nn.parallel.DistributedDataParallel() module. If yourtraining program uses GPUs for training and you would like to usetorch.nn.parallel.DistributedDataParallel() module,here is how to configure it.

>>>model=torch.nn.parallel.DistributedDataParallel(model,>>>device_ids=[args.local_rank],>>>output_device=args.local_rank)

Please ensure thatdevice_ids argument is set to be the only GPU device idthat your code will be operating on. This is generally the local rank of theprocess. In other words, thedevice_ids needs to be[args.local_rank],andoutput_device needs to beargs.local_rank in order to use thisutility

5. Another way to passlocal_rank to the subprocesses via environment variableLOCAL_RANK. This behavior is enabled when you launch the script with--use-env=True. You must adjust the subprocess example above to replaceargs.local_rank withos.environ['LOCAL_RANK']; the launcherwill not pass--local-rank when you specify this flag.

Warning

local_rank is NOT globally unique: it is only unique per processon a machine. Thus, don’t use it to decide if you should, e.g.,write to a networked filesystem. Seepytorch/pytorch#12042 for an example ofhow things can go wrong if you don’t do this correctly.

Spawn utility#

TheMultiprocessing package - torch.multiprocessing package also provides aspawnfunction intorch.multiprocessing.spawn(). This helper functioncan be used to spawn multiple processes. It works by passing in thefunction that you want to run and spawns N processes to run it. Thiscan be used for multiprocess distributed training as well.

For references on how to use it, please refer toPyTorch example - ImageNetimplementation

Note that this function requires Python 3.4 or higher.

Debuggingtorch.distributed applications#

Debugging distributed applications can be challenging due to hard to understand hangs, crashes, or inconsistent behavior across ranks.torch.distributed providesa suite of tools to help debug training applications in a self-serve fashion:

Python Breakpoint#

It is extremely convenient to use python’s debugger in a distributed environment, but because it does not work out of the box many people do not use it at all.PyTorch offers a customized wrapper around pdb that streamlines the process.

torch.distributed.breakpoint makes this process easy. Internally, it customizespdb’s breakpoint behavior in two ways but otherwise behaves as normalpdb.

  1. Attaches the debugger only on one rank (specified by the user).

  2. Ensures all other ranks stop, by using atorch.distributed.barrier() that will release once the debugged rank issues acontinue

  3. Reroutes stdin from the child process such that it connects to your terminal.

To use it, simply issuetorch.distributed.breakpoint(rank) on all ranks, using the same value forrank in each case.

Monitored Barrier#

As of v1.10,torch.distributed.monitored_barrier() exists as an alternative totorch.distributed.barrier() which fails with helpful information about which rank may be faultywhen crashing, i.e. not all ranks calling intotorch.distributed.monitored_barrier() within the provided timeout.torch.distributed.monitored_barrier() implements a host-sidebarrier usingsend/recv communication primitives in a process similar to acknowledgements, allowing rank 0 to report which rank(s) failed to acknowledgethe barrier in time. As an example, consider the following function where rank 1 fails to call intotorch.distributed.monitored_barrier() (in practice this could be dueto an application bug or hang in a previous collective):

importosfromdatetimeimporttimedeltaimporttorchimporttorch.distributedasdistimporttorch.multiprocessingasmpdefworker(rank):dist.init_process_group("nccl",rank=rank,world_size=2)# monitored barrier requires gloo process group to perform host-side sync.group_gloo=dist.new_group(backend="gloo")ifranknotin[1]:dist.monitored_barrier(group=group_gloo,timeout=timedelta(seconds=2))if__name__=="__main__":os.environ["MASTER_ADDR"]="localhost"os.environ["MASTER_PORT"]="29501"mp.spawn(worker,nprocs=2,args=())

The following error message is produced on rank 0, allowing the user to determine which rank(s) may be faulty and investigate further:

RuntimeError:Rank1failedtopassmonitoredBarrierin2000msOriginalexception:[gloo/transport/tcp/pair.cc:598]Connectionclosedbypeer[2401:db00:eef0:1100:3560:0:1c05:25d]:8594

TORCH_DISTRIBUTED_DEBUG#

WithTORCH_CPP_LOG_LEVEL=INFO, the environment variableTORCH_DISTRIBUTED_DEBUG can be used to trigger additional useful logging and collective synchronization checks to ensure all ranksare synchronized appropriately.TORCH_DISTRIBUTED_DEBUG can be set to eitherOFF (default),INFO, orDETAIL depending on the debugging levelrequired. Please note that the most verbose option,DETAIL may impact the application performance and thus should only be used when debugging issues.

SettingTORCH_DISTRIBUTED_DEBUG=INFO will result in additional debug logging when models trained withtorch.nn.parallel.DistributedDataParallel() are initialized, andTORCH_DISTRIBUTED_DEBUG=DETAIL will additionally log runtime performance statistics a select number of iterations. These runtime statisticsinclude data such as forward time, backward time, gradient communication time, etc. As an example, given the following application:

importosimporttorchimporttorch.distributedasdistimporttorch.multiprocessingasmpclassTwoLinLayerNet(torch.nn.Module):def__init__(self):super().__init__()self.a=torch.nn.Linear(10,10,bias=False)self.b=torch.nn.Linear(10,1,bias=False)defforward(self,x):a=self.a(x)b=self.b(x)return(a,b)defworker(rank):dist.init_process_group("nccl",rank=rank,world_size=2)torch.cuda.set_device(rank)print("init model")model=TwoLinLayerNet().cuda()print("init ddp")ddp_model=torch.nn.parallel.DistributedDataParallel(model,device_ids=[rank])inp=torch.randn(10,10).cuda()print("train")for_inrange(20):output=ddp_model(inp)loss=output[0]+output[1]loss.sum().backward()if__name__=="__main__":os.environ["MASTER_ADDR"]="localhost"os.environ["MASTER_PORT"]="29501"os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"os.environ["TORCH_DISTRIBUTED_DEBUG"]="DETAIL"# set to DETAIL for runtime logging.mp.spawn(worker,nprocs=2,args=())

The following logs are rendered at initialization time:

I060716:10:35.739390515217logger.cpp:173][Rank0]:DDPInitializedwith:broadcast_buffers:1bucket_cap_bytes:26214400find_unused_parameters:0gradient_as_bucket_view:0is_multi_device_module:0iteration:0num_parameter_tensors:2output_device:0rank:0total_parameter_size_bytes:440world_size:2backend_name:ncclbucket_sizes:440cuda_visible_devices:N/Adevice_ids:0dtypes:floatmaster_addr:localhostmaster_port:29501module_name:TwoLinLayerNetnccl_async_error_handling:N/Anccl_blocking_wait:N/Anccl_debug:WARNnccl_ib_timeout:N/Anccl_nthreads:N/Anccl_socket_ifname:N/Atorch_distributed_debug:INFO

The following logs are rendered during runtime (whenTORCH_DISTRIBUTED_DEBUG=DETAIL is set):

I060716:18:58.085681544067logger.cpp:344][Rank1/2]TrainingTwoLinLayerNetunused_parameter_size=0Avgforwardcomputetime:40838608Avgbackwardcomputetime:5983335Avgbackwardcomm.time:4326421Avgbackwardcomm/compoverlaptime:4207652I060716:18:58.085693544066logger.cpp:344][Rank0/2]TrainingTwoLinLayerNetunused_parameter_size=0Avgforwardcomputetime:42850427Avgbackwardcomputetime:3885553Avgbackwardcomm.time:2357981Avgbackwardcomm/compoverlaptime:2234674

In addition,TORCH_DISTRIBUTED_DEBUG=INFO enhances crash logging intorch.nn.parallel.DistributedDataParallel() due to unused parameters in the model. Currently,find_unused_parameters=Truemust be passed intotorch.nn.parallel.DistributedDataParallel() initialization if there are parameters that may be unused in the forward pass, and as of v1.10, all model outputs are requiredto be used in loss computation astorch.nn.parallel.DistributedDataParallel() does not support unused parameters in the backwards pass. These constraints are challenging especially for largermodels, thus when crashing with an error,torch.nn.parallel.DistributedDataParallel() will log the fully qualified name of all parameters that went unused. For example, in the above application,if we modifyloss to be instead computed asloss=output[1], thenTwoLinLayerNet.a does not receive a gradient in the backwards pass, andthus results inDDP failing. On a crash, the user is passed information about parameters which went unused, which may be challenging to manually find for large models:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and bymaking sure all `forward` function outputs participate in calculating loss.If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).Parameters which did not receive grad for rank 0: a.weightParameter indices which did not receive grad for rank 0: 0

SettingTORCH_DISTRIBUTED_DEBUG=DETAIL will trigger additional consistency and synchronization checks on every collective call issued by the usereither directly or indirectly (such as DDPallreduce). This is done by creating a wrapper process group that wraps all process groups returned bytorch.distributed.init_process_group() andtorch.distributed.new_group() APIs. As a result, these APIs will return a wrapper process group that can be used exactly like a regular processgroup, but performs consistency checks before dispatching the collective to an underlying process group. Currently, these checks include atorch.distributed.monitored_barrier(),which ensures all ranks complete their outstanding collective calls and reports ranks which are stuck. Next, the collective itself is checked for consistency byensuring all collective functions match and are called with consistent tensor shapes. If this is not the case, a detailed error report is included when theapplication crashes, rather than a hang or uninformative error message. As an example, consider the following function which has mismatched input shapes intotorch.distributed.all_reduce():

importtorchimporttorch.distributedasdistimporttorch.multiprocessingasmpdefworker(rank):dist.init_process_group("nccl",rank=rank,world_size=2)torch.cuda.set_device(rank)tensor=torch.randn(10ifrank==0else20).cuda()dist.all_reduce(tensor)torch.cuda.synchronize(device=rank)if__name__=="__main__":os.environ["MASTER_ADDR"]="localhost"os.environ["MASTER_PORT"]="29501"os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"os.environ["TORCH_DISTRIBUTED_DEBUG"]="DETAIL"mp.spawn(worker,nprocs=2,args=())

With theNCCL backend, such an application would likely result in a hang which can be challenging to root-cause in nontrivial scenarios. If the user enablesTORCH_DISTRIBUTED_DEBUG=DETAIL and reruns the application, the following error message reveals the root cause:

work=default_pg.allreduce([tensor],opts)RuntimeError:ErrorwhenverifyingshapetensorsforcollectiveALLREDUCEonrank0.Thislikelyindicatesthatinputshapesintothecollectivearemismatchedacrossranks.Gotshapes:10[torch.LongTensor{1}]

Note

For fine-grained control of the debug level during runtime the functionstorch.distributed.set_debug_level(),torch.distributed.set_debug_level_from_env(), andtorch.distributed.get_debug_level() can also be used.

In addition,TORCH_DISTRIBUTED_DEBUG=DETAIL can be used in conjunction withTORCH_SHOW_CPP_STACKTRACES=1 to log the entire callstack when a collective desynchronization is detected. Thesecollective desynchronization checks will work for all applications that usec10d collective calls backed by process groups created with thetorch.distributed.init_process_group() andtorch.distributed.new_group() APIs.

Logging#

In addition to explicit debugging support viatorch.distributed.monitored_barrier() andTORCH_DISTRIBUTED_DEBUG, the underlying C++ library oftorch.distributed also outputs logmessages at various levels. These messages can be helpful to understand the execution state of a distributed training job and to troubleshoot problems such as network connection failures. Thefollowing matrix shows how the log level can be adjusted via the combination ofTORCH_CPP_LOG_LEVEL andTORCH_DISTRIBUTED_DEBUG environment variables.

TORCH_CPP_LOG_LEVEL

TORCH_DISTRIBUTED_DEBUG

Effective Log Level

ERROR

ignored

Error

WARNING

ignored

Warning

INFO

ignored

Info

INFO

INFO

Debug

INFO

DETAIL

Trace (a.k.a. All)

Distributed components raise custom Exception types derived fromRuntimeError:

  • torch.distributed.DistError: This is the base type of all distributed exceptions.

  • torch.distributed.DistBackendError: This exception is thrown when a backend-specific error occurs. For example, iftheNCCL backend is used and the user attempts to use a GPU that is not available to theNCCL library.

  • torch.distributed.DistNetworkError: This exception is thrown when networkinglibraries encounter errors (ex: Connection reset by peer)

  • torch.distributed.DistStoreError: This exception is thrown when the Store encountersan error (ex: TCPStore timeout)

classtorch.distributed.DistError#

Exception raised when an error occurs in the distributed library

classtorch.distributed.DistBackendError#

Exception raised when a backend error occurs in distributed

classtorch.distributed.DistNetworkError#

Exception raised when a network error occurs in distributed

classtorch.distributed.DistStoreError#

Exception raised when an error occurs in the distributed store

If you are running single node training, it may be convenient to interactively breakpoint your script. We offer a way to conveniently breakpoint a single rank:

torch.distributed.breakpoint(rank=0,skip=0,timeout_s=3600)[source]#

Set a breakpoint, but only on a single rank. All other ranks will wait for you to bedone with the breakpoint before continuing.

Parameters
  • rank (int) – Which rank to break on. Default:0

  • skip (int) – Skip the firstskip calls to this breakpoint. Default:0.

On this page