Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

License

NotificationsYou must be signed in to change notification settings

google-deepmind/distrax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

240 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CI statuspypi

Distrax is a lightweight library of probability distributions and bijectors. Itacts as a JAX-native reimplementation of a subset ofTensorFlow Probability (TFP), withsome new features and emphasis on extensibility.

Installation

You can install the latest released version of Distrax from PyPI via:

pip install distrax

or you can install the latest development version from GitHub:

pip install git+https://github.com/deepmind/distrax.git

To run the tests orexamples you willneed to install additionalrequirements.

Design Principles

The general design principles for the DeepMind JAX Ecosystem are addressed inthis blog.Additionally, Distrax places emphasis on the following:

  1. Readability. Distrax implementations are intended to be self-containedand read as close to the underlying math as possible.
  2. Extensibility. We have made it as simple as possible for users to definetheir own distribution or bijector. This is useful for example in reinforcementlearning, where users may wish to define custom behavior for probabilistic agentpolicies.
  3. Compatibility. Distrax is not intended as a replacement for TFP, and TFPcontains many advanced features that we do not intend to replicate. To this end,we have made the APIs for distributions and bijectors as cross-compatible aspossible, and provide utilities for transforming between equivalent Distrax andTFP classes.

Features

Distributions

Distributions in Distrax are simple to define and use, particularly if you'reused to TFP. Let's compare the two side-by-side:

importdistraximportjaximportjax.numpyasjnpfromtensorflow_probability.substratesimportjaxastfptfd=tfp.distributionskey=jax.random.PRNGKey(1234)mu=jnp.array([-1.,0.,1.])sigma=jnp.array([0.1,0.2,0.3])dist_distrax=distrax.MultivariateNormalDiag(mu,sigma)dist_tfp=tfd.MultivariateNormalDiag(mu,sigma)samples=dist_distrax.sample(seed=key)# Both print 1.775print(dist_distrax.log_prob(samples))print(dist_tfp.log_prob(samples))

In addition to behaving consistently, Distrax distributions and TFPdistributions are cross-compatible. For example:

mu_0=jnp.array([-1.,0.,1.])sigma_0=jnp.array([0.1,0.2,0.3])dist_distrax=distrax.MultivariateNormalDiag(mu_0,sigma_0)mu_1=jnp.array([1.,2.,3.])sigma_1=jnp.array([0.2,0.3,0.4])dist_tfp=tfd.MultivariateNormalDiag(mu_1,sigma_1)# Both print 85.237print(dist_distrax.kl_divergence(dist_tfp))print(tfd.kl_divergence(dist_distrax,dist_tfp))

Distrax distributions implement the methodsample_and_log_prob, which providessamples and their log-probability in one line. For some distributions, this ismore efficient than calling separatelysample andlog_prob:

mu=jnp.array([-1.,0.,1.])sigma=jnp.array([0.1,0.2,0.3])dist_distrax=distrax.MultivariateNormalDiag(mu,sigma)samples=dist_distrax.sample(seed=key,sample_shape=())log_prob=dist_distrax.log_prob(samples)# A one-line equivalent of the above is:samples,log_prob=dist_distrax.sample_and_log_prob(seed=key,sample_shape=())

TFP distributions can be passed to Distrax meta-distributions as inputs. Forexample:

key=jax.random.PRNGKey(1234)mu=jnp.array([-1.,0.,1.])sigma=jnp.array([0.2,0.3,0.4])dist_tfp=tfd.Normal(mu,sigma)metadist_distrax=distrax.Independent(dist_tfp,reinterpreted_batch_ndims=1)samples=metadist_distrax.sample(seed=key)print(metadist_distrax.log_prob(samples))# Prints 0.38871175

To use Distrax distributions in TFP meta-distributions, Distrax provides thewrapperto_tfp. A wrapped Distrax distribution can be directly used in TFP:

