- Notifications
You must be signed in to change notification settings - Fork716
Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"
License
facebookresearch/DiT
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
Paper |Project Page | Run DiT-XL/2


This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploringdiffusion models with transformers (DiTs). You can find more visualizations on ourproject page.
Scalable Diffusion Models with Transformers
William Peebles,Saining Xie
UC Berkeley, New York University
We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates onlatent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward passcomplexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth/width orincreased number of input tokens---consistently have lower FID. In addition to good scalability properties, ourDiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks,achieving a state-of-the-art FID of 2.27 on the latter.
This repository contains:
- 🪐 A simple PyTorchimplementation of DiT
- ⚡️ Pre-trained class-conditional DiT models trained on ImageNet (512x512 and 256x256)
- 💥 A self-containedHugging Face Space andColab notebook for running pre-trained DiT-XL/2 models
- 🛸 A DiTtraining script using PyTorch DDP
An implementation of DiT directly in Hugging Facediffusers can also be foundhere.
First, download and set up the repo:
git clone https://github.com/facebookresearch/DiT.gitcd DiTWe provide anenvironment.yml file that can be used to create a Conda environment. If you only wantto run pre-trained models locally on CPU, you can remove thecudatoolkit andpytorch-cuda requirements from the file.
conda env create -f environment.ymlconda activate DiT
Pre-trained DiT checkpoints. You can sample from our pre-trained DiT models withsample.py. Weights for our pre-trained DiT model will beautomatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample fromour 512x512 DiT-XL/2 model, you can use:
python sample.py --image-size 512 --seed 1
For convenience, our pre-trained DiT models can be downloaded directly here as well:
| DiT Model | Image Resolution | FID-50K | Inception Score | Gflops |
|---|---|---|---|---|
| XL/2 | 256x256 | 2.27 | 278.24 | 119 |
| XL/2 | 512x512 | 3.04 | 240.82 | 525 |
Custom DiT checkpoints. If you've trained a new DiT model withtrain.py (seebelow), you can add the--ckptargument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom256x256 DiT-L/4 model, run:
python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt
We provide a training script for DiT intrain.py. This script can be used to train class-conditionalDiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL/2 (256x256) training withN GPUs onone node:
torchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL/2 --data-path /path/to/imagenet/train
We've trained DiT-XL/2 and DiT-B/4 models from scratch with the PyTorch training scriptto verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our experiments, the PyTorch-trained models givesimilar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. Some data points:
| DiT Model | Train Steps | FID-50K (JAX Training) | FID-50K (PyTorch Training) | PyTorch Global Training Seed |
|---|---|---|---|---|
| XL/2 | 400K | 19.5 | 18.1 | 42 |
| B/4 | 400K | 68.4 | 68.9 | 42 |
| B/4 | 400K | 68.4 | 68.3 | 100 |
These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FIDhere is computed with 250 DDPM sampling steps, with themse VAE decoder and without guidance (cfg-scale=1).
TF32 Note (important for A100 users). When we ran the above tests, TF32 matmuls were disabled per PyTorch's defaults.We've enabled them at the top oftrain.py andsample.py because it makes training and sampling way way way faster onA100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared tothe above results.
Training (and sampling) could likely be sped-up significantly by:
- usingFlash Attention in the DiT model
- using
torch.compilein PyTorch 2.0
Basic features that would be nice to add:
- Monitor FID and other metrics
- Generate and save samples from the EMA model periodically
- Resume training from a checkpoint
- AMP/bfloat16 support
🔥 Feature Update Check out this repository athttps://github.com/chuanyangjin/fast-DiT to preview a selection of training speed acceleration and memory saving features including gradient checkpointing, mixed precision training and pre-extrated VAE features. With these advancements, we have achieved a training speed of 0.84 steps/sec for DiT-XL/2 using just a single A100 GPU.
We include asample_ddp.py script which samples a large number of images from a DiT model in parallel. This scriptgenerates a folder of samples as well as a.npz file which can be directly used withADM's TensorFlowevaluation suite to compute FID, Inception Score andother metrics. For example, to sample 50K images from our pre-trained DiT-XL/2 model overN GPUs, run:
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000
There are several additional options; seesample_ddp.py for details.
Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models.There may be minor differences in results stemming from sampling with different floating point precisions. We re-evaluatedour ported PyTorch weights at FP32, and they actually perform marginally better than sampling in JAX (2.21 FIDversus 2.27 in the paper).
@article{Peebles2022DiT,title={Scalable Diffusion Models with Transformers},author={William Peebles and Saining Xie},year={2022},journal={arXiv preprint arXiv:2212.09748},}
We thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for helpful discussions.William Peebles is supported by the NSF Graduate Research Fellowship.
This codebase borrows from OpenAI's diffusion repos, most notablyADM.
The code and model weights are licensed under CC-BY-NC. SeeLICENSE.txt for details.
About
Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"
Resources
License
Code of conduct
Contributing
Security policy
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Contributors3
Uh oh!
There was an error while loading.Please reload this page.

