- Notifications
You must be signed in to change notification settings - Fork37
License
google-deepmind/distrax
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
Distrax is a lightweight library of probability distributions and bijectors. Itacts as a JAX-native reimplementation of a subset ofTensorFlow Probability (TFP), withsome new features and emphasis on extensibility.
You can install the latest released version of Distrax from PyPI via:
pip install distrax
or you can install the latest development version from GitHub:
pip install git+https://github.com/deepmind/distrax.git
To run the tests orexamples you willneed to install additionalrequirements.
The general design principles for the DeepMind JAX Ecosystem are addressed inthis blog.Additionally, Distrax places emphasis on the following:
- Readability. Distrax implementations are intended to be self-containedand read as close to the underlying math as possible.
- Extensibility. We have made it as simple as possible for users to definetheir own distribution or bijector. This is useful for example in reinforcementlearning, where users may wish to define custom behavior for probabilistic agentpolicies.
- Compatibility. Distrax is not intended as a replacement for TFP, and TFPcontains many advanced features that we do not intend to replicate. To this end,we have made the APIs for distributions and bijectors as cross-compatible aspossible, and provide utilities for transforming between equivalent Distrax andTFP classes.
Distributions in Distrax are simple to define and use, particularly if you'reused to TFP. Let's compare the two side-by-side:
importdistraximportjaximportjax.numpyasjnpfromtensorflow_probability.substratesimportjaxastfptfd=tfp.distributionskey=jax.random.PRNGKey(1234)mu=jnp.array([-1.,0.,1.])sigma=jnp.array([0.1,0.2,0.3])dist_distrax=distrax.MultivariateNormalDiag(mu,sigma)dist_tfp=tfd.MultivariateNormalDiag(mu,sigma)samples=dist_distrax.sample(seed=key)# Both print 1.775print(dist_distrax.log_prob(samples))print(dist_tfp.log_prob(samples))
In addition to behaving consistently, Distrax distributions and TFPdistributions are cross-compatible. For example:
mu_0=jnp.array([-1.,0.,1.])sigma_0=jnp.array([0.1,0.2,0.3])dist_distrax=distrax.MultivariateNormalDiag(mu_0,sigma_0)mu_1=jnp.array([1.,2.,3.])sigma_1=jnp.array([0.2,0.3,0.4])dist_tfp=tfd.MultivariateNormalDiag(mu_1,sigma_1)# Both print 85.237print(dist_distrax.kl_divergence(dist_tfp))print(tfd.kl_divergence(dist_distrax,dist_tfp))
Distrax distributions implement the methodsample_and_log_prob, which providessamples and their log-probability in one line. For some distributions, this ismore efficient than calling separatelysample andlog_prob:
mu=jnp.array([-1.,0.,1.])sigma=jnp.array([0.1,0.2,0.3])dist_distrax=distrax.MultivariateNormalDiag(mu,sigma)samples=dist_distrax.sample(seed=key,sample_shape=())log_prob=dist_distrax.log_prob(samples)# A one-line equivalent of the above is:samples,log_prob=dist_distrax.sample_and_log_prob(seed=key,sample_shape=())
TFP distributions can be passed to Distrax meta-distributions as inputs. Forexample:
key=jax.random.PRNGKey(1234)mu=jnp.array([-1.,0.,1.])sigma=jnp.array([0.2,0.3,0.4])dist_tfp=tfd.Normal(mu,sigma)metadist_distrax=distrax.Independent(dist_tfp,reinterpreted_batch_ndims=1)samples=metadist_distrax.sample(seed=key)print(metadist_distrax.log_prob(samples))# Prints 0.38871175
To use Distrax distributions in TFP meta-distributions, Distrax provides thewrapperto_tfp. A wrapped Distrax distribution can be directly used in TFP:
key=jax.random.PRNGKey(1234)distrax_dist=distrax.Normal(0.,1.)wrapped_dist=distrax.to_tfp(distrax_dist)metadist_tfp=tfd.Sample(wrapped_dist,sample_shape=[3])samples=metadist_tfp.sample(seed=key)print(metadist_tfp.log_prob(samples))# Prints -3.3409896
A "bijector" in Distrax is an invertible function that knows how to compute itsJacobian determinant. Bijectors can be used to create complex distributions bytransforming simpler ones. Distrax bijectors are functionally similar to TFPbijectors, with a few API differences. Here is an example comparing the two:
importdistraximportjax.numpyasjnpfromtensorflow_probability.substratesimportjaxastfptfb=tfp.bijectorstfd=tfp.distributions# Same distribution.distrax.Transformed(distrax.Normal(loc=0.,scale=1.),distrax.Tanh())tfd.TransformedDistribution(tfd.Normal(loc=0.,scale=1.),tfb.Tanh())
Additionally, Distrax bijectors can be composed and inverted:
bij_distrax=distrax.Tanh()bij_tfp=tfb.Tanh()# Same bijector.inv_bij_distrax=distrax.Inverse(bij_distrax)inv_bij_tfp=tfb.Invert(bij_tfp)# These are both the identity bijector.distrax.Chain([bij_distrax,inv_bij_distrax])tfb.Chain([bij_tfp,inv_bij_tfp])
All TFP bijectors can be passed to Distrax, and can be freely composed withDistrax bijectors. For example, all of the following will work:
distrax.Inverse(tfb.Tanh())distrax.Chain([tfb.Tanh(),distrax.Tanh()])distrax.Transformed(tfd.Normal(loc=0.,scale=1.),tfb.Tanh())
Distrax bijectors can also be passed to TFP, but first they must be transformedwithto_tfp:
bij_distrax=distrax.to_tfp(distrax.Tanh())tfb.Invert(bij_distrax)tfb.Chain([tfb.Tanh(),bij_distrax])tfd.TransformedDistribution(tfd.Normal(loc=0.,scale=1.),bij_distrax)
Distrax also comes withLambda, a convenient wrapper for turning simple JAXfunctions into bijectors. Here are a fewLambda examples with their TFPequivalents:
distrax.Lambda(lambdax:x)# tfb.Identity()distrax.Lambda(lambdax:2*x+3)# tfb.Chain([tfb.Shift(3), tfb.Scale(2)])distrax.Lambda(jnp.sinh)# tfb.Sinh()distrax.Lambda(lambdax:jnp.sinh(2*x+3))# tfb.Chain([tfb.Sinh(), tfb.Shift(3), tfb.Scale(2)])
Unlike TFP, bijectors in Distrax do not takeevent_ndims as an argument whenthey compute the Jacobian determinant. Instead, Distrax assumes that the numberof event dimensions is statically known to every bijector, and usesBlock to lift bijectors to a different number of dimensions. For example:
x=jnp.zeros([2,3,4])# In TFP, `event_ndims` can be passed to the bijector.bij_tfp=tfb.Tanh()ld_1=bij_tfp.forward_log_det_jacobian(x,event_ndims=0)# Shape = [2, 3, 4]# Distrax assumes `Tanh` is a scalar bijector by default.bij_distrax=distrax.Tanh()ld_2=bij_distrax.forward_log_det_jacobian(x)# ld_1 == ld_2# With `event_ndims=2`, TFP sums the last 2 dimensions of the log det.ld_3=bij_tfp.forward_log_det_jacobian(x,event_ndims=2)# Shape = [2]# Distrax treats the number of dimensions statically.bij_distrax=distrax.Block(bij_distrax,ndims=2)ld_4=bij_distrax.forward_log_det_jacobian(x)# ld_3 == ld_4
Distrax bijectors implement the methodforward_and_log_det (some bijectorsadditionally implementinverse_and_log_det), which allows to obtain theforward mapping and its log Jacobian determinant in one line. For somebijectors, this is more efficient than calling separatelyforward andforward_log_det_jacobian. (Analogously, when available,inverse_and_log_detcan be more efficient thaninverse andinverse_log_det_jacobian.)
x=jnp.zeros([2,3,4])bij_distrax=distrax.Tanh()y=bij_distrax.forward(x)ld=bij_distrax.forward_log_det_jacobian(x)# A one-line equivalent of the above is:y,ld=bij_distrax.forward_and_log_det(x)
Distrax distributions and bijectors can be passed as arguments to jittedfunctions. User-defined distributions and bijectors get this property for freeby subclassingdistrax.Distribution anddistrax.Bijector respectively. Forexample:
mu_0=jnp.array([-1.,0.,1.])sigma_0=jnp.array([0.1,0.2,0.3])dist_0=distrax.MultivariateNormalDiag(mu_0,sigma_0)mu_1=jnp.array([1.,2.,3.])sigma_1=jnp.array([0.2,0.3,0.4])dist_1=distrax.MultivariateNormalDiag(mu_1,sigma_1)jitted_kl=jax.jit(lambdad_0,d_1:d_0.kl_divergence(d_1))# Both print 85.237print(jitted_kl(dist_0,dist_1))print(dist_0.kl_divergence(dist_1))
The serialization logic that enables Distrax objects to be passed as argumentsto jitted functions also enables functions to map over them as data usingjax.vmap andjax.pmap.
However,support for this behavior is experimental and incomplete. Usecaution when applyingjax.vmap orjax.pmap to functions that take Distraxobjects as arguments, or return Distrax objects.
Simple objects such asdistrax.Categorical may behave correctly under thesetransformations, but more complex objects such asdistrax.MultivariateNormalDiag may generate exceptions when used as inputs oroutputs of avmap-ed orpmap-ed function.
User-defined distributions can be created by subclassingdistrax.Distribution.This can be achieved by implementing only a few methods:
classMyDistribution(distrax.Distribution):def__init__(self, ...): ...def_sample_n(self,key,n):samples= ...returnsamplesdeflog_prob(self,value):log_prob= ...returnlog_probdefevent_shape(self):event_shape= ...returnevent_shapedef_sample_n_and_log_prob(self,key,n):# Optional. Only when more efficient implementation is possible.samples,log_prob= ...returnsamples,log_prob
Similarly, more complicated bijectors can be created by subclassingdistrax.Bijector. This can be achieved by implementing only one or two classmethods:
classMyBijector(distrax.Bijector):def__init__(self, ...):super().__init__(...)defforward_and_log_det(self,x):y= ...logdet= ...returny,logdetdefinverse_and_log_det(self,y):# Optional. Can be omitted if inverse methods are not needed.x= ...logdet= ...returnx,logdet
Theexamples directory contains some representative examples of full programsthat use Distrax.
hmm.py demonstrates how to usedistrax.HMM to combine distributions thatmodel the initial states, transitions, and observation distributions of aHidden Markov Model, and infer the latent rates and state transitions in achanging noisy signal.
vae.py contains an example implementation of a variational auto-encoder thatis trained to model the binarized MNIST dataset as a jointdistrax.Bernoullidistribution over the pixels.
flow.py illustrates a simple example of modelling MNIST data usingdistrax.MaskedCoupling layers to implement a normalizing flow, and trainingthe model with gradient descent.
We greatly appreciate the ongoing support of the TensorFlow Probability authorsin assisting with the design and cross-compatibility of Distrax.
Special thanks to Aleyna Kara and Kevin Murphy for contributing the code uponwhich the Hidden Markov Model and associated example are based.
This repository is part of the DeepMind JAX Ecosystem, to cite Distraxplease use the citation:
@software{deepmind2020jax,title ={The {D}eep{M}ind {JAX} {E}cosystem},author ={DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},url ={http://github.com/deepmind},year ={2020},}
About
Resources
License
Contributing
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.