Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
NotificationsYou must be signed in to change notification settings

wustl-cig/MBDL_Pruning

Repository files navigation

[Preprint]

cover-img

Abstract

Model-based deep learning (MBDL) is a powerful methodology for designing deep models to solve imaging inverse problems. MBDL networks can be seen as iterative algorithms that estimate the desired image using a physical measurement model and a learned image prior specified using a convolutional neural net (CNNs). The iterative nature of MBDL networks increases the test-time computational complexity, which limits their applicability in certain large-scale applications. Here we make two contributions to address this issue: First, we show how structured pruning can be adopted to reduce the number of parameters in MBDL networks. Second, we present three methods to fine-tune the pruned MBDL networks to mitigate potential performance loss. Each fine-tuning strategy has a unique benefit that depends on the presence of a pre-trained model and a high-quality ground truth. We show that our pruning and fine-tuning approach can accelerate image reconstruction using popular deep equilibrium learning (DEQ) and deep unfolding (DU) methods by 50% and 32%, respectively, with nearly no performance loss. This work thus offers a step forward for solving inverse problems by showing the potential of pruning to improve the scalability of MBDL.

Environment setting

1) Clone the repository

git clone https://github.com/wustl-cig/MBDL_Pruningcd MBDL_Pruning

2) Download fastMRI dataset

  • DownloadT2 Brain dataset fromfastMRI dataset. Test data will be released soon.

3) Download Pretrained Models

  • DownloadDeep Equilibrium Model (DEQ) trained on the brain fastMRI datasetPretrained DEQ link. The default save directory is./pretrained_models.

  • DownloadEnd-to-End Variational Network (E2E-VarNet) trained on the brain fastMRI datasetPretrained E2EVarNet link. The default save directory is./pretrained_models.

  • DownloadVariational Network (VarNet) trained on the FFHQ 256x256 datasetPretrained VarNet link. The default save directory is./pretrained_models.

4) Virtual environment setup

conda env create -f MBDL_Pruning.ymlconda activate MBDL_Pruning

Run experiment

1) Pick one task fromconfigs directory:

Pruning and fine-tuning (self-supervised) the selected network

  • configs/deq_configs.json
  • configs/e2evarnet_configs.json
  • configs/varnet_configs.json

2) Execute the code

python main.py --task_config configs/{TASK_YAML_FILE_NAME}.yaml    # example code: python main.py --task_config configs/varnet_config.yaml

Implementation detail

main.py                                # All high-level implementations of pruning and fine-tuning pipeline.│   └────────── Pruning_Finetuning.py      # All low-level implementations of pruning and fine-tuning pipeline.

Troubleshooting

! If you encounter any issues, feel free to reach out via email at chicago@wustl.edu.

Code references

We adapt Torch-Pruning code structure fromTorch-Pruning repo to implement network pruning.

Citation

@article{park2024EfficientMBDL,  title={Efficient Model-Based Deep Learning via Network Pruning and Fine-Tuning},  author={Park, Chicago Y.and Gan, Weijieand Zou, Zihaoand Hu, Yuyangand Sun, Zhixinand Kamilov, Ulugbek S.},  journal={Research Square},  year={2024}          doi={10.21203/rs.3.rs-5286110/v1}}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

[8]ページ先頭

©2009-2025 Movatter.jp