- Notifications
You must be signed in to change notification settings - Fork176
PyTorch implementation of CNNs for CIFAR benchmark
License
BIGBALLON/CIFAR-ZOO
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Status:Archive ( Final test withPyTorch 1.7 and no longer maintained, I would recommend you to usepycls powered byFAIR, which is a simple and flexible codebase for image classification )
This repository contains the pytorch code for multiple CNN architectures and improve methods based on the following papers,hope the implementation and results will helpful for your research!!
- Architecure
- (lenet)LeNet-5, convolutional neural networks
- (alexnet)ImageNet Classification with Deep Convolutional Neural Networks
- (vgg)Very Deep Convolutional Networks for Large-Scale Image Recognition
- (resnet)Deep Residual Learning for Image Recognition
- (preresnet)Identity Mappings in Deep Residual Networks
- (resnext)Aggregated Residual Transformations for Deep Neural Networks
- (densenet)Densely Connected Convolutional Networks
- (senet)Squeeze-and-Excitation Networks
- (bam)BAM: Bottleneck Attention Module
- (cbam)CBAM: Convolutional Block Attention Module
- (genet)Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks
- (sknet)SKNet: Selective Kernel Networks
- Regularization
- Learning Rate Scheduler
- Python (>=3.6)
- PyTorch (>=1.1.0)
- Tensorboard(>=1.4.0) (forvisualization)
- Other dependencies (pyyaml, easydict)
pip install -r requirements.txt
simply run the cmd for the training:
## 1 GPU for lenetCUDA_VISIBLE_DEVICES=0 python -u train.py --work-path ./experiments/cifar10/lenet## resume from ckptCUDA_VISIBLE_DEVICES=0 python -u train.py --work-path ./experiments/cifar10/lenet --resume## 2 GPUs for resnet1202CUDA_VISIBLE_DEVICES=0,1 python -u train.py --work-path ./experiments/cifar10/preresnet1202## 4 GPUs for densenet190bcCUDA_VISIBLE_DEVICES=0,1,2,3 python -u train.py --work-path ./experiments/cifar10/densenet190bc## 1 GPU for vgg19 inferenceCUDA_VISIBLE_DEVICES=0 python -u eval.py --work-path ./experiments/cifar10/vgg19
We use yaml fileconfig.yaml
to save the parameters, check any files in./experimets
for more details.
You can see the training curve via tensorboard,tensorboard --logdir path-to-event --port your-port
.
The training log will be dumped via logging, checklog.txt
in your work path.
architecture | params | batch size | epoch | C10 test acc (%) | C100 test acc (%) |
---|---|---|---|---|---|
Lecun | 62K | 128 | 250 | 67.46 | 34.10 |
alexnet | 2.4M | 128 | 250 | 75.56 | 38.67 |
vgg19 | 20M | 128 | 250 | 93.00 | 72.07 |
preresnet20 | 0.27M | 128 | 250 | 91.88 | 67.03 |
preresnet110 | 1.7M | 128 | 250 | 94.24 | 72.96 |
preresnet1202 | 19.4M | 128 | 250 | 94.74 | 75.28 |
densenet100bc | 0.76M | 64 | 300 | 95.08 | 77.55 |
densenet190bc | 25.6M | 64 | 300 | 96.11 | 82.59 |
resnext29_16x64d | 68.1M | 128 | 300 | 95.94 | 83.18 |
se_resnext29_16x64d | 68.6M | 128 | 300 | 96.15 | 83.65 |
cbam_resnext29_16x64d | 68.7M | 128 | 300 | 96.27 | 83.62 |
ge_resnext29_16x64d | 70.0M | 128 | 300 | 96.21 | 83.57 |
PS: the default data augmentation methods areRandomCrop
+RandomHorizontalFlip
+Normalize
,
and the√
means which additional method be used. 🍰
architecture | epoch | cutout | mixup | C10 test acc (%) |
---|---|---|---|---|
preresnet20 | 250 | 91.88 | ||
preresnet20 | 250 | √ | 92.57 | |
preresnet20 | 250 | √ | 92.71 | |
preresnet20 | 250 | √ | √ | 92.66 |
preresnet110 | 250 | 94.24 | ||
preresnet110 | 250 | √ | 94.67 | |
preresnet110 | 250 | √ | 94.94 | |
preresnet110 | 250 | √ | √ | 95.66 |
se_resnext29_16x64d | 300 | 96.15 | ||
se_resnext29_16x64d | 300 | √ | 96.60 | |
se_resnext29_16x64d | 300 | √ | 96.86 | |
se_resnext29_16x64d | 300 | √ | √ | 97.03 |
cbam_resnext29_16x64d | 300 | √ | √ | 97.16 |
ge_resnext29_16x64d | 300 | √ | √ | 97.19 |
-- | -- | -- | -- | -- |
shake_resnet26_2x64d | 1800 | 96.94 | ||
shake_resnet26_2x64d | 1800 | √ | 97.20 | |
shake_resnet26_2x64d | 1800 | √ | 97.42 | |
shake_resnet26_2x64d | 1800 | √ | √ | 97.71 |
PS:shake_resnet26_2x64d
achieved97.71% test accuracy withcutout
andmixup
!!
It's cool, right?
architecture | epoch | step decay | cosine | htd(-6,3) | cutout | mixup | C10 test acc (%) |
---|---|---|---|---|---|---|---|
preresnet20 | 250 | √ | 91.88 | ||||
preresnet20 | 250 | √ | 92.13 | ||||
preresnet20 | 250 | √ | 92.44 | ||||
preresnet20 | 250 | √ | √ | √ | 93.30 | ||
preresnet110 | 250 | √ | 94.24 | ||||
preresnet110 | 250 | √ | 94.48 | ||||
preresnet110 | 250 | √ | 94.82 | ||||
preresnet110 | 250 | √ | √ | √ | 95.88 |
Provided codes were adapted from
- kuangliu/pytorch-cifar
- bearpaw/pytorch-classification
- timgaripov/swa
- xgastaldi/shake-shake
- uoguelph-mlrg/Cutout
- facebookresearch/mixup-cifar10
- BIGBALLON/cifar-10-cnn
- BayesWatch/pytorch-GENet
- Jongchan/attention-module
- pppLang/SKNet
Feel free to contact me if you have any suggestions or questions, issues are welcome,
create a PR if you find any bugs or you want to contribute. 😊
@misc{bigballon2019cifarzoo, author = {Wei Li}, title = {CIFAR-ZOO: PyTorch implementation of CNNs for CIFAR dataset}, howpublished = {\url{https://github.com/BIGBALLON/CIFAR-ZOO}}, year = {2019}}
About
PyTorch implementation of CNNs for CIFAR benchmark