The checkify transformation
Contents
Thecheckify transformation#
Summary: Checkify lets you addjit-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use thecheckify.checkify transformation together with the assert-likecheckify.check function to add runtime checks to JAX code:
fromjax.experimentalimportcheckifyimportjaximportjax.numpyasjnpdeff(x,i):checkify.check(i>=0,"index needs to be non-negative, got{i}",i=i)y=x[i]z=jnp.sin(y)returnzjittable_f=checkify.checkify(f)err,z=jax.jit(jittable_f)(jnp.ones((5,)),-2)print(err.get())# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))
You can also use checkify to automatically add common checks:
errors=checkify.user_checks|checkify.index_checks|checkify.float_checkschecked_f=checkify.checkify(f,errors=errors)err,z=checked_f(jnp.ones((5,)),100)err.throw()# ValueError: out-of-bounds indexing at <..>:7 (f)err,z=checked_f(jnp.ones((5,)),-1)err.throw()# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))err,z=checked_f(jnp.array([jnp.inf,1]),0)err.throw()# ValueError: nan generated by primitive sin at <...>:8 (f)err,z=checked_f(jnp.array([5,1]),0)err.throw()# if no error occurred, throw does nothing!
Functionalizing checks#
The assert-like check API by itself is not functionally pure: it can raise a Python Exception as a side-effect, just like assert. So it can’t be staged out withjit,pmap,pjit, orscan:
jax.jit(f)(jnp.ones((5,)),-1)# checkify transformation not used# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.
But the checkify transformation functionalizes (or discharges) these effects. A checkify-transformed function returns an errorvalue as a new output and remains functionally pure. That functionalization means checkify-transformed functions can be composed with staging/transforms however we like:
err,z=jax.pmap(checked_f)(jnp.ones((3,5)),jnp.array([-1,2,100]))err.throw()"""ValueError:.. at mapped index 0: index needs to be non-negative! (check failed at :6 (f)).. at mapped index 2: out-of-bounds indexing at <..>:7 (f)"""
Why does JAX need checkify?#
Under some JAX transformations you can express runtime error checks with ordinary Python assertions, for example when only usingjax.grad andjax.numpy:
deff(x):assertx>0.,"must be positive!"returnjnp.log(x)jax.grad(f)(0.)# ValueError: "must be positive!"
But ordinary assertions don’t work insidejit,pmap,pjit, orscan. In those cases, numeric computations are staged out rather than evaluated eagerly during Python execution, and as a result numeric values aren’t available:
jax.jit(f)(0.)# ConcretizationTypeError: "Abstract tracer value encountered ..."
JAX transformation semantics rely on functional purity, especially when composing multiple transformations, so how can we provide an error mechanism without disrupting all that?Beyond needing a new API, the situation is trickier still:XLA HLO doesn’t support assertions or throwing errors, so even if we had a JAX API which was able to stage out assertions, how would we lower these assertions to XLA?
You could imagine manually adding run-time checks to your function and plumbing out values representing errors:
deff_checked(x):error=x<=0.result=jnp.log(x)returnerror,resulterr,y=jax.jit(f_checked)(0.)iferr:raiseValueError("must be positive!")# ValueError: "must be positive!"
The error is a regular value computed by the function, and the error is raised outside off_checked.f_checked is functionally pure, so we know by construction that it’ll already work withjit, pmap, pjit, scan, and all of JAX’s transformations. The only problem is that this plumbing can be a pain!
checkify does this rewrite for you: that includes plumbing the error value through the function, rewriting checks to boolean operations and merging the result with the tracked error value, and returning the final error value as an output to the checkified function:
deff(x):checkify.check(x>0.,"{} must be positive!",x)# convenient but effectful APIreturnjnp.log(x)f_checked=checkify(f)err,x=jax.jit(f_checked)(-1.)err.throw()# ValueError: -1. must be positive! (check failed at <...>:2 (f))
We call this functionalizing or discharging the effect introduced by calling check. (In the “manual” example above the error value is just a boolean. checkify’s error values are conceptually similar but also track error messages and expose throw and get methods; seejax.experimental.checkify).checkify.check also allows you to add run-time values to your error message by providing them as format arguments to the error message.
You could now manually instrument your code with run-time checks, butcheckify can also automatically add checks for common errors!Consider these error cases:
jnp.arange(3)[5]# out of boundsjnp.sin(jnp.inf)# NaN generatedjnp.ones((5,))/jnp.arange(5)# division by zero
By defaultcheckify only dischargescheckify.checks, and won’t do anything to catch errors like the above. But if you ask it to,checkify will also instrument your code with checks automatically.
deff(x,i):y=x[i]# i could be out of bounds.z=jnp.sin(y)# z could become NaNreturnzerrors=checkify.user_checks|checkify.index_checks|checkify.float_checkschecked_f=checkify.checkify(f,errors=errors)err,z=checked_f(jnp.ones((5,)),100)err.throw()# ValueError: out-of-bounds indexing at <..>:7 (f)err,z=checked_f(jnp.array([jnp.inf,1]),0)err.throw()# ValueError: nan generated by primitive sin at <...>:8 (f)
The API for selecting which automatic checks to enable is based on Sets. Seejax.experimental.checkify for more details.
checkify under JAX transformations.#
As demonstrated in the examples above, a checkified function can be happilyjitted. Here’s a few more examples ofcheckify with other JAXtransformations. Note that checkified functions are functionally pure, andshould trivially compose with all JAX transformations!
jit#
You can safely addjax.jit to a checkified function, orcheckify a jittedfunction, both will work.
deff(x,i):returnx[i]checkify_of_jit=checkify.checkify(jax.jit(f))jit_of_checkify=jax.jit(checkify.checkify(f))err,_=checkify_of_jit(jnp.ones((5,)),100)err.get()# out-of-bounds indexing at <..>:2 (f)err,_=jit_of_checkify(jnp.ones((5,)),100)# out-of-bounds indexing at <..>:2 (f)
vmap/pmap#
You canvmap andpmap checkified functions (orcheckify mapped functions).Mapping a checkified function will give you a mapped error, which can containdifferent errors for every element of the mapped dimension.
deff(x,i):checkify.check(i>=0,"index needs to be non-negative!")returnx[i]checked_f=checkify.checkify(f,errors=checkify.all_checks)errs,out=jax.vmap(checked_f)(jnp.ones((3,5)),jnp.array([-1,2,100]))errs.throw()"""ValueError: at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f)) at mapped index 2: out-of-bounds indexing at <...>:3 (f)"""
However, a checkify-of-vmap will produce a single (unmapped) error!
@jax.vmapdeff(x,i):checkify.check(i>=0,"index needs to be non-negative!")returnx[i]checked_f=checkify.checkify(f,errors=checkify.all_checks)err,out=checked_f(jnp.ones((3,5)),jnp.array([-1,2,100]))err.throw()# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))
pjit#
pjit of a checkified functionjust works, you only need to specify anadditionalout_axis_resources ofNone for the error value output.
deff(x):returnx/xf=checkify.checkify(f,errors=checkify.float_checks)f=pjit(f,in_shardings=PartitionSpec('x',None),out_shardings=(None,PartitionSpec('x',None)))withjax.sharding.Mesh(mesh.devices,mesh.axis_names):err,data=f(input_data)err.throw()# ValueError: divided by zero at <...>:4 (f)
grad#
Your gradient computation will also be instrumented if you checkify-of-grad:
deff(x):returnx/(1+jnp.sqrt(x))grad_f=jax.grad(f)err,_=checkify.checkify(grad_f,errors=checkify.nan_checks)(0.)print(err.get())>>nangeneratedbyprimitivemulat<...>:3(f)
Note that there’s no multiply inf, but there is a multiply in its gradient computation (and this is where the NaN is generated!). So use checkify-of-grad to add automatic checks to both forward and backward pass operations.
checkify.checks will only be applied to the primal value of your function. Ifyou want to use acheck on a gradient value, use acustom_vjp:
@jax.custom_vjpdefassert_gradient_negative(x):returnxdeffwd(x):returnassert_gradient_negative(x),Nonedefbwd(_,grad):checkify.check(grad<0,"gradient needs to be negative!")return(grad,)assert_gradient_negative.defvjp(fwd,bwd)jax.grad(assert_gradient_negative)(-1.)# ValueError: gradient needs to be negative!
Strengths and limitations ofjax.experimental.checkify#
Strengths#
You can use it everywhere (errors are “just values” and behave intuitively under transformations like other values)
Automatic instrumentation: you don’t need to make local modifications to your code. Instead,
checkifycan instrument all of it!
Limitations#
Adding a lot of runtime checks can be expensive (eg. adding a NaN check toevery primitive will add a lot of operations to your computation)
Requires threading error values out of functions and manually throwing theerror. If the error is not explicitly thrown, you might miss out on errors!
Throwing an error value will materialize that error value on the host, meaningit’s a blocking operation which defeats JAX’s async run-ahead.
