Rate this Page

Combining Distributed DataParallel with Distributed RPC Framework#

Created On: Jul 28, 2020 | Last Updated: Jun 06, 2023 | Last Verified: Not Verified

Authors:Pritam Damania andYi Wang

Note

edit View and edit this tutorial ingithub.

This tutorial uses a simple example to demonstrate how you can combineDistributedDataParallel (DDP)with theDistributed RPC frameworkto combine distributed data parallelism with distributed model parallelism totrain a simple model. Source code of the example can be foundhere.

Previous tutorials,Getting Started With Distributed Data ParallelandGetting Started with Distributed RPC Framework,described how to perform distributed data parallel and distributed modelparallel training respectively. Although, there are several training paradigmswhere you might want to combine these two techniques. For example:

  1. If we have a model with a sparse part (large embedding table) and a densepart (FC layers), we might want to put the embedding table on a parameterserver and replicate the FC layer across multiple trainers usingDistributedDataParallel.TheDistributed RPC frameworkcan be used to perform embedding lookups on the parameter server.

  2. Enable hybrid parallelism as described in thePipeDream paper.We can use theDistributed RPC frameworkto pipeline stages of the model across multiple workers and replicate eachstage (if needed) usingDistributedDataParallel.


In this tutorial we will cover case 1 mentioned above. We have a total of 4workers in our setup as follows:

  1. 1 Master, which is responsible for creating an embedding table(nn.EmbeddingBag) on the parameter server. The master also drives thetraining loop on the two trainers.

  2. 1 Parameter Server, which basically holds the embedding table in memory andresponds to RPCs from the Master and Trainers.

  3. 2 Trainers, which store an FC layer (nn.Linear) which is replicated amongstthemselves usingDistributedDataParallel.The trainers are also responsible for executing the forward pass, backwardpass and optimizer step.


The entire training process is executed as follows:

  1. The master creates aRemoteModulethat holds an embedding table on the Parameter Server.

  2. The master, then kicks off the training loop on the trainers and passes theremote module to the trainers.

  3. The trainers create aHybridModel which first performs an embedding lookupusing the remote module provided by the master and then executes theFC layer which is wrapped inside DDP.

  4. The trainer executes the forward pass of the model and uses the loss toexecute the backward pass usingDistributed Autograd.

  5. As part of the backward pass, the gradients for the FC layer are computedfirst and synced to all trainers via allreduce in DDP.

  6. Next, Distributed Autograd propagates the gradients to the parameter server,where the gradients for the embedding table are updated.

  7. Finally, theDistributed Optimizer is used to update all the parameters.

Attention

You should always useDistributed Autogradfor the backward pass if you’re combining DDP and RPC.

Now, let’s go through each part in detail. Firstly, we need to setup all of ourworkers before we can perform any training. We create 4 processes such thatranks 0 and 1 are our trainers, rank 2 is the master and rank 3 is theparameter server.

We initialize the RPC framework on all 4 workers using the TCP init_method.Once RPC initialization is done, the master creates a remote module that holds anEmbeddingBaglayer on the Parameter Server usingRemoteModule.The master then loops through each trainer and kicks off the training loop bycalling_run_trainer on each trainer usingrpc_async.Finally, the master waits for all training to finish before exiting.

The trainers first initialize aProcessGroup for DDP with world_size=2(for two trainers) usinginit_process_group.Next, they initialize the RPC framework using the TCP init_method. Note thatthe ports are different in RPC initialization and ProcessGroup initialization.This is to avoid port conflicts between initialization of both frameworks.Once the initialization is done, the trainers just wait for the_run_trainerRPC from the master.

The parameter server just initializes the RPC framework and waits for RPCs fromthe trainers and master.

defrun_worker(rank,world_size):r"""    A wrapper function that initializes RPC, calls the function, and shuts down    RPC.    """# We need to use different port numbers in TCP init_method for init_rpc and# init_process_group to avoid port conflicts.rpc_backend_options=TensorPipeRpcBackendOptions()rpc_backend_options.init_method="tcp://localhost:29501"# Rank 2 is master, 3 is ps and 0 and 1 are trainers.ifrank==2:rpc.init_rpc("master",rank=rank,world_size=world_size,rpc_backend_options=rpc_backend_options,)remote_emb_module=RemoteModule("ps",torch.nn.EmbeddingBag,args=(NUM_EMBEDDINGS,EMBEDDING_DIM),kwargs={"mode":"sum"},)# Run the training loop on trainers.futs=[]fortrainer_rankin[0,1]:trainer_name="trainer{}".format(trainer_rank)fut=rpc.rpc_async(trainer_name,_run_trainer,args=(remote_emb_module,trainer_rank))futs.append(fut)# Wait for all training to finish.forfutinfuts:fut.wait()elifrank<=1:# Initialize process group for Distributed DataParallel on trainers.dist.init_process_group(backend="gloo",rank=rank,world_size=2,init_method="tcp://localhost:29500")# Initialize RPC.trainer_name="trainer{}".format(rank)rpc.init_rpc(trainer_name,rank=rank,world_size=world_size,rpc_backend_options=rpc_backend_options,)# Trainer just waits for RPCs from master.else:rpc.init_rpc("ps",rank=rank,world_size=world_size,rpc_backend_options=rpc_backend_options,)# parameter server do nothingpass# block until all rpcs finishrpc.shutdown()if__name__=="__main__":# 2 trainers, 1 parameter server, 1 master.world_size=4mp.spawn(run_worker,args=(world_size,),nprocs=world_size,join=True)