key=jax.random.PRNGKey(1234)distrax_dist=distrax.Normal(0.,1.)wrapped_dist=distrax.to_tfp(distrax_dist)metadist_tfp=tfd.Sample(wrapped_dist,sample_shape=[3])samples=metadist_tfp.sample(seed=key)print(metadist_tfp.log_prob(samples))# Prints -3.3409896

Bijectors

A "bijector" in Distrax is an invertible function that knows how to compute itsJacobian determinant. Bijectors can be used to create complex distributions bytransforming simpler ones. Distrax bijectors are functionally similar to TFPbijectors, with a few API differences. Here is an example comparing the two:

importdistraximportjax.numpyasjnpfromtensorflow_probability.substratesimportjaxastfptfb=tfp.bijectorstfd=tfp.distributions# Same distribution.distrax.Transformed(distrax.Normal(loc=0.,scale=1.),distrax.Tanh())tfd.TransformedDistribution(tfd.Normal(loc=0.,scale=1.),tfb.Tanh())

Additionally, Distrax bijectors can be composed and inverted:

bij_distrax=distrax.Tanh()bij_tfp=tfb.Tanh()# Same bijector.inv_bij_distrax=distrax.Inverse(bij_distrax)inv_bij_tfp=tfb.Invert(bij_tfp)# These are both the identity bijector.distrax.Chain([bij_distrax,inv_bij_distrax])tfb.Chain([bij_tfp,inv_bij_tfp])

All TFP bijectors can be passed to Distrax, and can be freely composed withDistrax bijectors. For example, all of the following will work:

distrax.Inverse(tfb.Tanh())distrax.Chain([tfb.Tanh(),distrax.Tanh()])distrax.Transformed(tfd.Normal(loc=0.,scale=1.),tfb.Tanh())

Distrax bijectors can also be passed to TFP, but first they must be transformedwithto_tfp:

bij_distrax=distrax.to_tfp(distrax.Tanh())tfb.Invert(bij_distrax)tfb.Chain([tfb.Tanh(),bij_distrax])tfd.TransformedDistribution(tfd.Normal(loc=0.,scale=1.),bij_distrax)

Distrax also comes withLambda, a convenient wrapper for turning simple JAXfunctions into bijectors. Here are a fewLambda examples with their TFPequivalents:

distrax.Lambda(lambdax:x)# tfb.Identity()distrax.Lambda(lambdax:2*x+3)# tfb.Chain([tfb.Shift(3), tfb.Scale(2)])distrax.Lambda(jnp.sinh)# tfb.Sinh()distrax.Lambda(lambdax:jnp.sinh(2*x+3))# tfb.Chain([tfb.Sinh(), tfb.Shift(3), tfb.Scale(2)])

Unlike TFP, bijectors in Distrax do not takeevent_ndims as an argument whenthey compute the Jacobian determinant. Instead, Distrax assumes that the numberof event dimensions is statically known to every bijector, and usesBlock to lift bijectors to a different number of dimensions. For example:

x=jnp.zeros([2,3,4])# In TFP, `event_ndims` can be passed to the bijector.bij_tfp=tfb.Tanh()ld_1=bij_tfp.forward_log_det_jacobian(x,event_ndims=0)# Shape = [2, 3, 4]# Distrax assumes `Tanh` is a scalar bijector by default.bij_distrax=distrax.Tanh()ld_2=bij_distrax.forward_log_det_jacobian(x)# ld_1 == ld_2# With `event_ndims=2`, TFP sums the last 2 dimensions of the log det.ld_3=bij_tfp.forward_log_det_jacobian(x,event_ndims=2)# Shape = [2]# Distrax treats the number of dimensions statically.bij_distrax=distrax.Block(bij_distrax,ndims=2)ld_4=bij_distrax.forward_log_det_jacobian(x)# ld_3 == ld_4

Distrax bijectors implement the methodforward_and_log_det (some bijectorsadditionally implementinverse_and_log_det), which allows to obtain theforward mapping and its log Jacobian determinant in one line. For somebijectors, this is more efficient than calling separatelyforward andforward_log_det_jacobian. (Analogously, when available,inverse_and_log_detcan be more efficient thaninverse andinverse_log_det_jacobian.)

