Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Stochastic Weight Averaging Tutorials using pytorch.

License

NotificationsYou must be signed in to change notification settings

hoya012/swa-tutorials-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Stochastic Weight Averaging Tutorials using pytorch. Based onPyTorch 1.6 Official Features (Stochastic Weight Averaging), implement classification codebase using custom dataset.

  • author: hoya012
  • last update: 2020.10.23

0. Experimental Setup

0-1. Prepare Library

  • Need to install PyTorch and Captum
pipinstall-rrequirements.txt

0-2. Download dataset (Kaggle Intel Image Classification)

This Data contains around 25k images of size 150x150 distributed under 6 categories.{'buildings' -> 0,'forest' -> 1,'glacier' -> 2,'mountain' -> 3,'sea' -> 4,'street' -> 5 }

  • Makedata folder and move dataset intodata folder.

1. Baseline Training

  • ImageNet Pretrained ResNet-18 from torchvision.models
  • Batch Size 256 / Epochs 120 / Initial Learning Rate 0.0001
  • Training Augmentation: Resize((256, 256)), RandomHorizontalFlip()
  • Adam + Cosine Learning rate scheduling with warmup
  • I tried NVIDIA Pascal GPU - GTX 1080 Ti 1 GPU
pythonmain.py--checkpoint_namebaseline;

2. Stochastic Weight Averaging Training

In PyTorch 1.6, Stochastic Weight Averaging is very easy to use! Thanks to PyTorch..

  • PyTorch's official tutorial's guide
fromtorch.optim.swa_utilsimportAveragedModel,SWALRfromtorch.optim.lr_schedulerimportCosineAnnealingLRloader,optimizer,model,loss_fn= ...swa_model=AveragedModel(model)scheduler=CosineAnnealingLR(optimizer,T_max=100)swa_start=5swa_scheduler=SWALR(optimizer,swa_lr=0.05)forepochinrange(100):forinput,targetinloader:optimizer.zero_grad()loss_fn(model(input),target).backward()optimizer.step()ifepoch>swa_start:swa_model.update_parameters(model)swa_scheduler.step()else:scheduler.step()# Update bn statistics for the swa_model at the endtorch.optim.swa_utils.update_bn(loader,swa_model)# Use swa_model to make predictions on test datapreds=swa_model(test_input)
  • My own implementations
# in main.py""" define model and learning rate scheduler for stochastic weight averaging """swa_model=torch.optim.swa_utils.AveragedModel(model)swa_scheduler=SWALR(optimizer,swa_lr=args.swa_lr)...# in learning/trainer.pyforbatch_idx, (inputs,labels)inenumerate(data_loader):ifnotargs.decay_type=='swa':self.scheduler.step()else:ifepoch<=args.swa_start:self.scheduler.step()ifepoch>args.swa_startandargs.decay_type=='swa':self.swa_model.update_parameters(self.model)self.swa_scheduler.step()...# in main.pyswa_model=swa_model.cpu()torch.optim.swa_utils.update_bn(train_loader,swa_model)swa_model=swa_model.cuda()

Run Script (Command Line)

pythonmain.py--checkpoint_nameswa--decay_typeswa--swa_start90--swa_lr5e-5;

3. Performance Table

  • B : Baseline
  • SWA : Stochastic Weight Averaging
    • SWA_{swa_start}_{swa_lr}
AlgorithmTest Accuracy
B94.10
SWA_90_0.0580.53
SWA_90_1e-494.20
SWA_90_5e-493.87
SWA_90_1e-594.23
SWA_90_5e-594.57
SWA_75_5e-594.27
SWA_60_5e-594.33

4. Code Reference

About

Stochastic Weight Averaging Tutorials using pytorch.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages


[8]ページ先頭

©2009-2025 Movatter.jp