Train script#
Created On: May 04, 2021 | Last Updated On: Feb 09, 2023
If your train script works withtorch.distributed.launch it will continueworking withtorchrun with these differences:
No need to manually pass
RANK,WORLD_SIZE,MASTER_ADDR, andMASTER_PORT.rdzv_backendandrdzv_endpointcan be provided. For most usersthis will be set toc10d(seerendezvous). The defaultrdzv_backendcreates a non-elastic rendezvous whererdzv_endpointholdsthe master address.Make sure you have a
load_checkpoint(path)andsave_checkpoint(path)logic in your script. When any number ofworkers fail we restart all the workers with the same programarguments so you will lose progress up to the most recent checkpoint(seeelastic launch).use_envflag has been removed. If you were parsing local rank by parsingthe--local-rankoption, you need to get the local rank from theenvironment variableLOCAL_RANK(e.g.int(os.environ["LOCAL_RANK"])).
Below is an expository example of a training script that checkpoints on eachepoch, hence the worst-case progress lost on failure is one full epoch worthof training.
defmain():args=parse_args(sys.argv[1:])state=load_checkpoint(args.checkpoint_path)initialize(state)# torch.distributed.run ensures that this will work# by exporting all the env vars needed to initialize the process grouptorch.distributed.init_process_group(backend=args.backend)foriinrange(state.epoch,state.total_num_epochs)forbatchiniter(state.dataset)train(batch,state.model)state.epoch+=1save_checkpoint(state)
For concrete examples of torchelastic-compliant train scripts, visitourexamples page.