- Notifications
You must be signed in to change notification settings - Fork13
Code for the NeurIPS 2022 paper "Optimal Brain Compression: A Framework for Accurate Post-Training Quantization and Pruning".
IST-DASLab/OBC
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
This repository contains efficient implementations of ExactOBS for quantization,unstructured-, block- and N:M pruning, introduced in the NeurIPS 2022 paper"Optimal Brain Compression: A Framework for Accurate Post-Training Quantizationand Pruning".
trueobs.py
: efficient implementations of ExactOBS for all compression typesmain_trueobs.py
: code to run ExactOBSpost_proc.py
: post processing operations like statistics correctionsdatabase.py
: generating databases for non-uniform compressionspdy.py
: implementation of the DP algorithm for finding non-uniformcompression configurations; adapted from code provided by the authors of SPDY [9]modelutils.py
: model utilitiesdatautils.py
: data utilitiesquant.py
: quantization utilities
NOTE: The code as provided here only fully supports torchvision ResNet variants(the full integration of YOLO and BERT models is omitted due to large amountsof complex dependencies).
First, make sure ImageNet is located/linked to../imagenet
(alternatively,you can specifiy the--datapath
argument for all commands).
# Quantize weights and activationspython main_trueobs.py rn18 imagenet quant --wbits 4 --abits 4 --save rn18_4w4a.pth# Prune to the N:M patternpython main_trueobs.py rn18 imagenet nmprune --prunen 2 --prunem 4 --save rn18_24.pth# Generate an unstructured pruning databasemkdir models_unstrpython main_trueobs.py rn18 imagenet unstr --sparse-dir models_unstr# Generate a 4-block pruning databasemkdir models_4blockpython main_trueobs.py rn18 imagenet blocked --sparse-dir models_blocked# Quantize a 2:4 pruned modelpython main_trueobs.py rn18 imagenet quant --wbits 4 --abits 4 --load rn18_24.pth --save rn18_24_4w4a.pth
# Batchnorm tuningpython postproc.py rn18 imagenet rn18_24.pth --bnt# Statistics correctionpython postproc.py rn18 imagenet rn18_24.pth --statcorr --statcorr-samples 1024
mkdir scores# Unstructured pruning# Setup databasemkdir models_unstrpython main_trueobs.py rn18 imagenet unstr --sparse-dir models_unstr# Compute corresponding lossespython database.py rn18 imagenet unstr loss# Run DP algorithm to determine per-layer compression targets python spdy.py rn18 imagenet 2 unstr --dp # Stitch profile, apply batchnorm resetting and compute validation accuracy python postproc.py rn18 imagenet rn18_unstr_200x_dp.txt --database unstr --bnt# Mixed quantization + 2:4 pruningmkdir models_nmmkdir models_quantmkdir models_nm_quantpython main_trueobs.py rn18 imagenet nmprune --save models_nm/rn18_24.pthpython main_trueobs.py rn18 imagenet quant --wbits 8 --abits 8 --save models_quant/rn18_8w8a.pthpython main_trueobs.py rn18 imagenet quant --wbits 4 --abits 4 --save models_quant/rn18_4w4a.pthpython main_trueobs.py rn18 imagenet quant --wbits 8 --abits 8 --load models_nm/rn18_24.pth --save models_nm_quant/rn18_24_8w8a.pth python main_trueobs.py rn18 imagenet quant --wbits 4 --abits 4 --load models_nm/rn18_24.pth --save models_nm_quant/rn18_24_4w4a.pth python database.py rn18 imagenet mixed losspython spdy.py rn18 imagenet 8 mixed --dppython postproc.py rn18 imagenet rn18_mixed_800x_dp.txt --database mixed --bnt
Before using our BERT integration, please download ourpretrained checkpoints and move them to thebertsquad
folder.Then you should be able to use most features described above by passingbertsquad
(orbertsquad6
for smaller variants) as the model name andsquad
as the dataset name.The code was tested withtransformers==4.21.2
anddatasets==1.17.0
.
@article{frantar2022obc, title={{Optimal Brain Compression:} A Framework for Accurate Post-Training Quantization and Pruning}, author={Frantar, Elias and Singh, Sidak Pal and Alistarh, Dan}, journal={Advances in Neural Information Processing Systems}, volume={36}, year={2022}}