Autobatching for Bayesian inference
Contents
Autobatching for Bayesian inference#
This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.
Inspired by a notebook by @davmre.
importmatplotlib.pyplotaspltimportjaximportjax.numpyasjnpimportjax.scipyasjspfromjaximportrandomimportnumpyasnpimportscipyassp
Generate a fake binary classification dataset#
np.random.seed(10009)num_features=10num_points=100true_beta=np.random.randn(num_features).astype(jnp.float32)all_x=np.random.randn(num_points,num_features).astype(jnp.float32)y=(np.random.rand(num_points)<sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)
yarray([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)
Write the log-joint function for the model#
We’ll write a non-batched version, a manually batched version, and an autobatched version.
Non-batched#
deflog_joint(beta):result=0.# Note that no `axis` parameter is provided to `jnp.sum`.result=result+jnp.sum(jsp.stats.norm.logpdf(beta,loc=0.,scale=1.))result=result+jnp.sum(-jnp.log(1+jnp.exp(-(2*y-1)*jnp.dot(all_x,beta))))returnresult
log_joint(np.random.randn(num_features))
Array(-213.23558, dtype=float32)
# This doesn't work, because we didn't write `log_prob()` to handle batching.try:batch_size=10batched_test_beta=np.random.randn(batch_size,num_features)log_joint(np.random.randn(batch_size,num_features))exceptValueErrorase:print("Caught expected exception "+str(e))
Caught expected exception Incompatible shapes for broadcasting: shapes=[(100,), (100, 10)]
Manually batched#
defbatched_log_joint(beta):result=0.# Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis# or setting it incorrectly yields an error; at worst, it silently changes the# semantics of the model.result=result+jnp.sum(jsp.stats.norm.logpdf(beta,loc=0.,scale=1.),axis=-1)# Note the multiple transposes. Getting this right is not rocket science,# but it's also not totally mindless. (I didn't get it right on the first# try.)result=result+jnp.sum(-jnp.log(1+jnp.exp(-(2*y-1)*jnp.dot(all_x,beta.T).T)),axis=-1)returnresult
batch_size=10batched_test_beta=np.random.randn(batch_size,num_features)batched_log_joint(batched_test_beta)
Array([-147.84032, -207.02205, -109.26076, -243.80833, -163.02908, -143.84848, -160.28772, -113.7717 , -126.60544, -190.81989], dtype=float32)
Autobatched with vmap#
It just works.
vmap_batched_log_joint=jax.vmap(log_joint)vmap_batched_log_joint(batched_test_beta)
Array([-147.84032, -207.02205, -109.26076, -243.80833, -163.02908, -143.84848, -160.28772, -113.7717 , -126.60544, -190.81989], dtype=float32)
Self-contained variational inference example#
A little code is copied from above.
Set up the (batched) log-joint function#
@jax.jitdeflog_joint(beta):result=0.# Note that no `axis` parameter is provided to `jnp.sum`.result=result+jnp.sum(jsp.stats.norm.logpdf(beta,loc=0.,scale=10.))result=result+jnp.sum(-jnp.log(1+jnp.exp(-(2*y-1)*jnp.dot(all_x,beta))))returnresultbatched_log_joint=jax.jit(jax.vmap(log_joint))
Define the ELBO and its gradient#
defelbo(beta_loc,beta_log_scale,epsilon):beta_sample=beta_loc+jnp.exp(beta_log_scale)*epsilonreturnjnp.mean(batched_log_joint(beta_sample),0)+jnp.sum(beta_log_scale-0.5*np.log(2*np.pi))elbo=jax.jit(elbo)elbo_val_and_grad=jax.jit(jax.value_and_grad(elbo,argnums=(0,1)))
Optimize the ELBO using SGD#
defnormal_sample(key,shape):"""Convenience function for quasi-stateful RNG."""new_key,sub_key=random.split(key)returnnew_key,random.normal(sub_key,shape)normal_sample=jax.jit(normal_sample,static_argnums=(1,))key=random.key(10003)beta_loc=jnp.zeros(num_features,jnp.float32)beta_log_scale=jnp.zeros(num_features,jnp.float32)step_size=0.01batch_size=128epsilon_shape=(batch_size,num_features)foriinrange(1000):key,epsilon=normal_sample(key,epsilon_shape)elbo_val,(beta_loc_grad,beta_log_scale_grad)=elbo_val_and_grad(beta_loc,beta_log_scale,epsilon)beta_loc+=step_size*beta_loc_gradbeta_log_scale+=step_size*beta_log_scale_gradifi%10==0:print('{}\t{}'.format(i,elbo_val))
0-175.561599731445310-112.7636413574218820-102.4135818481445330-100.2779388427734440-99.5581817626953150-98.1800079345703160-98.6023712158203170-97.6973571777343880-97.5322570800781290-97.17939758300781100-97.09412384033203110-97.4031753540039120-97.0446548461914130-97.20584106445312140-96.89036560058594150-96.91874694824219160-97.00558471679688170-97.45591735839844180-96.7357177734375190-96.95585632324219200-97.51350402832031210-96.92330932617188220-97.0315933227539230-96.88632202148438240-96.9697036743164250-97.35342407226562260-97.07598876953125270-97.24360656738281280-97.23468017578125290-97.02444458007812300-97.00311279296875310-97.07694244384766320-97.33139038085938330-97.15113830566406340-97.28958129882812350-97.41972351074219360-96.95799255371094370-97.36982727050781380-97.00273132324219390-97.10067749023438400-97.13653564453125410-96.87237548828125420-97.24083709716797430-97.04019165039062440-96.68864440917969450-97.19795989990234460-97.18959045410156470-97.09814453125480-97.11341857910156490-97.20771789550781500-97.39350128173828510-97.25328063964844520-97.20199584960938530-96.95065307617188540-97.37591552734375550-96.98526763916016560-97.01451873779297570-96.97328186035156580-97.04313659667969590-97.38459777832031600-97.3158187866211610-97.10185241699219620-97.22990417480469630-97.18515014648438640-97.1563720703125650-97.13624572753906660-97.0641860961914670-97.17774963378906680-97.31779479980469690-97.42807006835938700-97.18154907226562710-97.57279968261719720-96.99563598632812730-97.15852355957031740-96.85628509521484750-96.8902587890625760-97.11228942871094770-97.214111328125780-96.99479675292969790-97.30390930175781800-96.98690795898438810-97.12832641601562820-97.51512145996094830-97.4146728515625840-96.89874267578125850-96.84567260742188860-97.2318344116211870-97.24137115478516880-96.74853515625890-97.09489440917969900-97.13866424560547910-96.79051971435547920-97.06621551513672930-97.14911651611328940-97.26902770996094950-97.0196533203125960-96.95348358154297970-97.138916015625980-97.60130310058594990-97.25077056884766
Display the results#
Coverage isn’t quite as good as we might like, but it’s not bad, and nobody said variational inference was exact.
plt.figure(figsize=(7,7))plt.plot(true_beta,beta_loc,'.',label='Approximated Posterior Means')plt.plot(true_beta,beta_loc+2*jnp.exp(beta_log_scale),'r.',label=r'Approximated Posterior $2\sigma$ Error Bars')plt.plot(true_beta,beta_loc-2*jnp.exp(beta_log_scale),'r.')plot_scale=3plt.plot([-plot_scale,plot_scale],[-plot_scale,plot_scale],'k')plt.xlabel('True beta')plt.ylabel('Estimated beta')plt.legend(loc='best')
<matplotlib.legend.Legend at 0x72bc557010a0>

Contents
