- Notifications
You must be signed in to change notification settings - Fork41
Diffusion Classifier leverages pretrained diffusion models to perform zero-shot classification without additional training
diffusion-classifier/diffusion-classifier
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
This is the official implementation of the ICCV 2023 paperYour Diffusion Model is Secretly a Zero-Shot Classifier by Alexander Li, Mihir Prabhudesai, Shivam Duggal, Ellis Brown, and Deepak Pathak.
The recent wave of large-scale text-to-image diffusion models has dramatically increased our text-based image generation abilities. These models can generate realistic images for a staggering variety of prompts and exhibit impressive compositional generalization abilities. Almost all use cases thus far have solely focused on sampling; however, diffusion models can also provide conditional density estimates, which are useful for tasks beyond image generation. In this paper, we show that the density estimates from large-scale text-to-image diffusion models like Stable Diffusion can be leveraged to perform zero-shot classification without any additional training. Our generative approach to classification, which we callDiffusion Classifier, attains strong results on a variety of benchmarks and outperforms alternative methods of extracting knowledge from diffusion models. Although a gap remains between generative and discriminative approaches on zero-shot recognition tasks, our diffusion-based approach has significantly stronger multimodal compositional reasoning ability than competing discriminative approaches. Finally, we use Diffusion Classifier to extract standard classifiers from class-conditional diffusion models trained on ImageNet. Our models achieve strong classification performance using only weak augmentations and exhibit qualitatively better "effective robustness" to distribution shift. Overall, our results are a step toward using generative over discriminative models for downstream tasks.
Create a conda environment with the following command:
conda env create -f environment.yml
If this takes too long,conda config --set solver libmamba sets conda to use thelibmamba solver and could speed up installation.
python eval_prob_adaptive.py --dataset cifar10 --splittest --n_trials 1 \ --to_keep 5 1 --n_samples 50 500 --loss l1 \ --prompt_path prompts/cifar10_prompts.csvThis command reads potential prompts from a csv file and evaluates the epsilon prediction loss for each prompt using Stable Diffusion.This should work on a variety of GPUs, from as small as a 2080Ti or 3080 to as large as a 3090 or A6000.Losses are saved separately for each test image in the log directory. For the command above, the log directory isdata/cifar10/v2-0_1trials_5_1keep_50_500samples_l1. Accuracy can be computed by running:
python scripts/print_acc.py data/cifar10/v2-0_1trials_5_1keep_50_500samples_l1
Commands to run Diffusion Classifier on each dataset arehere.If evaluation on your use case is taking too long, there are a few options:
- Parallelize evaluation across multiple workers. Try using the
--n_workersand--worker_idxflags. - Play around with the evaluation strategy (e.g.
--n_samplesand--to_keep). - Evaluate on a smaller subset of the dataset. Saving a npy array of test set indices and using the
--subset_pathflag can be useful for this.
- Create a csv file with the prompts that you want to evaluate, making sure to match up the correct prompts with the correct class labels. See
scripts/write_cifar10_prompts.pyfor an example. Note that you can use multiple prompts per class. - Run the command above, changing the
--datasetand--prompt_pathflags to match your use case. - Play around with the evaluation strategy on a small subset of the dataset to reduce evaluation time.
Within thediffusion-classifier folder, download the DiT repository
git clone git@github.com:facebookresearch/DiT.git
First, save a consistent set of noise (epsilon) that will be used for all image-class pairs:
python scripts/save_noise.py --img_size 256
Then, compute and save the epsilon-prediction error for each class:
python eval_prob_dit.py --dataset imagenet --splittest \ --noise_path noise_256.pt --randomize_noise \ --batch_size 32 --cls CLS --t_interval 4 --extra dit256 --save_vbFor example, for ImageNet, this would need to be run with CLS from 0 to 999.This is currently a very expensive process, so we recommend using the--subset_path command to evaluate on a smaller subset of the dataset.We also plan on releasing an adaptive version that greatly reduces the computation time per test image.
Finally, compute the accuracy using the saved errors:
python scripts/print_dit_acc.py data/imagenet_dit256 --dataset imagenet
We show the commands to run DiT on all ImageNet variantshere.
To run Diffusion Classifier on Winoground:First, save a consistent set of noise (epsilon) that will be used for all image-caption pairs:
python scripts/save_noise.py --img_size 512
Then, evaluate on Winoground:
python run_winoground.py --model sd --version 2-0 --t_interval 1 --batch_size 32 --noise_path noise_512.pt --randomize_noise --interpolation bicubic
To run CLIP or OpenCLIP baselines:
python run_winoground.py --model clip --version ViT-L/14python run_winoground.py --model openclip --version ViT-H-14
If you find this work useful in your research, please cite:
@misc{li2023diffusion,title={Your Diffusion Model is Secretly a Zero-Shot Classifier},author={Alexander C. Li and Mihir Prabhudesai and Shivam Duggal and Ellis Brown and Deepak Pathak},year={2023},eprint={2303.16203},archivePrefix={arXiv},primaryClass={cs.LG}}
About
Diffusion Classifier leverages pretrained diffusion models to perform zero-shot classification without additional training
Topics
Resources
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Releases
Packages0
Uh oh!
There was an error while loading.Please reload this page.