- Notifications
You must be signed in to change notification settings - Fork12
JYWa/Overlap_Local_SGD
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Code to reproduce the experiments reported in this paper:
Jianyu Wang, Hao Liang, Gauri Joshi, "Overlap Local-SGD: An Algorithmic Approach to Hide Communication Delays in Distributed SGD," ICASSP 2020.(arXiv)
This repo contains the implementations of the following algorithms:
- Local SGDStich ICLR 2018,Yu et al. AAAI 2019,Wang and Joshi 2018
- Overlap-Local-SGD (proposed in this paper)
- Elastic Averaging SGDZhang et al. NeurIPS 2015
- CoCoD-SGDShen et al. IJCAI 2019
- Blockwise Model-update Filtering (BMUF)Chen and Huo ICASSP 2016, also equivalent toSlowMo-Local SGD.
Please cite this paper if you use this code for your research/projects.
The code runs on Python 3.5 with PyTorch 1.0.0 and torchvision 0.2.1.The non-blocking communication is implemented using Python threading package.
We implement all the above mentioned algorithms as subclasses oftorch.optim.optimizer. A typical usage is shown as follows:
importdistoptim# Before training# define the optimizer# One can use: 1) LocalSGD (including BMUF); 2) OverlapLocalSGD;# 3) EASGD; 4) CoCoDSGD# tau is the number of local updates / communication periodoptimizer=distoptim.SELECTED_OPTIMIZER(tau)......# define model, criterion, logging, etc..# Start trainingforbatch_id, (data,label)inenumerate(data_loader):# same as serial trainingoutput=model(data)# forwardloss=criterion(output,label)loss.backward()# backwardoptimizer.step()# gradient stepoptimizer.zero_grad()# additional line to average local models at workers# communication happens after every tau iterations# optimizer has its own iteration counter insideoptimizer.average()
In addition, one need to initialize the process group as described in thisdocumentation. In our private cluster, each machine has one GPU.
# backend = gloo or nccl# rank: 0,1,2,3,...# size: number of workers# h0 is the host name of worker0, you need to change ittorch.distributed.init_process_group(backend=args.backend,init_method='tcp://h0:22000',rank=args.rank,world_size=args.size)
@article{wang2020overlap,title={Overlap Local-{SGD}: An Algorithmic Approach to Hide Communication Delays in Distributed {SGD}},author={Wang, Jianyu and Liang, Hao and Joshi, Gauri},journal={arXiv preprint arXiv:2002.09539},year={2020}}