- Notifications
You must be signed in to change notification settings - Fork3
Intriguing Properties of Data Attribution on Diffusion Models (ICLR 2024)
License
sail-sg/D-TRAK
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
[Project Page] | [arXiv] | [Data Repository]
We report counter-intuitive observations that theoretically unjustified design choices for attributing diffusion models empirically outperform previous baselines by a large margin.
Counterfactual visualization on CIFAR-2 | Counterfactual visualization on ArtBench-2 |
---|---|
![]() | ![]() |
![]() | ![]() |
![]() | ![]() |
Checkquickstart.ipynb to conduct data attribution on pre-trained diffusion models loaded from huggingface directly!
To get started, follow these steps:
- Clone the GitHub Repository: Begin by cloning the repository using the command:
git clone https://github.com/sail-sg/D-TRAK.git
- Set Up Python Environment: Ensure you have a version 3.8.name:
conda create -n dtrak python=3.8 -yconda activate dtrak
- Install Dependencies: Install the necessary dependencies by running:
pip install -r requirements.txt
We provide the commands to run experiments on CIFAR-2.It is easy to transfer to other datasets.
Data pre-processing:
cd CIFAR2
Run00_EDA.ipynb to create dataset splits and subsets of the training set.
Train a diffusion model and generate images:
bash scripts/run_train.sh 0 18888 5000-0.5bash scripts/run_gen.sh 0 0 5000-0.5
Construct the LDS benchmark:
Train 64 models corresponding to 64 subsets of the training set
bash scripts/run_lds_val_sub.sh 0 18888 5000-0.5 0 63
Evaluate the model outputs on the validation set
bash scripts/run_eval_lds_val_sub.sh 0 0 5000-0.5 idx_val.pkl 0 63bash scripts/run_eval_lds_val_sub.sh 0 1 5000-0.5 idx_val.pkl 0 63bash scripts/run_eval_lds_val_sub.sh 0 2 5000-0.5 idx_val.pkl 0 63
Evaluate the model outputs on the generation set
bash scripts/run_eval_lds_val_sub.sh 0 0 5000-0.5 idx_gen.pkl 0 63bash scripts/run_eval_lds_val_sub.sh 0 1 5000-0.5 idx_gen.pkl 0 63bash scripts/run_eval_lds_val_sub.sh 0 2 5000-0.5 idx_gen.pkl 0 63
Compute gradients:
We shard the training set into 5 parts, each has 1000 examples.
Use the following commands to compute the gradients to be used for TRAK.
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 1 ddpm/checkpoint-8000 loss uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 2 ddpm/checkpoint-8000 loss uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 3 ddpm/checkpoint-8000 loss uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 4 ddpm/checkpoint-8000 loss uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-val.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-gen.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768
Use the following commands to compute the gradients to be used for D-TRAK.
bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 1 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 2 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 3 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 4 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-val.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768bash scripts/run_grad.sh 0 0 5000-0.5 idx-gen.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
Compute the TRAK/D-TRAK attributions and evaluate the LDS scores
Run notebooks inmethods/04_if.
The implementations of other baselines can also be found inmethods.
Data pre-processing
Run thisnotebook first to get the indices of those training examples to be removed.
Retrain models after removing the top-influenctial training examples
bash scripts/run_counter.sh 0 18888 5000-0.5 0 59
Generate images using the retrained models
Measure l2 distance
Measure CLIP cosine similarity
If you find this project useful in your research, please consider citing our paper:
@inproceedings{zheng2023intriguing,title={Intriguing Properties of Data Attribution on Diffusion Models}, author={Zheng, Xiaosen and Pang, Tianyu and Du, Chao and Jiang, Jing and Lin, Min},booktitle={International Conference on Learning Representations (ICLR)},year={2024},}