Rate this Page

Train script#

Created On: May 04, 2021 | Last Updated On: Feb 09, 2023

If your train script works withtorch.distributed.launch it will continueworking withtorchrun with these differences:

  1. No need to manually passRANK,WORLD_SIZE,MASTER_ADDR, andMASTER_PORT.

  2. rdzv_backend andrdzv_endpoint can be provided. For most usersthis will be set toc10d (seerendezvous). The defaultrdzv_backend creates a non-elastic rendezvous whererdzv_endpoint holdsthe master address.

  3. Make sure you have aload_checkpoint(path) andsave_checkpoint(path) logic in your script. When any number ofworkers fail we restart all the workers with the same programarguments so you will lose progress up to the most recent checkpoint(seeelastic launch).

  4. use_env flag has been removed. If you were parsing local rank by parsingthe--local-rank option, you need to get the local rank from theenvironment variableLOCAL_RANK (e.g.int(os.environ["LOCAL_RANK"])).

Below is an expository example of a training script that checkpoints on eachepoch, hence the worst-case progress lost on failure is one full epoch worthof training.

defmain():args=parse_args(sys.argv[1:])state=load_checkpoint(args.checkpoint_path)initialize(state)# torch.distributed.run ensures that this will work# by exporting all the env vars needed to initialize the process grouptorch.distributed.init_process_group(backend=args.backend)foriinrange(state.epoch,state.total_num_epochs)forbatchiniter(state.dataset)train(batch,state.model)state.epoch+=1save_checkpoint(state)

For concrete examples of torchelastic-compliant train scripts, visitourexamples page.