- Notifications
You must be signed in to change notification settings - Fork3
Intriguing Properties of Data Attribution on Diffusion Models (ICLR 2024)
License
sail-sg/D-TRAK
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
[Project Page] | [arXiv] | [Data Repository]
We report counter-intuitive observations that theoretically unjustified design choices for attributing diffusion models empirically outperform previous baselines by a large margin.
Counterfactual visualization on CIFAR-2 | Counterfactual visualization on ArtBench-2 |
---|---|
![]() | ![]() |
![]() | ![]() |
![]() | ![]() |
Checkquickstart.ipynb to conduct data attribution on pre-trained diffusion models loaded from huggingface directly!
To get started, follow these steps:
- Clone the GitHub Repository: Begin by cloning the repository using the command:
git clone https://github.com/sail-sg/D-TRAK.git
- Set Up Python Environment: Ensure you have a version 3.8.name:
conda create -n dtrak python=3.8 -yconda activate dtrak
- Install Dependencies: Install the necessary dependencies by running:
pip install -r requirements.txt
We provide the commands to run experiments on CIFAR-2.It is easy to transfer to other datasets.
Data pre-processing:
cd CIFAR2
Run00_EDA.ipynb to create dataset splits and subsets of the training set.
Train a diffusion model and generate images:
bash scripts/run_train.sh 0 18888 5000-0.5bash scripts/run_gen.sh 0 0 5000-0.5
Construct the LDS benchmark:
Train 64 models corresponding to 64 subsets of the training set
bash scripts/run_lds_val_sub.sh 0 18888 5000-0.5 0 63
Evaluate the model outputs on the validation set
bash scripts/run_eval_lds_val_sub.sh 0 0 5000-0.5 idx_val.pkl 0 63bash scripts/run_eval_lds_val_sub.sh 0 1 5000-0.5 idx_val.pkl 0 63bash scripts/run_eval_lds_val_sub.sh 0 2 5000-0.5 idx_val.pkl 0 63
Evaluate the model outputs on the generation set
bash scripts/run_eval_lds_val_sub.sh 0 0 5000-0.5 idx_gen.pkl 0 63bash scripts/run_eval_lds_val_sub.sh 0 1 5000-0.5 idx_gen.pkl 0 63bash scripts/run_eval_lds_val_sub.sh 0 2 5000-0.5 idx_gen.pkl 0 63
Compute gradients:
We shard the training set into 5 parts, each has 1000 examples.
Use the following commands to compute the gradients to be used for TRAK.
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 1 ddpm/checkpoint-8000 loss uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 2 ddpm/checkpoint-8000 loss uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 3 ddpm/checkpoint-8000 loss uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 4 ddpm/checkpoint-8000 loss uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-val.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-gen.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768
Use the following commands to compute the gradients to be used for D-TRAK.
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 1 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 2 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 3 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 4 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-val.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-gen.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
Compute the TRAK/D-TRAK attributions and evaluate the LDS scores
Run notebooks inmethods/04_if.
The implementations of other baselines can also be found inmethods.
Data pre-processing
Run thisnotebook first to get the indices of those training examples to be removed.
Retrain models after removing the top-influenctial training examples
bash scripts/run_counter.sh 0 18888 5000-0.5 0 59
Generate images using the retrained models
Measure l2 distance
Measure CLIP cosine similarity
If you find this project useful in your research, please consider citing our paper:
@inproceedings{zheng2023intriguing,title={Intriguing Properties of Data Attribution on Diffusion Models}, author={Zheng, Xiaosen and Pang, Tianyu and Du, Chao and Jiang, Jing and Lin, Min},booktitle={International Conference on Learning Representations (ICLR)},year={2024},}
About
Intriguing Properties of Data Attribution on Diffusion Models (ICLR 2024)
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Releases
Packages0
Uh oh!
There was an error while loading.Please reload this page.