Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

PyTorch image models, scripts, pretrained weights -- (SE)ResNet/ResNeXT, DPN, MobileNet-V3/V2/V1, MNASNet, Single-Path NAS, FBNet, and more

NotificationsYou must be signed in to change notification settings

nunofernandes-plight/pytorch-image-models

 
 

Repository files navigation

Introduction

For each competition, personal, or freelance project involving images + Convolution Neural Networks, I build on top of an evolving collection of code and models. This repo contains a (somewhat) cleaned up and paired down iteration of that code. Hopefully it'll be of use to others.

The work of many others is present here. I've tried to make sure all source material is acknowledged:

Models

I've included a few of my favourite models, but this is not an exhaustive collection. You can't do better than Cadene's collection in that regard. Most models do have pretrained weights from their respective sources or original authors.

  • ResNet/ResNeXt (fromtorchvision with ResNeXt mods by myself)
    • ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152, ResNeXt50 (32x4d), ResNeXt101 (32x4d and 64x4d)
  • DenseNet (fromtorchvision)
    • DenseNet-121, DenseNet-169, DenseNet-201, DenseNet-161
  • Squeeze-and-Excitation ResNet/ResNeXt (fromCadene with some pretrained weight additions by myself)
    • SENet-154, SE-ResNet-18, SE-ResNet-34, SE-ResNet-50, SE-ResNet-101, SE-ResNet-152, SE-ResNeXt-26 (32x4d), SE-ResNeXt50 (32x4d), ResNeXt101 (32x4d)
  • Inception-ResNet-V2 and Inception-V4 (fromCadene )
  • Xception (fromCadene)
  • PNasNet (fromCadene)
  • DPN (fromme, weights hosted by Cadene)
    • DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
  • Generic MobileNet (from my standaloneGenMobileNet) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
    • MNASNet B1, A1 (Squeeze-Excite), and Small
    • MobileNet-V1
    • MobileNet-V2
    • MobileNet-V3 (work in progress, validating config)
    • ChamNet (details hard to find, currently an educated guess)
    • FBNet-C (TODO A/B variants)

The full list of model strings that can be passed to model factory via--model arg for train, validation, inference scripts:

chamnetv1_100chamnetv2_100densenet121densenet161densenet169densenet201dpn107dpn131dpn68dpn68bdpn92dpn98fbnetc_100inception_resnet_v2inception_v4mnasnet_050mnasnet_075mnasnet_100mnasnet_140mnasnet_smallmobilenetv1_100mobilenetv2_100mobilenetv3_050mobilenetv3_075mobilenetv3_100pnasnet5largeresnet101resnet152resnet18resnet34resnet50resnext101_32x4dresnext101_64x4dresnext152_32x4dresnext50_32x4dsemnasnet_050semnasnet_075semnasnet_100semnasnet_140seresnet101seresnet152seresnet18seresnet34seresnet50seresnext101_32x4dseresnext26_32x4dseresnext50_32x4dspnasnet_100tflite_mnasnet_100tflite_semnasnet_100xception

Features

Several (less common) features that I often utilize in my projects are included. Many of their additions are the reason why I maintain my own set of models, instead of using others' via PIP:

  • All models have a common default configuration interface and API for
    • accessing/changing the classifier -get_classifier andreset_classifier
    • doing a forward pass on just the features -forward_features
    • these makes it easy to write consistent network wrappers that work with any of the models
  • All models have a consistent pretrained weight loader that adapts last linear if necessary, and from 3 to 1 channel input if desired
  • The train script works in several process/GPU modes:
    • NVIDIA DDP w/ a single GPU per process, multiple processes with APEX present (AMP mixed-precision optional)
    • PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled)
    • PyTorch w/ single GPU single process (AMP optional)
  • A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights.
  • A 'Test Time Pool' wrapper that can wrap any of the included models and usually provide improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs)
  • Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Label Smoothing, etc)
  • Mixup (as inhttps://arxiv.org/abs/1710.09412) - currently implementing/testing
  • An inference script that dumps output to CSV is provided as an example

Custom Weights

I've leveraged the training scripts in this repository to train a few of the models with missing weights to good levels of performance. These numbers are all for 224x224 training and validation image sizing with the usual 87.5% validation crop.

ModelPrec@1 (Err)Prec@5 (Err)Param #Image Scaling
ResNeXt-50 (32x4d)78.512 (21.488)94.042 (5.958)25Mbicubic
SE-ResNeXt-26 (32x4d)77.104 (22.896)93.316 (6.684)16.8Mbicubic
SE-ResNet-3474.808 (25.192)92.124 (7.876)22Mbilinear
SE-ResNet-1871.742 (28.258)90.334 (9.666)11.8Mbicubic
Single-Path NASNet 1.0074.084 (25.916)91.818 (8.182)4.42Mbilinear

Ported Weights

ModelPrec@1 (Err)Prec@5 (Err)Param #Image ScalingSource
MNASNet 1.00 (B1)72.398 (27.602)90.930 (9.070)4.36MbicubicGoogle TFLite
SE-MNASNet 1.00 (A1)73.086 (26.914)91.336 (8.664)3.87MbicubicGoogle TFLite

NOTE: For some reason I can't hit the stated accuracy with my impl of MNASNet and Google's tflite weights. Using a TF equivalent to 'SAME' padding was important to get > 70%, but something small is still missing. Trying to train my own weights from scratch with these models has so far to leveled off in the same 72-73% range.

Script Usage

Training

The variety of training args is large and not all combinations of options (or even options) have been fully tested. For the training dataset folder, specify the folder to the base that contains atrain andvalidation folder.

To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value:

./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 -j 4

NOTE: NVIDIA APEX should be installed to run in per-process distributed via DDP or to enable AMP mixed precision with the --amp flag

Validation / Inference

Validation and inference scripts are similar in usage. One outputs metrics on a validation set and the other outputs topk class ids in a csv. Specify the folder containing validation images, not the base as in training script.

To validate with the model's pretrained weights (if they exist):

python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained

To run inference from a checkpoint:

python inference.py /imagenet/validation/ --model mobilenetv3_100 --checkpoint ./output/model_best.pth.tar

TODO

A number of additions planned in the future for various projects, incl

  • Find optimal training hyperparams and create/port pretraiend weights for the generic MobileNet variants
  • Do a model performance (speed + accuracy) benchmarking across all models (make runable as script)
  • More training experiments
  • Make folder/file layout compat with usage as a module
  • Add usage examples to comments, good hyper params for training
  • Comments, cleanup and the usual things that get pushed back

About

PyTorch image models, scripts, pretrained weights -- (SE)ResNet/ResNeXT, DPN, MobileNet-V3/V2/V1, MNASNet, Single-Path NAS, FBNet, and more

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python100.0%

[8]ページ先頭

©2009-2025 Movatter.jp