x=jnp.zeros([2,3,4])bij_distrax=distrax.Tanh()y=bij_distrax.forward(x)ld=bij_distrax.forward_log_det_jacobian(x)# A one-line equivalent of the above is:y,ld=bij_distrax.forward_and_log_det(x)

Jitting Distrax

Distrax distributions and bijectors can be passed as arguments to jittedfunctions. User-defined distributions and bijectors get this property for freeby subclassingdistrax.Distribution anddistrax.Bijector respectively. Forexample:

mu_0=jnp.array([-1.,0.,1.])sigma_0=jnp.array([0.1,0.2,0.3])dist_0=distrax.MultivariateNormalDiag(mu_0,sigma_0)mu_1=jnp.array([1.,2.,3.])sigma_1=jnp.array([0.2,0.3,0.4])dist_1=distrax.MultivariateNormalDiag(mu_1,sigma_1)jitted_kl=jax.jit(lambdad_0,d_1:d_0.kl_divergence(d_1))# Both print 85.237print(jitted_kl(dist_0,dist_1))print(dist_0.kl_divergence(dist_1))
A note aboutvmap andpmap

The serialization logic that enables Distrax objects to be passed as argumentsto jitted functions also enables functions to map over them as data usingjax.vmap andjax.pmap.

However,support for this behavior is experimental and incomplete. Usecaution when applyingjax.vmap orjax.pmap to functions that take Distraxobjects as arguments, or return Distrax objects.

Simple objects such asdistrax.Categorical may behave correctly under thesetransformations, but more complex objects such asdistrax.MultivariateNormalDiag may generate exceptions when used as inputs oroutputs of avmap-ed orpmap-ed function.

Subclassing Distributions and Bijectors

User-defined distributions can be created by subclassingdistrax.Distribution.This can be achieved by implementing only a few methods:

classMyDistribution(distrax.Distribution):def__init__(self, ...):    ...def_sample_n(self,key,n):samples= ...returnsamplesdeflog_prob(self,value):log_prob= ...returnlog_probdefevent_shape(self):event_shape= ...returnevent_shapedef_sample_n_and_log_prob(self,key,n):# Optional. Only when more efficient implementation is possible.samples,log_prob= ...returnsamples,log_prob

Similarly, more complicated bijectors can be created by subclassingdistrax.Bijector. This can be achieved by implementing only one or two classmethods:

classMyBijector(distrax.Bijector):def__init__(self, ...):super().__init__(...)defforward_and_log_det(self,x):y= ...logdet= ...returny,logdetdefinverse_and_log_det(self,y):# Optional. Can be omitted if inverse methods are not needed.x= ...logdet= ...returnx,logdet

Examples

Theexamples directory contains some representative examples of full programsthat use Distrax.

hmm.py demonstrates how to usedistrax.HMM to combine distributions thatmodel the initial states, transitions, and observation distributions of aHidden Markov Model, and infer the latent rates and state transitions in achanging noisy signal.

vae.py contains an example implementation of a variational auto-encoder thatis trained to model the binarized MNIST dataset as a jointdistrax.Bernoullidistribution over the pixels.

flow.py illustrates a simple example of modelling MNIST data usingdistrax.MaskedCoupling layers to implement a normalizing flow, and trainingthe model with gradient descent.

Acknowledgements

We greatly appreciate the ongoing support of the TensorFlow Probability authorsin assisting with the design and cross-compatibility of Distrax.

Special thanks to Aleyna Kara and Kevin Murphy for contributing the code uponwhich the Hidden Markov Model and associated example are based.

Citing Distrax

This repository is part of the DeepMind JAX Ecosystem, to cite Distraxplease use the citation:

@software{deepmind2020jax,title ={The {D}eep{M}ind {JAX} {E}cosystem},author ={DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},url ={http://github.com/deepmind},year ={2020},}

About

No description, website, or topics provided.

Resources

License

Contributing

Stars

Watchers

Forks

Packages

No packages published

Contributors24


[8]ページ先頭

©2009-2026 Movatter.jp