- Notifications
You must be signed in to change notification settings - Fork0
wustl-cig/MBDL_Pruning
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Model-based deep learning (MBDL) is a powerful methodology for designing deep models to solve imaging inverse problems. MBDL networks can be seen as iterative algorithms that estimate the desired image using a physical measurement model and a learned image prior specified using a convolutional neural net (CNNs). The iterative nature of MBDL networks increases the test-time computational complexity, which limits their applicability in certain large-scale applications. Here we make two contributions to address this issue: First, we show how structured pruning can be adopted to reduce the number of parameters in MBDL networks. Second, we present three methods to fine-tune the pruned MBDL networks to mitigate potential performance loss. Each fine-tuning strategy has a unique benefit that depends on the presence of a pre-trained model and a high-quality ground truth. We show that our pruning and fine-tuning approach can accelerate image reconstruction using popular deep equilibrium learning (DEQ) and deep unfolding (DU) methods by 50% and 32%, respectively, with nearly no performance loss. This work thus offers a step forward for solving inverse problems by showing the potential of pruning to improve the scalability of MBDL.
git clone https://github.com/wustl-cig/MBDL_Pruningcd MBDL_Pruning
- DownloadT2 Brain dataset fromfastMRI dataset. Test data will be released soon.
DownloadDeep Equilibrium Model (DEQ) trained on the brain fastMRI datasetPretrained DEQ link. The default save directory is
./pretrained_models
.DownloadEnd-to-End Variational Network (E2E-VarNet) trained on the brain fastMRI datasetPretrained E2EVarNet link. The default save directory is
./pretrained_models
.DownloadVariational Network (VarNet) trained on the FFHQ 256x256 datasetPretrained VarNet link. The default save directory is
./pretrained_models
.
conda env create -f MBDL_Pruning.ymlconda activate MBDL_Pruning
configs/deq_configs.json
configs/e2evarnet_configs.json
configs/varnet_configs.json
python main.py --task_config configs/{TASK_YAML_FILE_NAME}.yaml # example code: python main.py --task_config configs/varnet_config.yaml
main.py # All high-level implementations of pruning and fine-tuning pipeline.│ └────────── Pruning_Finetuning.py # All low-level implementations of pruning and fine-tuning pipeline.
! If you encounter any issues, feel free to reach out via email at chicago@wustl.edu.
We adapt Torch-Pruning code structure fromTorch-Pruning repo to implement network pruning.
@article{park2024EfficientMBDL, title={Efficient Model-Based Deep Learning via Network Pruning and Fine-Tuning}, author={Park, Chicago Y.and Gan, Weijieand Zou, Zihaoand Hu, Yuyangand Sun, Zhixinand Kamilov, Ulugbek S.}, journal={Research Square}, year={2024} doi={10.21203/rs.3.rs-5286110/v1}}