- Notifications
You must be signed in to change notification settings - Fork3
YangLabHKUST/DGLSB
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Official code for ICML 2021 paperDeep Generative Learning via Schrödinger Bridge, by Gefei Wang, Yuling Jiao, Qian Xu, Yang Wang and Can Yang.
For the CIFAR-10 dataset, to train both density ratio estimator and score estimator:
python main.py --config cifar10.yml --doc cifar10 --sigma_sq 1.0 --tau 2.0
To only train density ratio estimator:
python main.py --config cifar10.yml --doc cifar10 --sigma_sq 1.0 --tau 2.0 --train_d_only
To only train score estimator:
python main.py --config cifar10.yml --doc cifar10 --sigma_sq 1.0 --tau 2.0 --train_s_only
For the CelebA dataset, the config file isceleba.yml
, and recommended hyperparameters are--sigma_sq 4.0 --tau 8.0
.
To sample 50,000 images for fid evaluation:
python main.py --config cifar10.yml --doc cifar10 --sample --fid
This package is developed by Gefei Wang (gwangas@connect.ust.hk).
Please contact Gefei (gwangas@connect.ust.hk), Yuling (yulingjiaomath@whu.edu.cn) or Can (macyang@ust.hk) if any enquiry.
This implementation is based onhttps://github.com/ermongroup/ddim.
@InProceedings{pmlr-v139-wang21l, title = {Deep Generative Learning via Schr{ö}dinger Bridge}, author = {Wang, Gefei and Jiao, Yuling and Xu, Qian and Wang, Yang and Yang, Can}, booktitle = {Proceedings of the 38th International Conference on Machine Learning}, pages = {10794--10804}, year = {2021}, editor = {Meila, Marina and Zhang, Tong}, volume = {139}, series = {Proceedings of Machine Learning Research}, month = {18--24 Jul}, publisher = {PMLR}}