- Notifications
You must be signed in to change notification settings - Fork48
License
google-deepmind/chex
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Chex is a library of utilities for helping to write reliable JAX code.
This includes utils to help:
- Instrument your code (e.g. assertions, warnings)
- Debug (e.g. transforming
pmaps
invmaps
within a context manager). - Test JAX code across many
variants
(e.g. jitted vs non-jitted).
You can install the latest released version of Chex from PyPI via:
pip install chex
or you can install the latest development version from GitHub:
pip install git+https://github.com/deepmind/chex.git
Dataclass (dataclass.py)
Dataclasses are a popular construct introduced by Python 3.7 to allow toeasily specify typed data structures with minimal boilerplate code. They arenot, however, compatible with JAX anddm-tree out of the box.
In Chex we provide a JAX-friendly dataclass implementation reusing pythondataclasses.
Chex implementation ofdataclass
registers dataclasses as internalPyTreenodes to ensurecompatibility with JAX data structures.
In addition, we provide a class wrapper that exposes dataclasses ascollections.Mapping
descendants which allows to process them(e.g. (un-)flatten) indm-tree
methods as usual Python dictionaries.See@mappable_dataclass
docstring for more details.
Example:
@chex.dataclassclassParameters:x:chex.ArrayDevicey:chex.ArrayDeviceparameters=Parameters(x=jnp.ones((2,2)),y=jnp.ones((1,2)),)# Dataclasses can be treated as JAX pytreesjax.tree_util.tree_map(lambdax:2.0*x,parameters)# and as mappings by dm-treetree.flatten(parameters)
NOTE: Unlike standard Python 3.7 dataclasses, Chexdataclasses cannot be constructed using positional arguments. They supportconstruction arguments provided in the same format as the Python dictconstructor. Dataclasses can be converted to tuples with thefrom_tuple
andto_tuple
methods if necessary.
parameters=Parameters(jnp.ones((2,2)),jnp.ones((1,2)),)# ValueError: Mappable dataclass constructor doesn't support positional args.
Assertions (asserts.py)
One limitation of PyType annotations for JAX is that they do not support thespecification ofDeviceArray
ranks, shapes or dtypes. Chex includes a numberof functions that allow flexible and concise specification of these properties.
E.g. suppose you want to ensure that all tensorst1
,t2
,t3
have the sameshape, and that tensorst4
,t5
have rank2
and (3
or4
), respectively.
chex.assert_equal_shape([t1,t2,t3])chex.assert_rank([t4,t5], [2, {3,4}])
More examples:
fromcheximportassert_shape,assert_rank, ...assert_shape(x, (2,3))# x has shape (2, 3)assert_shape([x,y], [(), (2,3)])# x is scalar and y has shape (2, 3)assert_rank(x,0)# x is scalarassert_rank([x,y], [0,2])# x is scalar and y is a rank-2 arrayassert_rank([x,y], {0,2})# x and y are scalar OR rank-2 arraysassert_type(x,int)# x has type `int` (x can be an array)assert_type([x,y], [int,float])# x has type `int` and y has type `float`assert_equal_shape([x,y,z])# x, y, and z have equal shapesassert_trees_all_close(tree_x,tree_y)# values and structure of trees matchassert_tree_all_finite(tree_x)# all tree_x leaves are finiteassert_devices_available(2,'gpu')# 2 GPUs availableassert_tpu_available()# at least 1 TPU availableassert_numerical_grads(f, (x,y),j)# f^{(j)}(x, y) matches numerical grads
Seeasserts.py
documentation tofind all supported assertions.
If you cannot find a specific assertion, please consider making a pull requestor openning an issue onthe bug tracker.
All chex assertions support the following optional kwargs for manipulating theemitted exception messages:
custom_message
: A string to include into the emitted exception messages.include_default_message
: Whether to include the default Chex message intothe emitted exception messages.exception_type
: An exception type to use.AssertionError
by default.
For example, the following code:
dataset=load_dataset()params=init_params()foriinrange(num_steps):params=update_params(params,dataset.sample())chex.assert_tree_all_finite(params,custom_message=f'Failed at iteration{i}.',exception_type=ValueError)
will raise aValueError
that includes a step number whenparams
get pollutedwithNaNs
orNone
s.
Chex divides all assertions into 2 classes:static andvalueassertions.
static assertions use anything except concrete values of tensors.Examples:
assert_shape
,assert_trees_all_equal_dtypes
,assert_max_traces
.value assertions require access to tensor values, which are notavailable during JAX tracing (seeHowJAX primitives work),thus such assertion need special treatment in ajitted code.
To enable value assertions in a jitted function, it can be decorated withchex.chexify()
wrapper. Example:
@chex.chexify@jax.jitdeflogp1_abs_safe(x:chex.Array)->chex.Array:chex.assert_tree_all_finite(x)returnjnp.log(jnp.abs(x)+1)logp1_abs_safe(jnp.ones(2))# OKlogp1_abs_safe(jnp.array([jnp.nan,3]))# FAILS (in async mode)# The error will be raised either at the next line OR at the next# `logp1_abs_safe` call. See the docs for more detain on async mode.logp1_abs_safe.wait_checks()# Wait for the (async) computation to complete.
Seethis docstringfor more detail onchex.chexify()
.
JAX re-traces JIT'ted function every time the structure of passed argumentschanges. Often this behavior is inadvertent and leads to a significantperformance drop which is hard to debug.@chex.assert_max_tracesdecorator asserts that the function is not re-traced more thann
times duringprogram execution.
Global trace counter can be cleared by callingchex.clear_trace_counter()
. This function be used to isolate unittests relyingon@chex.assert_max_traces
.
Examples:
@jax.jit@chex.assert_max_traces(n=1)deffn_sum_jitted(x,y):returnx+yfn_sum_jitted(jnp.zeros(3),jnp.zeros(3))# tracing for the 1st time - OKfn_sum_jitted(jnp.zeros([6,7]),jnp.zeros([6,7]))# AssertionError!
Can be used withjax.pmap()
as well:
deffn_sub(x,y):returnx-yfn_sub_pmapped=jax.pmap(chex.assert_max_traces(fn_sub,n=10))
SeeHowJAX primitives worksection for more information about tracing.
Warnings (warnigns.py)
In addition to hard assertions Chex also offers utilities to add commonwarnings, such as specific types of deprecation warnings.
Test variants (variants.py)
JAX relies extensively on code transformation and compilation, meaning that itcan be hard to ensure that code is properly tested. For instance, just testing apython function using JAX code will not cover the actual code path that isexecuted when jitted, and that path will also differ whether the code is jittedfor CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugswhere XLA changes would lead to undesirable behaviours that however onlymanifest in one specific code transformation.
Variants make it easy to ensure that unit tests cover different ‘variations’ ofa function, by providing a simple decorator that can be used to repeat any testunder all (or a subset) of the relevant code transformations.
E.g. suppose you want to test the output of a functionfn
with or without jit.You can usechex.variants
to run the test with both the jitted and non-jittedversion of the function by simply decorating a test method with@chex.variants
, and then usingself.variant(fn)
in place offn
in the bodyof the test.
deffn(x,y):returnx+y...classExampleTest(chex.TestCase):@chex.variants(with_jit=True,without_jit=True)deftest(self):var_fn=self.variant(fn)self.assertEqual(fn(1,2),3)self.assertEqual(var_fn(1,2),fn(1,2))
If you define the function in the test method, you may also useself.variant
as a decorator in the function definition. For example:
classExampleTest(chex.TestCase):@chex.variants(with_jit=True,without_jit=True)deftest(self):@self.variantdefvar_fn(x,y):returnx+yself.assertEqual(var_fn(1,2),3)
Example of parameterized test:
fromabsl.testingimportparameterized# Could also be:# `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):`# `class ExampleParameterizedTest(chex.TestCase):`classExampleParameterizedTest(parameterized.TestCase):@chex.variants(with_jit=True,without_jit=True)@parameterized.named_parameters( ('case_positive',1,2,3), ('case_negative',-1,-2,-3), )deftest(self,arg_1,arg_2,expected):@self.variantdefvar_fn(x,y):returnx+yself.assertEqual(var_fn(arg_1,arg_2),expected)
Chex currently supports the following variants:
with_jit
-- appliesjax.jit()
transformation to the function.without_jit
-- uses the function as is, i.e. identity transformation.with_device
-- places all arguments (except specified inignore_argnums
argument) into device memory before applying the function.without_device
-- places all arguments in RAM before applying the function.with_pmap
-- appliesjax.pmap()
transformation to the function (see notes below).
See documentation invariants.py for more details on the supported variants.More examples can be found invariants_test.py.
Test classes that use
@chex.variants
must inherit fromchex.TestCase
(or any other base class that unrolls tests generatorswithinTestCase
, e.g.absl.testing.parameterized.TestCase
).[
jax.vmap
] All variants can be applied to a vmapped function;please see an example invariants_test.py (test_vmapped_fn_named_params
andtest_pmap_vmapped_fn
).[
@chex.all_variants
] You can get all supported variantsby using the decorator@chex.all_variants
.[
with_pmap
variant]jax.pmap(fn)
(doc) performsparallel map offn
onto multiple devices. Since most tests run in asingle-device environment (i.e. having access to a single CPU or GPU), in whichcasejax.pmap
is a functional equivalent tojax.jit
,with_pmap
variant isskipped by default (although it works fine with a single device). Below wedescribe a way to properly testfn
if it is supposed to be used inmulti-device environments (TPUs or multiple CPUs/GPUs). To disable skippingwith_pmap
variants in case of a single device, add--chex_skip_pmap_variant_if_single_device=false
to your test command.
Fakes (fake.py)
Debugging in JAX is made more difficult by code transformations such asjit
andpmap
, which introduce optimizations that make code hard to inspect andtrace. It can also be difficult to disable those transformations duringdebugging as they can be called at several places in the underlyingcode. Chex provides tools to globally replacejax.jit
with a no-optransformation andjax.pmap
with a (non-parallel)jax.vmap
, in order to moreeasily debug code in a single-device context.
For example, you can use Chex to fakepmap
and have it replaced with avmap
.This can be achieved by wrapping your code with a context manager:
withchex.fake_pmap():@jax.pmapdeffn(inputs): ...# Function will be vmapped over inputsfn(inputs)
The same functionality can also be invoked withstart
andstop
:
fake_pmap=chex.fake_pmap()fake_pmap.start()...yourjaxcode ...fake_pmap.stop()
In addition, you can fake a real multi-device test environment with amulti-threaded CPU. See sectionFaking multi-device test environments formore details.
See documentation infake.py and examples infake_test.py for more details.
In situations where you do not have easy access to multiple devices, you canstill test parallel computation using single-device multi-threading.
In particular, one can force XLA to use a single CPU's threads as separatedevices, i.e. to fake a real multi-device environment with a multi-threaded one.These two options are theoretically equivalent from XLA perspective because theyexpose the same interface and use identical abstractions.
Chex has a flagchex_n_cpu_devices
that specifies a number of CPU threads touse as XLA devices.
To set up a multi-threaded XLA environment forabsl
tests, definesetUpModule
function in your test module:
defsetUpModule():chex.set_n_cpu_devices()
Now you can launch your test withpython test.py --chex_n_cpu_devices=N
to runit in multi-device regime. Note thatall tests within a module will have anaccess toN
devices.
More examples can be found invariants_test.py,fake_test.py andfake_set_n_cpu_devices_test.py.
Chex comes with a small utility that allows you to package a collection ofdimension sizes into a single object. The basic idea is:
dims=chex.Dimensions(B=batch_size,T=sequence_len,E=embedding_dim)...chex.assert_shape(arr,dims['BTE'])
String lookups are translated integer tuples. For instance, let's saybatch_size == 3
,sequence_len = 5
andembedding_dim = 7
, then
dims['BTE']== (3,5,7)dims['B']== (3,)dims['TTBEE']== (5,5,3,7,7)...
You can also assign dimension sizes dynamically as follows:
dims['XY']=some_matrix.shapedims.Z=13
For more examples, seechex.Dimensionsdocumentation.
This repository is part of theDeepMind JAX Ecosystem, to cite Chex please usetheDeepMind JAX Ecosystem citation.