- Notifications
You must be signed in to change notification settings - Fork259
Variational autoencoder implemented in tensorflow and pytorch (including inverse autoregressive flow)
License
jaanli/variational-autoencoder
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
Reference implementation for a variational autoencoder in TensorFlow and PyTorch.
I recommend the PyTorch version. It includes an example of a more expressive variational family, theinverse autoregressive flow.
Variational inference is used to fit the model to binarized MNIST handwritten digits images. An inference network (encoder) is used to amortize the inference and share parameters across datapoints. The likelihood is parameterized by a generative network (decoder).
Blog post:https://jaan.io/what-is-variational-autoencoder-vae-tutorial/
(anaconda environment is inenvironment-jax.yml)
Importance sampling is used to estimate the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. The final marginal likelihood on the test set was-97.10 nats is comparable to published numbers.
$ python train_variational_autoencoder_pytorch.py --variational mean-field --use_gpu --data_dir $DAT --max_iterations 30000 --log_interval 10000Step 0 Train ELBO estimate: -558.027 Validation ELBO estimate: -384.432 Validation log p(x) estimate: -355.430 Speed: 2.72e+06 examples/sStep 10000 Train ELBO estimate: -111.323 Validation ELBO estimate: -109.048 Validation log p(x) estimate: -103.746 Speed: 2.64e+04 examples/sStep 20000 Train ELBO estimate: -103.013 Validation ELBO estimate: -107.655 Validation log p(x) estimate: -101.275 Speed: 2.63e+04 examples/sStep 29999 Test ELBO estimate: -106.642 Test log p(x) estimate: -100.309Total time: 2.49 minutesUsing a non mean-field, more expressive variational posterior approximation (inverse autoregressive flow,https://arxiv.org/abs/1606.04934), the test marginal log-likelihood improves to-95.33 nats:
$ python train_variational_autoencoder_pytorch.py --variational flowstep: 0 train elbo: -578.35step: 0 valid elbo: -407.06 valid log p(x): -367.88step: 10000 train elbo: -106.63step: 10000 valid elbo: -110.12 valid log p(x): -104.00step: 20000 train elbo: -101.51step: 20000 valid elbo: -105.02 valid log p(x): -99.11step: 30000 train elbo: -98.70step: 30000 valid elbo: -103.76 valid log p(x): -97.71Using jax (anaconda environment is inenvironment-jax.yml), to get a 3x speedup over pytorch:
$ python train_variational_autoencoder_jax.py --variational mean-field Step 0 Train ELBO estimate: -566.059 Validation ELBO estimate: -565.755 Validation log p(x) estimate: -557.914 Speed: 2.56e+11 examples/sStep 10000 Train ELBO estimate: -98.560 Validation ELBO estimate: -105.725 Validation log p(x) estimate: -98.973 Speed: 7.03e+04 examples/sStep 20000 Train ELBO estimate: -109.794 Validation ELBO estimate: -105.756 Validation log p(x) estimate: -97.914 Speed: 4.26e+04 examples/sStep 29999 Test ELBO estimate: -104.867 Test log p(x) estimate: -96.716Total time: 0.810 minutesInverse autoregressive flow in jax:
$ python train_variational_autoencoder_jax.py --variational flow Step 0 Train ELBO estimate: -727.404 Validation ELBO estimate: -726.977 Validation log p(x) estimate: -713.389 Speed: 2.56e+11 examples/sStep 10000 Train ELBO estimate: -100.093 Validation ELBO estimate: -106.985 Validation log p(x) estimate: -99.565 Speed: 2.57e+04 examples/sStep 20000 Train ELBO estimate: -113.073 Validation ELBO estimate: -108.057 Validation log p(x) estimate: -98.841 Speed: 3.37e+04 examples/sStep 29999 Test ELBO estimate: -106.803 Test log p(x) estimate: -97.620Total time: 2.350 minutes(The difference between a mean field and inverse autoregressive flow may be due to several factors, chief being the lack of convolutions in the implementation. Residual blocks are used inhttps://arxiv.org/pdf/1606.04934.pdf to get the ELBO closer to -80 nats.)
- Run
python train_variational_autoencoder_tensorflow.py - Install imagemagick (homebrew for Mac:https://formulae.brew.sh/formula/imagemagick or Chocolatey in Windows:https://community.chocolatey.org/packages/imagemagick.app)
- Go to the directory where the jpg files are saved, and run the imagemagick command to generate the .gif:
convert -delay 20 -loop 0 *.jpg latent-space.gif
- add multiple GPU / TPU option
- add jaxtyping support for PyTorch and Jax implementations :) for runtime static type checking (using @beartype decorators)
About
Variational autoencoder implemented in tensorflow and pytorch (including inverse autoregressive flow)
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.
Contributors5
Uh oh!
There was an error while loading.Please reload this page.