Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Masked Diffusion Transformer is the SOTA for image synthesis. (ICCV 2023)

License

NotificationsYou must be signed in to change notification settings

sail-sg/MDT

Repository files navigation

PWCHuggingFace space

The official codebase forMasked Diffusion Transformer is a Strong Image Synthesizer.

MDTv2: Faster Convergeence & Stronger performance

MDTv2 achieves superior image synthesis performance, e.g., a new SOTA FID score of 1.58 on the ImageNet dataset, and has more than 10× faster learning speed than the previous SOTA DiT.

MDTv2 demonstrates a 5x acceleration compared to the original MDT.

MDTv1 code

Introduction

Despite its success in image synthesis, we observe that diffusion probabilistic models (DPMs) often lack contextual reasoning ability to learn the relations among object parts in an image, leading to a slow learning process. To solve this issue, we propose a Masked Diffusion Transformer (MDT) that introduces a mask latent modeling scheme to explicitly enhance the DPMs’ ability to contextual relation learning among object semantic parts in an image.

During training, MDT operates in the latent space to mask certain tokens. Then, an asymmetric diffusion transformer is designed to predict masked tokens from unmasked ones while maintaining the diffusion generation process. Our MDT can reconstruct the full information of an image from its incomplete contextual input, thus enabling it to learn the associated relations among image tokens. We further improve MDT with a more efficient macro network structure and training strategy, named MDTv2.

Experimental results show that MDTv2 achieves superior image synthesis performance, e.g.,a new SOTA FID score of 1.58 on the ImageNet dataset, and has more than 10× faster learning speed than the previous SOTA DiT.

image

Performance

ModelDatasetResolutionFID-50KInception Score
MDT-XL/2ImageNet256x2561.79283.01
MDTv2-XL/2ImageNet256x2561.58314.73

Pretrained model download

Model is hosted on hugglingface, you can also download it with:

from huggingface_hub import snapshot_downloadmodels_path = snapshot_download("shgao/MDT-XL2")ckpt_model_path = os.path.join(models_path, "mdt_xl2_v1_ckpt.pt")

A hugglingface demo is onDEMO.

NEW SOTA on FID.

Setup

Prepare the Pytorch >=2.0 version. Download and install this repo.

git clone https://github.com/sail-sg/MDTcd MDTpip install -e .

InstallAdan optimizer, Adan is a strong optimizer with faster convergence speed than AdamW.(paper)

python -m pip install git+https://github.com/sail-sg/Adan.git

DATA

  • For standard datasets like ImageNet and CIFAR, please refer to 'dataset' for preparation.
  • When using customized dataset, change the image file name toClassID_ImgID.jpg,as theADM's dataloder gets the class ID from the file name.

Training

Training on one node (`run.sh`).
export OPENAI_LOGDIR=output_mdtv2_s2NUM_GPUS=8MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 6 --model MDTv2_S_2"DIFFUSION_FLAGS="--diffusion_steps 1000"TRAIN_FLAGS="--batch_size 32"DATA_PATH=/dataset/imagenetpython -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_train.py --data_dir$DATA_PATH$MODEL_FLAGS$DIFFUSION_FLAGS$TRAIN_FLAGS
Training on multiple nodes (`run_ddp_master.sh` and `run_ddp_worker.sh`).
# On master:export OPENAI_LOGDIR=output_mdtv2_xl2MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 4 --model MDTv2_XL_2"DIFFUSION_FLAGS="--diffusion_steps 1000"TRAIN_FLAGS="--batch_size 4"DATA_PATH=/dataset/imagenetNUM_NODE=8GPU_PRE_NODE=8python -m torch.distributed.launch --master_addr=$(hostname) --nnodes=$NUM_NODE --node_rank=$RANK --nproc_per_node=$GPU_PRE_NODE --master_port=$MASTER_PORT scripts/image_train.py --data_dir$DATA_PATH$MODEL_FLAGS$DIFFUSION_FLAGS$TRAIN_FLAGS# On workers:export OPENAI_LOGDIR=output_mdtv2_xl2MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 4 --model MDTv2_XL_2"DIFFUSION_FLAGS="--diffusion_steps 1000"TRAIN_FLAGS="--batch_size 4"DATA_PATH=/dataset/imagenetNUM_NODE=8GPU_PRE_NODE=8python -m torch.distributed.launch --master_addr=$MASTER_ADDR --nnodes=$NUM_NODE --node_rank=$RANK --nproc_per_node=$GPU_PRE_NODE --master_port=$MASTER_PORT scripts/image_train.py --data_dir$DATA_PATH$MODEL_FLAGS$DIFFUSION_FLAGS$TRAIN_FLAGS

Evaluation

The evaluation code is obtained fromADM's TensorFlow evaluation suite.Please follow the instructions in theevaluations folder to set up the evaluation environment.

Sampling and Evaluation (`run_sample.sh`):
MODEL_PATH=output_mdtv2_xl2/mdt_xl2_v2_ckpt.ptexport OPENAI_LOGDIR=output_mdtv2_xl2_evalNUM_GPUS=8echo'CFG Class-conditional sampling:'MODEL_FLAGS="--image_size 256 --model MDTv2_XL_2 --decode_layer 4"DIFFUSION_FLAGS="--num_sampling_steps 250 --num_samples 50000  --cfg_cond True"echo$MODEL_FLAGSecho$DIFFUSION_FLAGSecho$MODEL_PATHpython -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_sample.py --model_path$MODEL_PATH$MODEL_FLAGS$DIFFUSION_FLAGSecho$MODEL_FLAGSecho$DIFFUSION_FLAGSecho$MODEL_PATHpython evaluations/evaluator.py ../dataeval/VIRTUAL_imagenet256_labeled.npz$OPENAI_LOGDIR/samples_50000x256x256x3.npzecho'Class-conditional sampling:'MODEL_FLAGS="--image_size 256 --model MDTv2_XL_2 --decode_layer 4"DIFFUSION_FLAGS="--num_sampling_steps 250 --num_samples 50000"echo$MODEL_FLAGSecho$DIFFUSION_FLAGSecho$MODEL_PATHpython -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_sample.py --model_path$MODEL_PATH$MODEL_FLAGS$DIFFUSION_FLAGSecho$MODEL_FLAGSecho$DIFFUSION_FLAGSecho$MODEL_PATHpython evaluations/evaluator.py ../dataeval/VIRTUAL_imagenet256_labeled.npz$OPENAI_LOGDIR/samples_50000x256x256x3.npz

Visualization

Run theinfer_mdt.py to generate images.

Citation

@misc{gao2023masked,      title={Masked Diffusion Transformer is a Strong Image Synthesizer},       author={Shanghua Gao and Pan Zhou and Ming-Ming Cheng and Shuicheng Yan},      year={2023},      eprint={2303.14389},      archivePrefix={arXiv},      primaryClass={cs.CV}}

Acknowledgement

This codebase is built based on theDiT andADM. Thanks!

About

Masked Diffusion Transformer is the SOTA for image synthesis. (ICCV 2023)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors2

  •  
  •  

[8]ページ先頭

©2009-2025 Movatter.jp