- Notifications
You must be signed in to change notification settings - Fork175
PyTorch implementation of CNNs for CIFAR benchmark
License
BIGBALLON/CIFAR-ZOO
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Status:Archive ( Final test withPyTorch 1.7 and no longer maintained, I would recommend you to usepycls powered byFAIR, which is a simple and flexible codebase for image classification )
This repository contains the pytorch code for multiple CNN architectures and improve methods based on the following papers,hope the implementation and results will helpful for your research!!
- Architecure
- (lenet)LeNet-5, convolutional neural networks
- (alexnet)ImageNet Classification with Deep Convolutional Neural Networks
- (vgg)Very Deep Convolutional Networks for Large-Scale Image Recognition
- (resnet)Deep Residual Learning for Image Recognition
- (preresnet)Identity Mappings in Deep Residual Networks
- (resnext)Aggregated Residual Transformations for Deep Neural Networks
- (densenet)Densely Connected Convolutional Networks
- (senet)Squeeze-and-Excitation Networks
- (bam)BAM: Bottleneck Attention Module
- (cbam)CBAM: Convolutional Block Attention Module
- (genet)Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks
- (sknet)SKNet: Selective Kernel Networks
- Regularization
- Learning Rate Scheduler
- Python (>=3.6)
- PyTorch (>=1.1.0)
- Tensorboard(>=1.4.0) (forvisualization)
- Other dependencies (pyyaml, easydict)
pip install -r requirements.txt
simply run the cmd for the training:
## 1 GPU for lenetCUDA_VISIBLE_DEVICES=0 python -u train.py --work-path ./experiments/cifar10/lenet## resume from ckptCUDA_VISIBLE_DEVICES=0 python -u train.py --work-path ./experiments/cifar10/lenet --resume## 2 GPUs for resnet1202CUDA_VISIBLE_DEVICES=0,1 python -u train.py --work-path ./experiments/cifar10/preresnet1202## 4 GPUs for densenet190bcCUDA_VISIBLE_DEVICES=0,1,2,3 python -u train.py --work-path ./experiments/cifar10/densenet190bc## 1 GPU for vgg19 inferenceCUDA_VISIBLE_DEVICES=0 python -u eval.py --work-path ./experiments/cifar10/vgg19
We use yaml fileconfig.yaml
to save the parameters, check any files in./experimets
for more details.
You can see the training curve via tensorboard,tensorboard --logdir path-to-event --port your-port
.
The training log will be dumped via logging, checklog.txt
in your work path.
architecture | params | batch size | epoch | C10 test acc (%) | C100 test acc (%) |
---|---|---|---|---|---|
Lecun | 62K | 128 | 250 | 67.46 | 34.10 |
alexnet | 2.4M | 128 | 250 | 75.56 | 38.67 |
vgg19 | 20M | 128 | 250 | 93.00 | 72.07 |
preresnet20 | 0.27M | 128 | 250 | 91.88 | 67.03 |
preresnet110 | 1.7M | 128 | 250 | 94.24 | 72.96 |
preresnet1202 | 19.4M | 128 | 250 | 94.74 | 75.28 |
densenet100bc | 0.76M | 64 | 300 | 95.08 | 77.55 |
densenet190bc | 25.6M | 64 | 300 | 96.11 | 82.59 |
resnext29_16x64d | 68.1M | 128 | 300 | 95.94 | 83.18 |
se_resnext29_16x64d | 68.6M | 128 | 300 | 96.15 | 83.65 |
cbam_resnext29_16x64d | 68.7M | 128 | 300 | 96.27 | 83.62 |
ge_resnext29_16x64d | 70.0M | 128 | 300 | 96.21 | 83.57 |
PS: the default data augmentation methods areRandomCrop
+RandomHorizontalFlip
+Normalize
,
and the√
means which additional method be used. 🍰
architecture | epoch | cutout | mixup | C10 test acc (%) |
---|---|---|---|---|
preresnet20 | 250 | 91.88 | ||
preresnet20 | 250 | √ | 92.57 | |
preresnet20 | 250 | √ | 92.71 | |
preresnet20 | 250 | √ | √ | 92.66 |
preresnet110 | 250 | 94.24 | ||
preresnet110 | 250 | √ | 94.67 | |
preresnet110 | 250 | √ | 94.94 | |
preresnet110 | 250 | √ | √ | 95.66 |
se_resnext29_16x64d | 300 | 96.15 | ||
se_resnext29_16x64d | 300 | √ | 96.60 | |
se_resnext29_16x64d | 300 | √ | 96.86 | |
se_resnext29_16x64d | 300 | √ | √ | 97.03 |
cbam_resnext29_16x64d | 300 | √ | √ | 97.16 |
ge_resnext29_16x64d | 300 | √ | √ | 97.19 |
-- | -- | -- | -- | -- |
shake_resnet26_2x64d | 1800 | 96.94 | ||
shake_resnet26_2x64d | 1800 | √ | 97.20 | |
shake_resnet26_2x64d | 1800 | √ | 97.42 | |
shake_resnet26_2x64d | 1800 | √ | √ | 97.71 |
PS:shake_resnet26_2x64d
achieved97.71% test accuracy withcutout
andmixup
!!
It's cool, right?
architecture | epoch | step decay | cosine | htd(-6,3) | cutout | mixup | C10 test acc (%) |
---|---|---|---|---|---|---|---|
preresnet20 | 250 | √ | 91.88 | ||||
preresnet20 | 250 | √ | 92.13 | ||||
preresnet20 | 250 | √ | 92.44 | ||||
preresnet20 | 250 | √ | √ | √ | 93.30 | ||
preresnet110 | 250 | √ | 94.24 | ||||
preresnet110 | 250 | √ | 94.48 | ||||
preresnet110 | 250 | √ | 94.82 | ||||
preresnet110 | 250 | √ | √ | √ | 95.88 |
Provided codes were adapted from
- kuangliu/pytorch-cifar
- bearpaw/pytorch-classification
- timgaripov/swa
- xgastaldi/shake-shake
- uoguelph-mlrg/Cutout
- facebookresearch/mixup-cifar10
- BIGBALLON/cifar-10-cnn
- BayesWatch/pytorch-GENet
- Jongchan/attention-module
- pppLang/SKNet
Feel free to contact me if you have any suggestions or questions, issues are welcome,
create a PR if you find any bugs or you want to contribute. 😊
@misc{bigballon2019cifarzoo, author = {Wei Li}, title = {CIFAR-ZOO: PyTorch implementation of CNNs for CIFAR dataset}, howpublished = {\url{https://github.com/BIGBALLON/CIFAR-ZOO}}, year = {2019}}
About
PyTorch implementation of CNNs for CIFAR benchmark
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.
Contributors3
Uh oh!
There was an error while loading.Please reload this page.