torchrun (Elastic Launch)#
Created On: May 04, 2021 | Last Updated On: Aug 26, 2021
Moduletorch.distributed.run.
torch.distributed.run is a module that spawns up multiple distributedtraining processes on each of the training nodes.
torchrun is a pythonconsole scriptto the main moduletorch.distributed.rundeclared in theentry_points configuration insetup.py.It is equivalent to invokingpython-mtorch.distributed.run.
torchrun can be used for single-node distributed training, in which one ormore processes per node will be spawned. It can be used for eitherCPU training or GPU training. If it is used for GPU training,each distributed process will be operating on a single GPU. This can achievewell-improved single-node training performance.torchrun can also be used inmulti-node distributed training, by spawning up multiple processes on each nodefor well-improved multi-node distributed training performance as well.This will especially be beneficial for systems with multiple Infinibandinterfaces that have direct-GPU support, since all of them can be utilized foraggregated communication bandwidth.
In both cases of single-node distributed training or multi-node distributedtraining,torchrun will launch the given number of processes per node(--nproc-per-node). If used for GPU training, this number needs to be lessor equal to the number of GPUs on the current system (nproc_per_node),and each process will be operating on a single GPU fromGPU 0 toGPU (nproc_per_node - 1).
Changed in version 2.0.0:torchrun will pass the--local-rank=<rank> argument to your script.From PyTorch 2.0.0 onwards, the dashed--local-rank is preferred over thepreviously used underscored--local_rank.
For backward compatibility, it may be necessary for users to handle bothcases in their argument parsing code. This means including both"--local-rank"and"--local_rank" in the argument parser. If only"--local_rank" isprovided,torchrun will trigger an error: “error: unrecognized arguments:–local-rank=<rank>”. For training code that only supports PyTorch 2.0.0+,including"--local-rank" should be sufficient.
>>>importargparse>>>parser=argparse.ArgumentParser()>>>parser.add_argument("--local-rank","--local_rank",type=int)>>>args=parser.parse_args()
Usage#
Single-node multi-worker#
torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_TRAINERS YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
Note
--nproc-per-node may be"gpu" (spawn one process per GPU),"cpu" (spawn one process per CPU),"xpu" (spawn one process per XPU),"auto" (equivalent to"gpu" if CUDA is available,else equivalent to"xpu" if XPU is available,else equivalent to"cpu"),or an integer specifying the number of processes.Seetorch.distributed.run.determine_local_world_sizefor more details.
Stacked single-node multi-worker#
To run multiple instances (separate jobs) of single-node, multi-worker on thesame host, we need to make sure that each instance (job) issetup on different ports to avoid port conflicts (or worse, two jobs being mergedas a single job). To do this you have to run with--rdzv-backend=c10dand specify a different port by setting--rdzv-endpoint=localhost:$PORT_k.For--nodes=1, its often convenient to lettorchrun pick a free randomport automatically instead of manually assigning different ports for each run.
torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --nnodes=1 --nproc-per-node=$NUM_TRAINERS YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures)#
torchrun --nnodes=$NUM_NODES --nproc-per-node=$NUM_TRAINERS --max-restarts=3 --rdzv-id=$JOB_ID --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
HOST_NODE_ADDR, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node andthe port on which the C10d rendezvous backend should be instantiated and hosted. It can be anynode in your training cluster, but ideally you should pick a node that has a high bandwidth.
Note
If no port number is specifiedHOST_NODE_ADDR defaults to 29400.
Elastic (min=1,max=4, tolerates up to 3 membership changes or failures)#
torchrun --nnodes=1:4 --nproc-per-node=$NUM_TRAINERS --max-restarts=3 --rdzv-id=$JOB_ID --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
HOST_NODE_ADDR, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node andthe port on which the C10d rendezvous backend should be instantiated and hosted. It can be anynode in your training cluster, but ideally you should pick a node that has a high bandwidth.
Note
If no port number is specifiedHOST_NODE_ADDR defaults to 29400.
Note on rendezvous backend#
For multi-node training you need to specify:
--rdzv-id: A unique job id (shared by all nodes participating in the job)--rdzv-backend: An implementation oftorch.distributed.elastic.rendezvous.RendezvousHandler--rdzv-endpoint: The endpoint where the rendezvous backend is running; usually in formhost:port.
Currentlyc10d (recommended),etcd-v2, andetcd (legacy) rendezvous backends aresupported out of the box. To useetcd-v2 oretcd, setup an etcd server with thev2 apienabled (e.g.--enable-v2).
Warning
etcd-v2 andetcd rendezvous use etcd API v2. You MUST enable the v2 API on the etcdserver. Our tests use etcd v3.4.3.
Warning
For etcd-based rendezvous we recommend usingetcd-v2 overetcd which is functionallyequivalent, but uses a revised implementation.etcd is in maintenance mode and will beremoved in a future version.
Definitions#
Node- A physical instance or a container; maps to the unit that the job manager works with.Worker- A worker in the context of distributed training.WorkerGroup- The set of workers that execute the same function (e.g. trainers).LocalWorkerGroup- A subset of the workers in the worker group running on the same node.RANK- The rank of the worker within a worker group.WORLD_SIZE- The total number of workers in a worker group.LOCAL_RANK- The rank of the worker within a local worker group.LOCAL_WORLD_SIZE- The size of the local worker group.rdzv_id- A user-defined id that uniquely identifies the worker group for a job. This id isused by each node to join as a member of a particular worker group.
rdzv_backend- The backend of the rendezvous (e.g.c10d). This is typically a stronglyconsistent key-value store.rdzv_endpoint- The rendezvous backend endpoint; usually in form<host>:<port>.
ANode runsLOCAL_WORLD_SIZE workers which comprise aLocalWorkerGroup. The union ofallLocalWorkerGroups in the nodes in the job comprise theWorkerGroup.
Environment Variables#
The following environment variables are made available to you in your script:
LOCAL_RANK- The local rank.RANK- The global rank.GROUP_RANK- The rank of the worker group. A number between 0 andmax_nnodes. Whenrunning a single worker group per node, this is the rank of the node.ROLE_RANK- The rank of the worker across all the workers that have the same role. The roleof the worker is specified in theWorkerSpec.LOCAL_WORLD_SIZE- The local world size (e.g. number of workers running locally); equals to--nproc-per-nodespecified ontorchrun.WORLD_SIZE- The world size (total number of workers in the job).ROLE_WORLD_SIZE- The total number of workers that was launched with the same role specifiedinWorkerSpec.MASTER_ADDR- The FQDN of the host that is running worker with rank 0; used to initializethe Torch Distributed backend.MASTER_PORT- The port on theMASTER_ADDRthat can be used to host the C10d TCP store.TORCHELASTIC_RESTART_COUNT- The number of worker group restarts so far.TORCHELASTIC_MAX_RESTARTS- The configured maximum number of restarts.TORCHELASTIC_RUN_ID- Equal to the rendezvousrun_id(e.g. unique job id).PYTHON_EXEC- System executable override. If provided, the python user script willuse the value ofPYTHON_EXECas executable. Thesys.executable is used by default.
Deployment#
(Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to bepassed as
--rdzv-endpointtotorchrun)Single-node multi-worker: Start
torchrunon the host to start the agent process whichcreates and monitors a local worker group.Multi-node multi-worker: Start
torchrunwith the same arguments on all the nodesparticipating in training.
When using a job/cluster manager, the entry point command to the multi-node job should betorchrun.
Failure Modes#
Worker failure: For a training job with
nworkers, ifk<=nworkers fail all workersare stopped and restarted up tomax_restarts.Agent failure: An agent failure results in a local worker group failure. It is up to the jobmanager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviorsare supported by the agent.
Node failure: Same as agent failure.
Membership Changes#
Node departure (scale-down): The agent is notified of the departure, all existing workers arestopped, a new
WorkerGroupis formed, and all workers are started with a newRANKandWORLD_SIZE.Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped,a new
WorkerGroupis formed, and all workers are started with a newRANKandWORLD_SIZE.
Important Notices#
This utility and multi-process distributed (single-node ormulti-node) GPU training currently only achieves the best performance usingthe NCCL distributed backend. Thus NCCL backend is the recommended backend touse for GPU training.
The environment variables necessary to initialize a Torch process group are provided to you bythis module, no need for you to pass
RANKmanually. To initialize a process group in yourtraining script, simply run:
>>>importtorch.distributedasdist>>>dist.init_process_group(backend="gloo|nccl")
In your training program, you can either use regular distributed functionsor use
torch.nn.parallel.DistributedDataParallel()module. If yourtraining program uses GPUs for training and you would like to usetorch.nn.parallel.DistributedDataParallel()module,here is how to configure it.
local_rank=int(os.environ["LOCAL_RANK"])model=torch.nn.parallel.DistributedDataParallel(model,device_ids=[local_rank],output_device=local_rank)
Please ensure thatdevice_ids argument is set to be the only GPU device idthat your code will be operating on. This is generally the local rank of theprocess. In other words, thedevice_ids needs to be[int(os.environ("LOCAL_RANK"))],andoutput_device needs to beint(os.environ("LOCAL_RANK")) in order to use thisutility
On failures or membership changes ALL surviving workers are killed immediately. Make sure tocheckpoint your progress. The frequency of checkpoints should depend on your job’s tolerancefor lost work.
This module only supports homogeneous
LOCAL_WORLD_SIZE. That is, it is assumed that allnodes run the same number of local workers (per role).RANKis NOT stable. Between restarts, the local workers on a node can be assigned adifferent range of ranks than before. NEVER hard code any assumptions about the stable-ness ofranks or some correlation betweenRANKandLOCAL_RANK.When using elasticity (
min_size!=max_size) DO NOT hard code assumptions aboutWORLD_SIZEas the world size can change as nodes are allowed to leave and join.It is recommended for your script to have the following structure:
defmain():load_checkpoint(checkpoint_path)initialize()train()deftrain():forbatchiniter(dataset):train_step(batch)ifshould_checkpoint:save_checkpoint(checkpoint_path)
(Recommended) On worker errors, this tool will summarize the details of the error(e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)is heuristically reported as the “Root Cause” error. To get tracebacks as part of thiserror summary print out, you must decorate your main entrypoint function in yourtraining script as shown in the example below. If not decorated, then the summarywill not include the traceback of the exception and will only contain the exitcode.For details on torchelastic error handling see:https://pytorch.org/docs/stable/elastic/errors.html
fromtorch.distributed.elastic.multiprocessing.errorsimportrecord@recorddefmain():# do trainpassif__name__=="__main__":main()