Before we discuss details of the Trainer, let’s introduce theHybridModel thatthe trainer uses. As described below, theHybridModel is initialized using aremote module that holds an embedding table (remote_emb_module) on the parameter server and thedeviceto use for DDP. The initialization of the model wraps annn.Linearlayer inside DDP to replicate and synchronize this layer across all trainers.

The forward method of the model is pretty straightforward. It performs anembedding lookup on the parameter server using RemoteModule’sforwardand passes its output onto the FC layer.

classHybridModel(torch.nn.Module):r"""    The model consists of a sparse part and a dense part.    1) The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel.    2) The sparse part is a Remote Module that holds an nn.EmbeddingBag on the parameter server.    This remote model can get a Remote Reference to the embedding table on the parameter server.    """def__init__(self,remote_emb_module,device):super(HybridModel,self).__init__()self.remote_emb_module=remote_emb_moduleself.fc=DDP(torch.nn.Linear(16,8).cuda(device),device_ids=[device])self.device=devicedefforward(self,indices,offsets):emb_lookup=self.remote_emb_module.forward(indices,offsets)returnself.fc(emb_lookup.cuda(self.device))

Next, let’s look at the setup on the Trainer. The trainer first creates theHybridModel described above using a remote module that holds the embedding table on theparameter server and its own rank.

Now, we need to retrieve a list of RRefs to all the parameters that we wouldlike to optimize withDistributedOptimizer.To retrieve the parameters for the embedding table from the parameter server,we can call RemoteModule’sremote_parameters,which basically walks through all the parameters for the embedding table and returnsa list of RRefs. The trainer calls this method on the parameter server via RPCto receive a list of RRefs to the desired parameters. Since theDistributedOptimizer always takes a list of RRefs to parameters that need tobe optimized, we need to create RRefs even for the local parameters for ourFC layers. This is done by walkingmodel.fc.parameters(), creating an RRef foreach parameter and appending it to the list returned fromremote_parameters().Note that we cannnot usemodel.parameters(),because it will recursively callmodel.remote_emb_module.parameters(),which is not supported byRemoteModule.

Finally, we create our DistributedOptimizer using all the RRefs and define aCrossEntropyLoss function.

def_run_trainer(remote_emb_module,rank):r"""    Each trainer runs a forward pass which involves an embedding lookup on the    parameter server and running nn.Linear locally. During the backward pass,    DDP is responsible for aggregating the gradients for the dense part    (nn.Linear) and distributed autograd ensures gradients updates are    propagated to the parameter server.    """# Setup the model.model=HybridModel(remote_emb_module,rank)# Retrieve all model parameters as rrefs for DistributedOptimizer.# Retrieve parameters for embedding table.model_parameter_rrefs=model.remote_emb_module.remote_parameters()# model.fc.parameters() only includes local parameters.# NOTE: Cannot call model.parameters() here,# because this will call remote_emb_module.parameters(),# which supports remote_parameters() but not parameters().forparaminmodel.fc.parameters():model_parameter_rrefs.append(RRef(param))# Setup distributed optimizeropt=DistributedOptimizer(optim.SGD,model_parameter_rrefs,lr=0.05,)criterion=torch.nn.CrossEntropyLoss()

Now we’re ready to introduce the main training loop that is run on each trainer.get_next_batch is just a helper function to generate random inputs andtargets for training. We run the training loop for multiple epochs and for eachbatch:

  1. Setup aDistributed Autograd Contextfor Distributed Autograd.

  2. Run the forward pass of the model and retrieve its output.

  3. Compute the loss based on our outputs and targets using the loss function.

  4. Use Distributed Autograd to execute a distributed backward pass using the loss.

  5. Finally, run a Distributed Optimizer step to optimize all the parameters.

defget_next_batch(rank):for_inrange(10):num_indices=random.randint(20,50)indices=torch.LongTensor(num_indices).random_(0,NUM_EMBEDDINGS)# Generate offsets.offsets=[]start=0batch_size=0whilestart<num_indices:offsets.append(start)start+=random.randint(1,10)batch_size+=1offsets_tensor=torch.LongTensor(offsets)target=torch.LongTensor(batch_size).random_(8).cuda(rank)yieldindices,offsets_tensor,target# Train for 100 epochsforepochinrange(100):# create distributed autograd contextforindices,offsets,targetinget_next_batch(rank):withdist_autograd.context()ascontext_id:output=model(indices,offsets)loss=criterion(output,target)# Run distributed backward passdist_autograd.backward(context_id,[loss])# Tun distributed optimizeropt.step(context_id)# Not necessary to zero grads as each iteration creates a different# distributed autograd context which hosts different gradsprint("Training done for epoch{}".format(epoch))

Source code for the entire example can be foundhere.