- Notifications
You must be signed in to change notification settings - Fork308
Flops counter for neural networks in pytorch framework
License
sovrasov/flops-counter.pytorch
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
This tool is designed to compute the theoretical amount of multiply-add operationsin neural networks. It can also compute the number of parameters andprint per-layer computational cost of a given network.
ptflops has two backends,pytorch andaten.pytorch backend is a legacy one, it considersnn.Modules only. However,it's still useful, since it provides a better par-layer analytics for CNNs. In all other cases it's recommended to useaten backend, which considers aten operations, and therefore it covers more model architectures (including transformers).The default backend isaten. Please, don't usepytorch backend for transformer architectures.
- aten.mm, aten.matmul, aten.addmm, aten.bmm
- aten.convolution
- Use
verbose=Trueto see the operations which were not considered during complexity computation. - This backend prints per-module statistics only for modules directly nested into the root
nn.Module.Deeper modules at the second level of nesting are not shown in the per-layer statistics. ignore_modulesoption forcesptflopsto ignore the listed modules. This can be usefulfor research purposes. For instance, one can drop all convolutions from the counting processspecifyingignore_modules=[torch.ops.aten.convolution, torch.ops.aten._convolution].
- Conv1d/2d/3d (including grouping)
- ConvTranspose1d/2d/3d (including grouping)
- BatchNorm1d/2d/3d, GroupNorm, InstanceNorm1d/2d/3d, LayerNorm
- Activations (ReLU, PReLU, ELU, ReLU6, LeakyReLU, GELU)
- Linear
- Upsample
- Poolings (AvgPool1d/2d/3d, MaxPool1d/2d/3d and adaptive ones)
Experimental support:
- RNN, LSTM, GRU (NLH layout is assumed)
- RNNCell, LSTMCell, GRUCell
- torch.nn.MultiheadAttention
- torchvision.ops.DeformConv2d
- visual transformers fromtimm
- This backend doesn't take into account some of the
torch.nn.functional.*andtensor.*operations. Therefore unsupported operations arenot contributing to the final complexity estimation. Seeptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPINGto check supported ops.Sometimes functional-level hooks conflict with hooks fornn.Module(for instance, custom ones). In that case, counting with these ops can be disabled bypassingbackend_specific_config={"count_functional" : False}. ptflopslaunches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use theinput_constructorargument of theget_model_complexity_info.input_constructoris a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next, this dict would be passed to the model as a keyword arguments.verboseparameter allows to get information about modules that don't contribute to the final numbers.ignore_modulesoption forcesptflopsto ignore the listed modules. This can be usefulfor research purposes. For instance, one can drop all convolutions from the counting processspecifyingignore_modules=[torch.nn.Conv2d].
Pytorch >= 2.0. Usepip install ptflops==0.7.2.2 to work with torch 1.x.
From PyPI:
pip install ptflops
From this repository:
pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git
importtorchvision.modelsasmodelsimporttorchfromptflopsimportget_model_complexity_infowithtorch.cuda.device(0):net=models.densenet161()macs,params=get_model_complexity_info(net, (3,224,224),as_strings=True,backend='pytorch'print_per_layer_stat=True,verbose=True)print('{:<30} {:<8}'.format('Computational complexity: ',macs))print('{:<30} {:<8}'.format('Number of parameters: ',params))macs,params=get_model_complexity_info(net, (3,224,224),as_strings=True,backend='aten'print_per_layer_stat=True,verbose=True)print('{:<30} {:<8}'.format('Computational complexity: ',macs))print('{:<30} {:<8}'.format('Number of parameters: ',params))
If ptflops was useful for your paper or tech report, please cite me:
@online{ptflops, author = {Vladislav Sovrasov}, title = {ptflops: a flops counting tool for neural networks in pytorch framework}, year = {2018-2024}, url = {https://github.com/sovrasov/flops-counter.pytorch},}Thanks to @warmspringwinds and Horace He for the initial version of the script.
| Model | Input Resolution | Params(M) | MACs(G) (pytorch) | MACs(G) (aten) |
|---|---|---|---|---|
| alexnet | 224x224 | 61.10 | 0.72 | 0.71 |
| convnext_base | 224x224 | 88.59 | 15.43 | 15.38 |
| densenet121 | 224x224 | 7.98 | 2.90 | |
| efficientnet_b0 | 224x224 | 5.29 | 0.41 | |
| efficientnet_v2_m | 224x224 | 54.14 | 5.43 | |
| googlenet | 224x224 | 13.00 | 1.51 | |
| inception_v3 | 224x224 | 27.16 | 5.75 | 5.71 |
| maxvit_t | 224x224 | 30.92 | 5.48 | |
| mnasnet1_0 | 224x224 | 4.38 | 0.33 | |
| mobilenet_v2 | 224x224 | 3.50 | 0.32 | |
| mobilenet_v3_large | 224x224 | 5.48 | 0.23 | |
| regnet_y_1_6gf | 224x224 | 11.20 | 1.65 | |
| resnet18 | 224x224 | 11.69 | 1.83 | 1.81 |
| resnet50 | 224x224 | 25.56 | 4.13 | 4.09 |
| resnext50_32x4d | 224x224 | 25.03 | 4.29 | |
| shufflenet_v2_x1_0 | 224x224 | 2.28 | 0.15 | |
| squeezenet1_0 | 224x224 | 1.25 | 0.84 | 0.82 |
| vgg16 | 224x224 | 138.36 | 15.52 | 15.48 |
| vit_b_16 | 224x224 | 86.57 | 17.61 (wrong) | 16.86 |
| wide_resnet50_2 | 224x224 | 68.88 | 11.45 |
Model | Input Resolution | Params(M) | MACs(G)
About
Flops counter for neural networks in pytorch framework
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.
Contributors13
Uh oh!
There was an error while loading.Please reload this page.