- Notifications
You must be signed in to change notification settings - Fork4
Stochastic Weight Averaging Tutorials using pytorch.
License
NotificationsYou must be signed in to change notification settings
hoya012/swa-tutorials-pytorch
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
Stochastic Weight Averaging Tutorials using pytorch. Based onPyTorch 1.6 Official Features (Stochastic Weight Averaging), implement classification codebase using custom dataset.
- author: hoya012
- last update: 2020.10.23
- Need to install PyTorch and Captum
pipinstall-rrequirements.txt
This Data contains around 25k images of size 150x150 distributed under 6 categories.{'buildings' -> 0,'forest' -> 1,'glacier' -> 2,'mountain' -> 3,'sea' -> 4,'street' -> 5 }
- Make
datafolder and move dataset intodatafolder.
- ImageNet Pretrained ResNet-18 from torchvision.models
- Batch Size 256 / Epochs 120 / Initial Learning Rate 0.0001
- Training Augmentation: Resize((256, 256)), RandomHorizontalFlip()
- Adam + Cosine Learning rate scheduling with warmup
- I tried NVIDIA Pascal GPU - GTX 1080 Ti 1 GPU
pythonmain.py--checkpoint_namebaseline;
In PyTorch 1.6, Stochastic Weight Averaging is very easy to use! Thanks to PyTorch..
- PyTorch's official tutorial's guide
fromtorch.optim.swa_utilsimportAveragedModel,SWALRfromtorch.optim.lr_schedulerimportCosineAnnealingLRloader,optimizer,model,loss_fn= ...swa_model=AveragedModel(model)scheduler=CosineAnnealingLR(optimizer,T_max=100)swa_start=5swa_scheduler=SWALR(optimizer,swa_lr=0.05)forepochinrange(100):forinput,targetinloader:optimizer.zero_grad()loss_fn(model(input),target).backward()optimizer.step()ifepoch>swa_start:swa_model.update_parameters(model)swa_scheduler.step()else:scheduler.step()# Update bn statistics for the swa_model at the endtorch.optim.swa_utils.update_bn(loader,swa_model)# Use swa_model to make predictions on test datapreds=swa_model(test_input)
- My own implementations
# in main.py""" define model and learning rate scheduler for stochastic weight averaging """swa_model=torch.optim.swa_utils.AveragedModel(model)swa_scheduler=SWALR(optimizer,swa_lr=args.swa_lr)...# in learning/trainer.pyforbatch_idx, (inputs,labels)inenumerate(data_loader):ifnotargs.decay_type=='swa':self.scheduler.step()else:ifepoch<=args.swa_start:self.scheduler.step()ifepoch>args.swa_startandargs.decay_type=='swa':self.swa_model.update_parameters(self.model)self.swa_scheduler.step()...# in main.pyswa_model=swa_model.cpu()torch.optim.swa_utils.update_bn(train_loader,swa_model)swa_model=swa_model.cuda()
pythonmain.py--checkpoint_nameswa--decay_typeswa--swa_start90--swa_lr5e-5;
- B : Baseline
- SWA : Stochastic Weight Averaging
- SWA_{swa_start}_{swa_lr}
| Algorithm | Test Accuracy |
|---|---|
| B | 94.10 |
| SWA_90_0.05 | 80.53 |
| SWA_90_1e-4 | 94.20 |
| SWA_90_5e-4 | 93.87 |
| SWA_90_1e-5 | 94.23 |
| SWA_90_5e-5 | 94.57 |
| SWA_75_5e-5 | 94.27 |
| SWA_60_5e-5 | 94.33 |
- Baseline Code:https://github.com/hoya012/carrier-of-tricks-for-classification-pytorch
- Gradual Warmup Scheduler:https://github.com/ildoonet/pytorch-gradual-warmup-lr
- PyTorch Stochastic Weight Averaging:https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging
About
Stochastic Weight Averaging Tutorials using pytorch.
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Releases
No releases published
Packages0
No packages published