Distributed RPC Framework#
Created On: Nov 14, 2019 | Last Updated On: Jul 09, 2025
The distributed RPC framework provides mechanisms for multi-machine modeltraining through a set of primitives to allow for remote communication, and ahigher-level API to automatically differentiate models split across severalmachines.
Warning
APIs in the RPC package are stable and in maintenance mode.
Warning
CUDA support is abeta feature.Not all features of the RPC package are yet compatible with CUDA support andthus their use is discouraged. These unsupported features include: RRefs,JIT compatibility, dist autograd and dist optimizer, and profiling.
Note
Please refer toPyTorchDistributedOverview<https://pytorch.org/tutorials/beginner/dist_overview.html>__for a brief introduction to all features related to distributed training.
Basics#
The distributed RPC framework makes it easy to run functions remotely, supportsreferencing remote objects without copying the real data around, and providesautograd and optimizer APIs to transparently run backward and update parametersacross RPC boundaries. These features can be categorized into four sets of APIs.
Remote Procedure Call (RPC) supports running a function on the specifieddestination worker with the given arguments and getting the return value backor creating a reference to the return value. There are three main RPC APIs:
rpc_sync()(synchronous),rpc_async()(asynchronous), andremote()(asynchronous and returns a referenceto the remote return value). Use the synchronous API if the user code cannotproceed without the return value. Otherwise, use the asynchronous API to geta future, and wait on the future when the return value is needed on thecaller. Theremote()API is useful when therequirement is to create something remotely but never need to fetch it tothe caller. Imagine the case that a driver process is setting up a parameterserver and a trainer. The driver can create an embedding table on theparameter server and then share the reference to the embedding table with thetrainer, but itself will never use the embedding table locally. In this case,rpc_sync()andrpc_async()are no longer appropriate, as theyalways imply that the return value will be returned to the callerimmediately or in the future.Remote Reference (RRef) serves as a distributed shared pointer to a localor remote object. It can be shared with other workers and reference countingwill be handled transparently. Each RRef only has one owner and the objectonly lives on that owner. Non-owner workers holding RRefs can get copies ofthe object from the owner by explicitly requesting it. This is useful whena worker needs to access some data object, but itself is neither the creator(the caller of
remote()) or the owner of theobject. The distributed optimizer, as we will discuss below, is one exampleof such use cases.Distributed Autograd stitches together local autograd engines on all theworkers involved in the forward pass, and automatically reach out to themduring the backward pass to compute gradients. This is especially helpful ifthe forward pass needs to span multiple machines when conducting, e.g.,distributed model parallel training, parameter-server training, etc. Withthis feature, user code no longer needs to worry about how to send gradientsacross RPC boundaries and in which order should the local autograd enginesbe launched, which can become quite complicated where there are nested andinter-dependent RPC calls in the forward pass.
Distributed Optimizer’s constructor takes a
Optimizer()(e.g.,SGD(),Adagrad(), etc.) and a list of parameter RRefs, creates anOptimizer()instance on each distinct RRef owner, andupdates parameters accordingly when runningstep(). When you havedistributed forward and backward passes, parameters and gradients will bescattered across multiple workers, and hence it requires an optimizer on eachof the involved workers. Distributed Optimizer wraps all those localoptimizers into one, and provides a concise constructor andstep()API.
RPC#
Before using RPC and distributed autograd primitives, initialization must takeplace. To initialize the RPC framework we need to useinit_rpc() which would initialize the RPCframework, RRef framework and distributed autograd.
- torch.distributed.rpc.init_rpc(name,backend=None,rank=-1,world_size=None,rpc_backend_options=None)[source]#
Initializes RPC primitives such as the local RPC agentand distributed autograd, which immediately makes the currentprocess ready to send and receive RPCs.
- Parameters
name (str) – a globally unique name of this node. (e.g.,
Trainer3,ParameterServer2,Master,Worker1)Name can only contain number, alphabet, underscore, colon,and/or dash, and must be shorter than 128 characters.backend (BackendType,optional) – The type of RPC backendimplementation. Supported values is
BackendType.TENSORPIPE(the default).SeeBackends for more information.rank (int) – a globally unique id/rank of this node.
world_size (int) – The number of workers in the group.
rpc_backend_options (RpcBackendOptions,optional) – The optionspassed to the RpcAgent constructor. It must be an agent-specificsubclass of
RpcBackendOptionsand contains agent-specific initialization configurations. Bydefault, for all agents, it sets the default timeout to 60seconds and performs the rendezvous with an underlying processgroup initialized usinginit_method="env://",meaning that environment variablesMASTER_ADDRandMASTER_PORTneed to be set properly. SeeBackends for more information and find which optionsare available.
The following APIs allow users to remotely execute functions as well as createreferences (RRefs) to remote data objects. In these APIs, when passing aTensor as an argument or a return value, the destination worker will try tocreate aTensor with the same meta (i.e., shape, stride, etc.). Weintentionally disallow transmitting CUDA tensors because it might crash if thedevice lists on source and destination workers do not match. In such cases,applications can always explicitly move the input tensors to CPU on the callerand move it to the desired devices on the callee if necessary.
- torch.distributed.rpc.rpc_sync(to,func,args=None,kwargs=None,timeout=-1.0)[source]#
Make a blocking RPC call to run function
funcon workerto. RPCmessages are sent and received in parallel to execution of Python code. Thismethod is thread-safe.- Parameters
to (str orWorkerInfo orint) – name/rank/
WorkerInfoof the destination worker.func (Callable) – a callable function, such as Python callables, builtinoperators (e.g.
add()) and annotatedTorchScript functions.args (tuple) – the argument tuple for the
funcinvocation.kwargs (dict) – is a dictionary of keyword arguments for the
funcinvocation.timeout (float,optional) – timeout in seconds to use for this RPC. Ifthe RPC does not complete in this amount oftime, an exception indicating it hastimed out will be raised. A value of 0indicates an infinite timeout, i.e. a timeouterror will never be raised. If not provided,the default value set during initializationor with
_set_rpc_timeoutis used.
- Returns
Returns the result of running
funcwithargsandkwargs.
- Example::
Make sure that
MASTER_ADDRandMASTER_PORTare set properlyon both workers. Refer toinit_process_group()API for more details. For example,export MASTER_ADDR=localhostexport MASTER_PORT=5678
Then run the following code in two different processes:
>>># On worker 0:>>>importtorch>>>importtorch.distributed.rpcasrpc>>>rpc.init_rpc("worker0",rank=0,world_size=2)>>>ret=rpc.rpc_sync("worker1",torch.add,args=(torch.ones(2),3))>>>rpc.shutdown()
>>># On worker 1:>>>importtorch.distributed.rpcasrpc>>>rpc.init_rpc("worker1",rank=1,world_size=2)>>>rpc.shutdown()
Below is an example of running a TorchScript function using RPC.
>>># On both workers:>>>@torch.jit.script>>>defmy_script_add(tensor:torch.Tensor,scalar:int):>>>returntorch.add(tensor,scalar)
>>># On worker 0:>>>importtorch.distributed.rpcasrpc>>>rpc.init_rpc("worker0",rank=0,world_size=2)>>>ret=rpc.rpc_sync("worker1",my_script_add,args=(torch.ones(2),3))>>>rpc.shutdown()
>>># On worker 1:>>>importtorch.distributed.rpcasrpc>>>rpc.init_rpc("worker1",rank=1,world_size=2)>>>rpc.shutdown()
- torch.distributed.rpc.rpc_async(to,func,args=None,kwargs=None,timeout=-1.0)[source]#
Make a non-blocking RPC call to run function
funcon workerto. RPCmessages are sent and received in parallel to execution of Python code. Thismethod is thread-safe. This method will immediately return aFuturethat can be awaited on.- Parameters
to (str orWorkerInfo orint) – name/rank/
WorkerInfoof the destination worker.func (Callable) – a callable function, such as Python callables, builtinoperators (e.g.
add()) and annotatedTorchScript functions.args (tuple) – the argument tuple for the
funcinvocation.kwargs (dict) – is a dictionary of keyword arguments for the
funcinvocation.timeout (float,optional) – timeout in seconds to use for this RPC. Ifthe RPC does not complete in this amount oftime, an exception indicating it hastimed out will be raised. A value of 0indicates an infinite timeout, i.e. a timeouterror will never be raised. If not provided,the default value set during initializationor with
_set_rpc_timeoutis used.
- Returns
Returns a
Futureobject that can be waitedon. When completed, the return value offunconargsandkwargscan be retrieved from theFutureobject.
Warning
Using GPU tensors as arguments or return values of
funcis notsupported since we don’t support sending GPU tensors over the wire. Youneed to explicitly copy GPU tensors to CPU before using them asarguments or return values offunc.Warning
The
rpc_asyncAPI does not copy storages of argument tensors untilsending them over the wire, which could be done by a different threaddepending on the RPC backend type. The caller should make sure that thecontents of those tensors stay intact until the returnedFuturecompletes.- Example::
Make sure that
MASTER_ADDRandMASTER_PORTare set properlyon both workers. Refer toinit_process_group()API for more details. For example,export MASTER_ADDR=localhostexport MASTER_PORT=5678
Then run the following code in two different processes:
>>># On worker 0:>>>importtorch>>>importtorch.distributed.rpcasrpc>>>rpc.init_rpc("worker0",rank=0,world_size=2)>>>fut1=rpc.rpc_async("worker1",torch.add,args=(torch.ones(2),3))>>>fut2=rpc.rpc_async("worker1",min,args=(1,2))>>>result=fut1.wait()+fut2.wait()>>>rpc.shutdown()
>>># On worker 1:>>>importtorch.distributed.rpcasrpc>>>rpc.init_rpc("worker1",rank=1,world_size=2)>>>rpc.shutdown()
Below is an example of running a TorchScript function using RPC.
>>># On both workers:>>>@torch.jit.script>>>defmy_script_add(tensor:torch.Tensor,scalar:int):>>>returntorch.add(tensor,scalar)
>>># On worker 0:>>>importtorch.distributed.rpcasrpc>>>rpc.init_rpc("worker0",rank=0,world_size=2)>>>fut=rpc.rpc_async("worker1",my_script_add,args=(torch.ones(2),3))>>>ret=fut.wait()>>>rpc.shutdown()
>>># On worker 1:>>>importtorch.distributed.rpcasrpc>>>rpc.init_rpc("worker1",rank=1,world_size=2)>>>rpc.shutdown()
- torch.distributed.rpc.remote(to,func,args=None,kwargs=None,timeout=-1.0)[source]#
Make a remote call to run
funcon workertoand return anRRefto the result value immediately.Workertowill be the owner of the returnedRRef, and the worker callingremoteisa user. The owner manages the global reference count of itsRRef, and the ownerRRefis only destructed when globally thereare no living references to it.- Parameters
to (str orWorkerInfo orint) – name/rank/
WorkerInfoof the destination worker.func (Callable) – a callable function, such as Python callables, builtinoperators (e.g.
add()) and annotatedTorchScript functions.args (tuple) – the argument tuple for the
funcinvocation.kwargs (dict) – is a dictionary of keyword arguments for the
funcinvocation.timeout (float,optional) – timeout in seconds for this remote call. If thecreation of this
RRefon workertois not successfully processed on thisworker within this timeout, then the next timethere is an attempt to use the RRef (such asto_here()), a timeout will be raisedindicating this failure. A value of 0 indicatesan infinite timeout, i.e. a timeout error willnever be raised. If not provided, the defaultvalue set during initialization or with_set_rpc_timeoutis used.
- Returns
A user
RRefinstance to the resultvalue. Use the blocking APItorch.distributed.rpc.RRef.to_here()to retrieve the result value locally.
Warning
The
remoteAPI does not copy storages of argument tensors untilsending them over the wire, which could be done by a different threaddepending on the RPC backend type. The caller should make sure that thecontents of those tensors stay intact until the returned RRef isconfirmed by the owner, which can be checked using thetorch.distributed.rpc.RRef.confirmed_by_owner()API.Warning
Errors such as timeouts for the
remoteAPI are handled on abest-effort basis. This means that when remote calls initiated byremotefail, such as with a timeout error, we take a best-effortapproach to error handling. This means that errors are handled and seton the resulting RRef on an asynchronous basis. If the RRef has not beenused by the application before this handling (such asto_hereorfork call), then future uses of theRRefwill appropriately raiseerrors. However, it is possible that the user application will use theRRefbefore the errors are handled. In this case, errors may not beraised as they have not yet been handled.Example:
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properlyon both workers. Refer to :meth:`~torch.distributed.init_process_group`API for more details. For example,export MASTER_ADDR=localhostexport MASTER_PORT=5678Then run the following code in two different processes:>>> # On worker 0:>>> import torch>>> import torch.distributed.rpc as rpc>>> rpc.init_rpc("worker0", rank=0, world_size=2)>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))>>> x = rref1.to_here() + rref2.to_here()>>> rpc.shutdown()>>> # On worker 1:>>> import torch.distributed.rpc as rpc>>> rpc.init_rpc("worker1", rank=1, world_size=2)>>> rpc.shutdown()Below is an example of running a TorchScript function using RPC.>>> # On both workers:>>> @torch.jit.script>>> def my_script_add(tensor: torch.Tensor, scalar: int):>>> return torch.add(tensor, scalar)>>> # On worker 0:>>> import torch.distributed.rpc as rpc>>> rpc.init_rpc("worker0", rank=0, world_size=2)>>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3))>>> rref.to_here()>>> rpc.shutdown()>>> # On worker 1:>>> import torch.distributed.rpc as rpc>>> rpc.init_rpc("worker1", rank=1, world_size=2)>>> rpc.shutdown()
- torch.distributed.rpc.get_worker_info(worker_name=None)[source]#
Get
WorkerInfoof a given worker name.Use thisWorkerInfoto avoid passing anexpensive string on every invocation.- Parameters
worker_name (str) – the string name of a worker. If
None, return thethe id of the current worker. (defaultNone)- Returns
WorkerInfoinstance for the givenworker_nameorWorkerInfoof thecurrent worker ifworker_nameisNone.
- torch.distributed.rpc.shutdown(graceful=True,timeout=0)[source]#
Perform a shutdown of the RPC agent, and then destroy the RPC agent. Thisstops the local agent from accepting outstanding requests, and shutsdown the RPC framework by terminating all RPC threads. If
graceful=True,this will block until all local and remote RPC processes reach this methodand wait for all outstanding work to complete. Otherwise, ifgraceful=False, this is a local shutdown, and it does not wait for otherRPC processes to reach this method.Warning
For
Futureobjects returned byrpc_async(),future.wait()should notbe called aftershutdown().- Parameters
graceful (bool) – Whether to do a graceful shutdown or not. If True,this will 1) wait until there is no pending systemmessages for
UserRRefsand delete them; 2) blockuntil all local and remote RPC processes have reachedthis method and wait for all outstanding work tocomplete.
- Example::
Make sure that
MASTER_ADDRandMASTER_PORTare set properlyon both workers. Refer toinit_process_group()API for more details. For example,export MASTER_ADDR=localhostexport MASTER_PORT=5678
Then run the following code in two different processes:
>>># On worker 0:>>>importtorch>>>importtorch.distributed.rpcasrpc>>>rpc.init_rpc("worker0",rank=0,world_size=2)>>># do some work>>>result=rpc.rpc_sync("worker1",torch.add,args=(torch.ones(1),1))>>># ready to shutdown>>>rpc.shutdown()
>>># On worker 1:>>>importtorch.distributed.rpcasrpc>>>rpc.init_rpc("worker1",rank=1,world_size=2)>>># wait for worker 0 to finish work, and then shutdown.>>>rpc.shutdown()
- classtorch.distributed.rpc.WorkerInfo#
A structure that encapsulates information of a worker in the system.Contains the name and ID of the worker. This class is not meant tobe constructed directly, rather, an instance can be retrievedthrough
get_worker_info()and theresult can be passed in to functions such asrpc_sync(),rpc_async(),remote()to avoid copying a string onevery invocation.- propertyid#
Globally unique id to identify the worker.
- propertyname#
The name of the worker.
The RPC package also provides decorators which allow applications to specifyhow a given function should be treated on the callee side.
- torch.distributed.rpc.functions.async_execution(fn)[source]#
A decorator for a function indicating that the return value of the functionis guaranteed to be a
Futureobject and thisfunction can run asynchronously on the RPC callee. More specifically, thecallee extracts theFuturereturned by the wrappedfunction and installs subsequent processing steps as a callback to thatFuture. The installed callback will read the valuefrom theFuturewhen completed and send thevalue back as the RPC response. That also means the returnedFutureonly exists on the callee side and is neversent through RPC. This decorator is useful when the wrapped function’s(fn) execution needs to pause and resume due to, e.g., containingrpc_async()or waiting for other signals.Note
To enable asynchronous execution, applications must pass thefunction object returned by this decorator to RPC APIs. If RPC detectedattributes installed by this decorator, it knows that this functionreturns a
Futureobject and will handle that accordingly.However, this does not mean this decorator has to be outmost one whendefining a function. For example, when combined with@staticmethodor@classmethod,@rpc.functions.async_executionneeds to be theinner decorator to allow the target function be recognized as a staticor class function. This target function can still execute asynchronouslybecause, when accessed, the static or class method preserves attributesinstalled by@rpc.functions.async_execution.- Example::
The returned
Futureobject can come fromrpc_async(),then(), orFutureconstructor. The example below shows directly using theFuturereturned bythen().>>>fromtorch.distributedimportrpc>>>>>># omitting setup and shutdown RPC>>>>>># On all workers>>>@rpc.functions.async_execution>>>defasync_add_chained(to,x,y,z):>>># This function runs on "worker1" and returns immediately when>>># the callback is installed through the `then(cb)` API. In the>>># mean time, the `rpc_async` to "worker2" can run concurrently.>>># When the return value of that `rpc_async` arrives at>>># "worker1", "worker1" will run the lambda function accordingly>>># and set the value for the previously returned `Future`, which>>># will then trigger RPC to send the result back to "worker0".>>>returnrpc.rpc_async(to,torch.add,args=(x,y)).then(>>>lambdafut:fut.wait()+z>>>)>>>>>># On worker0>>>ret=rpc.rpc_sync(>>>"worker1",>>>async_add_chained,>>>args=("worker2",torch.ones(2),1,1)>>>)>>>print(ret)# prints tensor([3., 3.])
When combined with TorchScript decorators, this decorator must be theoutmost one.
>>>fromtorchimportTensor>>>fromtorch.futuresimportFuture>>>fromtorch.distributedimportrpc>>>>>># omitting setup and shutdown RPC>>>>>># On all workers>>>@torch.jit.script>>>defscript_add(x:Tensor,y:Tensor)->Tensor:>>>returnx+y>>>>>>@rpc.functions.async_execution>>>@torch.jit.script>>>defasync_add(to:str,x:Tensor,y:Tensor)->Future[Tensor]:>>>returnrpc.rpc_async(to,script_add,(x,y))>>>>>># On worker0>>>ret=rpc.rpc_sync(>>>"worker1",>>>async_add,>>>args=("worker2",torch.ones(2),1)>>>)>>>print(ret)# prints tensor([2., 2.])
When combined with static or class method, this decorator must be theinner one.
>>>fromtorch.distributedimportrpc>>>>>># omitting setup and shutdown RPC>>>>>># On all workers>>>classAsyncExecutionClass:>>>>>>@staticmethod>>>@rpc.functions.async_execution>>>defstatic_async_add(to,x,y,z):>>>returnrpc.rpc_async(to,torch.add,args=(x,y)).then(>>>lambdafut:fut.wait()+z>>>)>>>>>>@classmethod>>>@rpc.functions.async_execution>>>defclass_async_add(cls,to,x,y,z):>>>ret_fut=torch.futures.Future()>>>rpc.rpc_async(to,torch.add,args=(x,y)).then(>>>lambdafut:ret_fut.set_result(fut.wait()+z)>>>)>>>returnret_fut>>>>>>@rpc.functions.async_execution>>>defbound_async_add(self,to,x,y,z):>>>returnrpc.rpc_async(to,torch.add,args=(x,y)).then(>>>lambdafut:fut.wait()+z>>>)>>>>>># On worker0>>>ret=rpc.rpc_sync(>>>"worker1",>>>AsyncExecutionClass.static_async_add,>>>args=("worker2",torch.ones(2),1,2)>>>)>>>print(ret)# prints tensor([4., 4.])>>>>>>ret=rpc.rpc_sync(>>>"worker1",>>>AsyncExecutionClass.class_async_add,>>>args=("worker2",torch.ones(2),1,2)>>>)>>>print(ret)# prints tensor([4., 4.])
This decorator also works with RRef helpers, i.e., .
torch.distributed.rpc.RRef.rpc_sync(),torch.distributed.rpc.RRef.rpc_async(), andtorch.distributed.rpc.RRef.remote().>>>fromtorch.distributedimportrpc>>>>>># reuse the AsyncExecutionClass class above>>>rref=rpc.remote("worker1",AsyncExecutionClass)>>>ret=rref.rpc_sync().static_async_add("worker2",torch.ones(2),1,2)>>>print(ret)# prints tensor([4., 4.])>>>>>>rref=rpc.remote("worker1",AsyncExecutionClass)>>>ret=rref.rpc_async().static_async_add("worker2",torch.ones(2),1,2).wait()>>>print(ret)# prints tensor([4., 4.])>>>>>>rref=rpc.remote("worker1",AsyncExecutionClass)>>>ret=rref.remote().static_async_add("worker2",torch.ones(2),1,2).to_here()>>>print(ret)# prints tensor([4., 4.])
Backends#
The RPC module can leverage different backends to perform the communicationbetween the nodes. The backend to be used can be specified in theinit_rpc() function, by passing a certain value oftheBackendType enum. Regardless of what backendis used, the rest of the RPC API won’t change. Each backend also defines its ownsubclass of theRpcBackendOptions class, aninstance of which can also be passed toinit_rpc()to configure the backend’s behavior.
- classtorch.distributed.rpc.BackendType(value)#
An enum class of available backends.
PyTorch ships with a builtin
BackendType.TENSORPIPEbackend.Additional ones can be registered using theregister_backend()function.
- classtorch.distributed.rpc.RpcBackendOptions#
An abstract structure encapsulating the options passed into the RPCbackend. An instance of this class can be passed in to
init_rpc()in order to initialize RPCwith specific configurations, such as the RPC timeout andinit_methodto be used.- propertyinit_method#
URL specifying how to initialize the process group.Default is
env://
- propertyrpc_timeout#
A float indicating the timeout to use for allRPCs. If an RPC does not complete in this timeframe, it willcomplete with an exception indicating that it has timed out.
TensorPipe Backend#
The TensorPipe agent, which is the default, leveragesthe TensorPipe library, which provides a nativelypoint-to-point communication primitive specifically suited for machine learningthat fundamentally addresses some of the limitations of Gloo. Compared to Gloo,it has the advantage of being asynchronous, which allows a large number oftransfers to occur simultaneously, each at their own speed, without blockingeach other. It will only open pipes between pairs of nodes when needed, ondemand, and when one node fails only its incident pipes will be closed, whileall other ones will keep working as normal. In addition, it is able to supportmultiple different transports (TCP, of course, but also shared memory, NVLink,InfiniBand, …) and can automatically detect their availability and negotiatethe best transport to use for each pipe.
The TensorPipe backend comes with a TCP-based transport, just like Gloo. It is also able toautomatically chunk and multiplex large tensors over multiple sockets andthreads in order to achieve very high bandwidths. The agent will be able to pickthe best transport on its own, with no intervention required.
Example:
importosfromtorch.distributedimportrpcos.environ['MASTER_ADDR']='localhost'os.environ['MASTER_PORT']='29500'rpc.init_rpc("worker1",rank=0,world_size=2,rpc_backend_options=rpc.TensorPipeRpcBackendOptions(num_worker_threads=8,rpc_timeout=20# 20 second timeout))# omitting init_rpc invocation on worker2
- classtorch.distributed.rpc.TensorPipeRpcBackendOptions(*,num_worker_threads=16,rpc_timeout=60.0,init_method='env://',device_maps=None,devices=None,_transports=None,_channels=None)[source]#
The backend options for
TensorPipeAgent, derived fromRpcBackendOptions.- Parameters
num_worker_threads (int,optional) – The number of threads in thethread-pool used by
TensorPipeAgentto executerequests (default: 16).rpc_timeout (float,optional) – The default timeout, in seconds,for RPC requests (default: 60 seconds). If the RPC has notcompleted in this timeframe, an exception indicating so willbe raised. Callers can override this timeout for individualRPCs in
rpc_sync()andrpc_async()if necessary.init_method (str,optional) – The URL to initialize the distributedstore used for rendezvous. It takes any value accepted for thesame argument of
init_process_group()(default:env://).device_maps (Dict[str,Dict],optional) – Device placement mappings fromthis worker to the callee. Key is the callee worker name and valuethe dictionary (
Dictofint,str, ortorch.device)that maps this worker’s devices to the callee worker’s devices.(default:None)devices (List[int, str, or
torch.device], optional) – all localCUDA devices used by RPC agent. By Default, it will be initializedto all local devices from its owndevice_mapsand correspondingdevices from its peers’device_maps. When processing CUDA RPCrequests, the agent will properly synchronize CUDA streams forall devices in thisList.
- propertydevice_maps#
The device map locations.
- propertydevices#
All devices used by the local agent.
- propertyinit_method#
URL specifying how to initialize the process group.Default is
env://
- propertynum_worker_threads#
The number of threads in the thread-pool used by
TensorPipeAgentto executerequests.
- propertyrpc_timeout#
A float indicating the timeout to use for allRPCs. If an RPC does not complete in this timeframe, it willcomplete with an exception indicating that it has timed out.
- set_device_map(to,device_map)[source]#
Set device mapping between each RPC caller and callee pair. Thisfunction can be called multiple times to incrementally adddevice placement configurations.
- Parameters
to (str) – Callee name.
device_map (Dict ofint,str, ortorch.device) – Device placementmappings from this worker to the callee. This map must beinvertible.
Example
>>># both workers>>>defadd(x,y):>>>print(x)# tensor([1., 1.], device='cuda:1')>>>returnx+y,(x+y).to(2)>>>>>># on worker 0>>>options=TensorPipeRpcBackendOptions(>>>num_worker_threads=8,>>>device_maps={"worker1":{0:1}}>>># maps worker0's cuda:0 to worker1's cuda:1>>>)>>>options.set_device_map("worker1",{1:2})>>># maps worker0's cuda:1 to worker1's cuda:2>>>>>>rpc.init_rpc(>>>"worker0",>>>rank=0,>>>world_size=2,>>>backend=rpc.BackendType.TENSORPIPE,>>>rpc_backend_options=options>>>)>>>>>>x=torch.ones(2)>>>rets=rpc.rpc_sync("worker1",add,args=(x.to(0),1))>>># The first argument will be moved to cuda:1 on worker1. When>>># sending the return value back, it will follow the invert of>>># the device map, and hence will be moved back to cuda:0 and>>># cuda:1 on worker0>>>print(rets[0])# tensor([2., 2.], device='cuda:0')>>>print(rets[1])# tensor([2., 2.], device='cuda:1')
- set_devices(devices)[source]#
Set local devices used by the TensorPipe RPC agent. When processingCUDA RPC requests, the TensorPipe RPC agent will properly synchronizeCUDA streams for all devices in this
List.- Parameters
devices (List ofint,str, ortorch.device) – local devices used bythe TensorPipe RPC agent.
Note
The RPC framework does not automatically retry anyrpc_sync(),rpc_async() andremote() calls. The reason being that there isno way the RPC framework can determine whether an operation is idempotent ornot and whether it is safe to retry. As a result, it is the application’sresponsibility to deal with failures and retry if necessary. RPC communicationis based on TCP and as a result failures could happen due to network failuresor intermittent network connectivity issues. In such scenarios, the applicationneeds to retry appropriately with reasonable backoffs to ensure the networkisn’t overwhelmed by aggressive retries.
RRef#
Warning
RRefs are not currently supported when using CUDA tensors
AnRRef (Remote REFerence) is a reference to a value of some typeT(e.g.Tensor) on a remote worker. This handle keeps the referenced remotevalue alive on the owner, but there is no implication that the value will betransferred to the local worker in the future. RRefs can be used inmulti-machine training by holding references tonn.Modules that exist onother workers, and calling the appropriate functions to retrieve or modify theirparameters during training. SeeRemote Reference Protocol for moredetails.
- classtorch.distributed.rpc.PyRRef(RRef)#
A class encapsulating a reference to a value of some type on a remoteworker. This handle will keep the referenced remote value alive on theworker. A
UserRRefwill be deleted when 1) no references to it inboth the application code and in the local RRef context, or 2) theapplication has called a graceful shutdown. Invoking methods on adeleted RRef leads to undefined behaviors. RRef implementation onlyoffers best-effort error detection, and applications should not useUserRRefsafterrpc.shutdown().Warning
RRefs can only be serialized and deserialized by the RPC module.Serializing and deserializing RRefs without RPC (e.g., Pythonpickle, torch
save()/load(),JITsave()/load(), etc.) willlead to errors.- Parameters
value (object) – The value to be wrapped by this RRef.
type_hint (Type,optional) – Python type that should be passed to
TorchScriptcompiler as type hint forvalue.
- Example::
Following examples skip RPC initialization and shutdown codefor simplicity. Refer to RPC docs for those details.
Create an RRef using rpc.remote
>>>importtorch>>>importtorch.distributed.rpcasrpc>>>rref=rpc.remote("worker1",torch.add,args=(torch.ones(2),3))>>># get a copy of value from the RRef>>>x=rref.to_here()
Create an RRef from a local object
>>>importtorch>>>fromtorch.distributed.rpcimportRRef>>>x=torch.zeros(2,2)>>>rref=RRef(x)
Share an RRef with other workers
>>># On both worker0 and worker1:>>>deff(rref):>>>returnrref.to_here()+1
>>># On worker0:>>>importtorch>>>importtorch.distributed.rpcasrpc>>>fromtorch.distributed.rpcimportRRef>>>rref=RRef(torch.zeros(2,2))>>># the following RPC shares the rref with worker1, reference>>># count is automatically updated.>>>rpc.rpc_sync("worker1",f,args=(rref,))
- backward(self:torch._C._distributed_rpc.PyRRef,dist_autograd_ctx_id:SupportsInt=-1,retain_graph:bool=False)→None#
Runs the backward pass using the RRef as the root of thebackward pass. If
dist_autograd_ctx_idis provided,we perform a distributed backward pass using the providedctx_id starting from the owner of the RRef. In this case,get_gradients()should beused to retrieve the gradients. Ifdist_autograd_ctx_idisNone, it is assumed that this is a local autograd graphand we only perform a local backward pass. In the local case,the node calling this API has to be the owner of the RRef.The value of the RRef is expected to be a scalar Tensor.- Parameters
dist_autograd_ctx_id (int,optional) – The distributedautograd context id for which we should retrieve thegradients (default: -1).
retain_graph (bool,optional) – If
False, the graph used tocompute the grad will be freed. Note that in nearly allcases setting this option toTrueis not needed andoften can be worked around in a much more efficient way.Usually, you need to set this toTrueto run backwardmultiple times (default: False).
- Example::
>>>importtorch.distributed.autogradasdist_autograd>>>withdist_autograd.context()ascontext_id:>>>rref.backward(context_id)
- confirmed_by_owner(self:torch._C._distributed_rpc.PyRRef)→bool#
Returns whether this
RRefhas been confirmed by the owner.OwnerRRefalways returns true, whileUserRRefonlyreturns true when the owner knowns about thisUserRRef.
- is_owner(self:torch._C._distributed_rpc.PyRRef)→bool#
Returns whether or not the current node is the owner of this
RRef.
- local_value(self:torch._C._distributed_rpc.PyRRef)→object#
If the current node is the owner, returns a reference to thelocal value. Otherwise, throws an exception.
- owner(self:torch._C._distributed_rpc.PyRRef)→torch._C._distributed_rpc.WorkerInfo#
Returns worker information of the node that owns this
RRef.
- owner_name(self:torch._C._distributed_rpc.PyRRef)→str#
Returns worker name of the node that owns this
RRef.
- remote(self:torch._C._distributed_rpc.PyRRef,timeout:SupportsFloat=-1.0)→object#
Create a helper proxy to easily launch a
remoteusingthe owner of the RRef as the destination to run functions onthe object referenced by this RRef. More specifically,rref.remote().func_name(*args,**kwargs)is the same asthe following:>>>defrun(rref,func_name,args,kwargs):>>>returngetattr(rref.local_value(),func_name)(*args,**kwargs)>>>>>>rpc.remote(rref.owner(),run,args=(rref,func_name,args,kwargs))
- Parameters
timeout (float,optional) – Timeout for
rref.remote(). Ifthe creation of thisRRefis not successfully completed within the timeout, then thenext time there is an attempt to use the RRef(such asto_here), a timeout will be raised. If notprovided, the default RPC timeout will be used. Please seerpc.remote()for specific timeout semantics forRRef.
- Example::
>>>fromtorch.distributedimportrpc>>>rref=rpc.remote("worker1",torch.add,args=(torch.zeros(2,2),1))>>>rref.remote().size().to_here()# returns torch.Size([2, 2])>>>rref.remote().view(1,4).to_here()# returns tensor([[1., 1., 1., 1.]])
- rpc_async(self:torch._C._distributed_rpc.PyRRef,timeout:SupportsFloat=-1.0)→object#
Create a helper proxy to easily launch an
rpc_asyncusingthe owner of the RRef as the destination to run functions onthe object referenced by this RRef. More specifically,rref.rpc_async().func_name(*args,**kwargs)is the same asthe following:>>>defrun(rref,func_name,args,kwargs):>>>returngetattr(rref.local_value(),func_name)(*args,**kwargs)>>>>>>rpc.rpc_async(rref.owner(),run,args=(rref,func_name,args,kwargs))
- Parameters
timeout (float,optional) – Timeout for
rref.rpc_async().If the call does not complete within this timeframe, anexception indicating so will be raised. If this argumentis not provided, the default RPC timeout will be used.
- Example::
>>>fromtorch.distributedimportrpc>>>rref=rpc.remote("worker1",torch.add,args=(torch.zeros(2,2),1))>>>rref.rpc_async().size().wait()# returns torch.Size([2, 2])>>>rref.rpc_async().view(1,4).wait()# returns tensor([[1., 1., 1., 1.]])
- rpc_sync(self:torch._C._distributed_rpc.PyRRef,timeout:SupportsFloat=-1.0)→object#
Create a helper proxy to easily launch an
rpc_syncusingthe owner of the RRef as the destination to run functions onthe object referenced by this RRef. More specifically,rref.rpc_sync().func_name(*args,**kwargs)is the same asthe following:>>>defrun(rref,func_name,args,kwargs):>>>returngetattr(rref.local_value(),func_name)(*args,**kwargs)>>>>>>rpc.rpc_sync(rref.owner(),run,args=(rref,func_name,args,kwargs))
- Parameters
timeout (float,optional) – Timeout for
rref.rpc_sync().If the call does not complete within this timeframe, anexception indicating so will be raised. If this argumentis not provided, the default RPC timeout will be used.
- Example::
>>>fromtorch.distributedimportrpc>>>rref=rpc.remote("worker1",torch.add,args=(torch.zeros(2,2),1))>>>rref.rpc_sync().size()# returns torch.Size([2, 2])>>>rref.rpc_sync().view(1,4)# returns tensor([[1., 1., 1., 1.]])
- to_here(self:torch._C._distributed_rpc.PyRRef,timeout:SupportsFloat=-1.0)→object#
Blocking call that copies the value of the RRef from the ownerto the local node and returns it. If the current node is theowner, returns a reference to the local value.
- Parameters
timeout (float,optional) – Timeout for
to_here. Ifthe call does not complete within this timeframe, anexception indicating so will be raised. If thisargument is not provided, the default RPC timeout(60s) will be used.
RemoteModule#
Warning
RemoteModule is not currently supported when using CUDA tensors
RemoteModule is an easy way to create an nn.Module remotely on a differentprocess. The actual module resides on a remote host, but the local host has ahandle to this module and invoke this module similar to a regular nn.Module.The invocation however incurs RPC calls to the remote end and can be performedasynchronously if needed via additional APIs supported by RemoteModule.
- classtorch.distributed.nn.api.remote_module.RemoteModule(*args,**kwargs)[source]#
A RemoteModule instance can only be created after RPC initialization.
It creates a user-specified module on a specified remote node.It behaves like a regular
nn.Moduleexcept that theforwardmethod isexecuted on the remote node.It takes care of autograd recording to ensure the backward pass propagatesgradients back to the corresponding remote module.It generates two methods
forward_asyncandforwardbased on thesignature of theforwardmethod ofmodule_cls.forward_asyncruns asynchronously and returns a Future. The arguments offorward_asyncandforwardare the same as theforwardmethod of the modulereturned by themodule_cls.For example, if
module_clsreturns an instance ofnn.Linear,that hasforwardmethod signature:defforward(input:Tensor)->Tensor:,the generatedRemoteModulewill have 2 methods with the signatures:defforward(input:Tensor)->Tensor:defforward_async(input:Tensor)->Future[Tensor]:- Parameters
remote_device (str) – Device on the destination worker where we’d like to place this module.The format should be “<workername>/<device>”, where the device field can be parsed as torch.device type.E.g., “trainer0/cpu”, “trainer0”, “ps0/cuda:0”.In addition, the device field can be optional and the default value is “cpu”.
module_cls (nn.Module) –
Class for the module to be created remotely. For example,
>>>classMyModule(nn.Module):>>>defforward(input):>>>returninput+1>>>>>>module_cls=MyModule
args (Sequence,optional) – args to be passed to
module_cls.kwargs (Dict,optional) – kwargs to be passed to
module_cls.
- Returns
A remote module instance which wraps the
Modulecreated by theuser-providedmodule_cls, it has a blockingforwardmethod and anasynchronousforward_asyncmethod that returns a future of theforwardcallon the user-provided module on the remote side.
- Example::
Run the following code in two different processes:
>>># On worker 0:>>>importtorch>>>importtorch.distributed.rpcasrpc>>>fromtorchimportnn,Tensor>>>fromtorch.distributed.nn.api.remote_moduleimportRemoteModule>>>>>>rpc.init_rpc("worker0",rank=0,world_size=2)>>>remote_linear_module=RemoteModule(>>>"worker1/cpu",nn.Linear,args=(20,30),>>>)>>>input=torch.randn(128,20)>>>ret_fut=remote_linear_module.forward_async(input)>>>ret=ret_fut.wait()>>>rpc.shutdown()
>>># On worker 1:>>>importtorch>>>importtorch.distributed.rpcasrpc>>>>>>rpc.init_rpc("worker1",rank=1,world_size=2)>>>rpc.shutdown()
Furthermore, a more practical example that is combined withDistributedDataParallel (DDP)can be found in thistutorial.
- get_module_rref()[source]#
Return an
RRef(RRef[nn.Module]) pointing to the remote module.- Return type
RRef[Module]
- remote_parameters(recurse=True)[source]#
Return a list of
RRefpointing to the remote module’s parameters.This can typically be used in conjunctionwith
DistributedOptimizer.- Parameters
recurse (bool) – if True, then returns parameters of the remotemodule and all submodules of the remote module. Otherwise,returns only parameters that are direct members of theremote module.
- Returns
A list of
RRef(List[RRef[nn.Parameter]])to remote module’s parameters.- Return type
list[torch.distributed.rpc.api.RRef[torch.nn.parameter.Parameter]]
Distributed Autograd Framework#
Warning
Distributed autograd is not currently supported when using CUDA tensors
This module provides an RPC-based distributed autograd framework that can beused for applications such as model parallel training. In short, applicationsmay send and receive gradient recording tensors over RPC. In the forward pass,we record when gradient recording tensors are sent over RPC and during thebackward pass we use this information to perform a distributed backward passusing RPC. For more details seeDistributed Autograd Design.
- torch.distributed.autograd.backward(context_id:int,roots:List[Tensor],retain_graph=False)→None#
Kicks off the distributed backward pass using the provided roots. Thiscurrently implements theFAST mode algorithm whichassumes all RPC messages sent in the same distributed autograd contextacross workers would be part of the autograd graph during the backward pass.
We use the provided roots to discover the autograd graph and computeappropriate dependencies. This method blocks until the entireautograd computation is done.
We accumulate the gradients in the appropriate
torch.distributed.autograd.contexton each of the nodes. The autogradcontext to be used is looked up given thecontext_idthat is passed in whentorch.distributed.autograd.backward()is called. If there is no validautograd context corresponding to the given ID, we throw an error. You canretrieve the accumulated gradients using theget_gradients()API.- Parameters
context_id (int) – The autograd context id for which we should retrieve the gradients.
roots (list) – Tensors which represent the roots of the autogradcomputation. All the tensors should be scalars.
retain_graph (bool,optional) – If False, the graph used to compute the gradwill be freed. Note that in nearly all cases setting thisoption to True is not needed and often can be worked aroundin a much more efficient way. Usually, you need to set thisto True to run backward multiple times.
- Example::
>>>importtorch.distributed.autogradasdist_autograd>>>withdist_autograd.context()ascontext_id:>>>pred=model.forward()>>>loss=loss_func(pred,loss)>>>dist_autograd.backward(context_id,loss)
- classtorch.distributed.autograd.context[source]#
Context object to wrap forward and backward passes when usingdistributed autograd. The
context_idgenerated in thewithstatement is required to uniquely identify a distributed backward passon all workers. Each worker stores metadata associated with thiscontext_id, which is required to correctly execute a distributedautograd pass.- Example::
>>>importtorch.distributed.autogradasdist_autograd>>>withdist_autograd.context()ascontext_id:>>>t1=torch.rand((3,3),requires_grad=True)>>>t2=torch.rand((3,3),requires_grad=True)>>>loss=rpc.rpc_sync("worker1",torch.add,args=(t1,t2)).sum()>>>dist_autograd.backward(context_id,[loss])
- torch.distributed.autograd.get_gradients(context_id:int)→Dict[Tensor,Tensor]#
Retrieves a map from Tensor to the appropriate gradient for that Tensoraccumulated in the provided context corresponding to the given
context_idas part of the distributed autograd backward pass.- Parameters
context_id (int) – The autograd context id for which we should retrieve thegradients.
- Returns
A map where the key is the Tensor and the value is the associated gradientfor that Tensor.
- Example::
>>>importtorch.distributed.autogradasdist_autograd>>>withdist_autograd.context()ascontext_id:>>>t1=torch.rand((3,3),requires_grad=True)>>>t2=torch.rand((3,3),requires_grad=True)>>>loss=t1+t2>>>dist_autograd.backward(context_id,[loss.sum()])>>>grads=dist_autograd.get_gradients(context_id)>>>print(grads[t1])>>>print(grads[t2])
Distributed Optimizer#
See thetorch.distributed.optim page for documentation on distributed optimizers.
Design Notes#
The distributed autograd design note covers the design of the RPC-based distributed autograd framework that is useful for applications such as model parallel training.
The RRef design note covers the design of theRRef (Remote REFerence) protocol used to refer to values on remote workers by the framework.
Tutorials#
The RPC tutorials introduce users to the RPC framework, provide several example applicationsusingtorch.distributed.rpc APIs, and demonstrate howto usethe profiler to profile RPC-based workloads.