Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

google-deepmind/chex

Repository files navigation

CI statusdocspypi

Chex is a library of utilities for helping to write reliable JAX code.

This includes utils to help:

  • Instrument your code (e.g. assertions, warnings)
  • Debug (e.g. transformingpmaps invmaps within a context manager).
  • Test JAX code across manyvariants (e.g. jitted vs non-jitted).

Installation

You can install the latest released version of Chex from PyPI via:

pip install chex

or you can install the latest development version from GitHub:

pip install git+https://github.com/deepmind/chex.git

Modules Overview

Dataclass (dataclass.py)

Dataclasses are a popular construct introduced by Python 3.7 to allow toeasily specify typed data structures with minimal boilerplate code. They arenot, however, compatible with JAX anddm-tree out of the box.

In Chex we provide a JAX-friendly dataclass implementation reusing pythondataclasses.

Chex implementation ofdataclass registers dataclasses as internalPyTreenodes to ensurecompatibility with JAX data structures.

In addition, we provide a class wrapper that exposes dataclasses ascollections.Mapping descendants which allows to process them(e.g. (un-)flatten) indm-tree methods as usual Python dictionaries.See@mappable_dataclassdocstring for more details.

Example:

@chex.dataclassclassParameters:x:chex.ArrayDevicey:chex.ArrayDeviceparameters=Parameters(x=jnp.ones((2,2)),y=jnp.ones((1,2)),)# Dataclasses can be treated as JAX pytreesjax.tree_util.tree_map(lambdax:2.0*x,parameters)# and as mappings by dm-treetree.flatten(parameters)

NOTE: Unlike standard Python 3.7 dataclasses, Chexdataclasses cannot be constructed using positional arguments. They supportconstruction arguments provided in the same format as the Python dictconstructor. Dataclasses can be converted to tuples with thefrom_tuple andto_tuple methods if necessary.

parameters=Parameters(jnp.ones((2,2)),jnp.ones((1,2)),)# ValueError: Mappable dataclass constructor doesn't support positional args.

Assertions (asserts.py)

One limitation of PyType annotations for JAX is that they do not support thespecification ofDeviceArray ranks, shapes or dtypes. Chex includes a numberof functions that allow flexible and concise specification of these properties.

E.g. suppose you want to ensure that all tensorst1,t2,t3 have the sameshape, and that tensorst4,t5 have rank2 and (3 or4), respectively.

chex.assert_equal_shape([t1,t2,t3])chex.assert_rank([t4,t5], [2, {3,4}])

More examples:

fromcheximportassert_shape,assert_rank, ...assert_shape(x, (2,3))# x has shape (2, 3)assert_shape([x,y], [(), (2,3)])# x is scalar and y has shape (2, 3)assert_rank(x,0)# x is scalarassert_rank([x,y], [0,2])# x is scalar and y is a rank-2 arrayassert_rank([x,y], {0,2})# x and y are scalar OR rank-2 arraysassert_type(x,int)# x has type `int` (x can be an array)assert_type([x,y], [int,float])# x has type `int` and y has type `float`assert_equal_shape([x,y,z])# x, y, and z have equal shapesassert_trees_all_close(tree_x,tree_y)# values and structure of trees matchassert_tree_all_finite(tree_x)# all tree_x leaves are finiteassert_devices_available(2,'gpu')# 2 GPUs availableassert_tpu_available()# at least 1 TPU availableassert_numerical_grads(f, (x,y),j)# f^{(j)}(x, y) matches numerical grads

Seeasserts.pydocumentation tofind all supported assertions.

If you cannot find a specific assertion, please consider making a pull requestor openning an issue onthe bug tracker.

Optional Arguments

All chex assertions support the following optional kwargs for manipulating theemitted exception messages:

  • custom_message: A string to include into the emitted exception messages.
  • include_default_message: Whether to include the default Chex message intothe emitted exception messages.
  • exception_type: An exception type to use.AssertionError by default.

For example, the following code:

dataset=load_dataset()params=init_params()foriinrange(num_steps):params=update_params(params,dataset.sample())chex.assert_tree_all_finite(params,custom_message=f'Failed at iteration{i}.',exception_type=ValueError)

will raise aValueError that includes a step number whenparams get pollutedwithNaNs orNones.

Static and Value (akaRuntime) Assertions

Chex divides all assertions into 2 classes:static andvalueassertions.

  1. static assertions use anything except concrete values of tensors.Examples:assert_shape,assert_trees_all_equal_dtypes,assert_max_traces.

  2. value assertions require access to tensor values, which are notavailable during JAX tracing (seeHowJAX primitives work),thus such assertion need special treatment in ajitted code.

To enable value assertions in a jitted function, it can be decorated withchex.chexify() wrapper. Example:

@chex.chexify@jax.jitdeflogp1_abs_safe(x:chex.Array)->chex.Array:chex.assert_tree_all_finite(x)returnjnp.log(jnp.abs(x)+1)logp1_abs_safe(jnp.ones(2))# OKlogp1_abs_safe(jnp.array([jnp.nan,3]))# FAILS (in async mode)# The error will be raised either at the next line OR at the next# `logp1_abs_safe` call. See the docs for more detain on async mode.logp1_abs_safe.wait_checks()# Wait for the (async) computation to complete.

Seethis docstringfor more detail onchex.chexify().

JAX Tracing Assertions

JAX re-traces JIT'ted function every time the structure of passed argumentschanges. Often this behavior is inadvertent and leads to a significantperformance drop which is hard to debug.@chex.assert_max_tracesdecorator asserts that the function is not re-traced more thann times duringprogram execution.

Global trace counter can be cleared by callingchex.clear_trace_counter(). This function be used to isolate unittests relyingon@chex.assert_max_traces.

Examples:

@jax.jit@chex.assert_max_traces(n=1)deffn_sum_jitted(x,y):returnx+yfn_sum_jitted(jnp.zeros(3),jnp.zeros(3))# tracing for the 1st time - OKfn_sum_jitted(jnp.zeros([6,7]),jnp.zeros([6,7]))# AssertionError!

Can be used withjax.pmap() as well:

deffn_sub(x,y):returnx-yfn_sub_pmapped=jax.pmap(chex.assert_max_traces(fn_sub,n=10))

SeeHowJAX primitives worksection for more information about tracing.

Warnings (warnigns.py)

In addition to hard assertions Chex also offers utilities to add commonwarnings, such as specific types of deprecation warnings.

Test variants (variants.py)

JAX relies extensively on code transformation and compilation, meaning that itcan be hard to ensure that code is properly tested. For instance, just testing apython function using JAX code will not cover the actual code path that isexecuted when jitted, and that path will also differ whether the code is jittedfor CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugswhere XLA changes would lead to undesirable behaviours that however onlymanifest in one specific code transformation.

Variants make it easy to ensure that unit tests cover different ‘variations’ ofa function, by providing a simple decorator that can be used to repeat any testunder all (or a subset) of the relevant code transformations.

E.g. suppose you want to test the output of a functionfn with or without jit.You can usechex.variants to run the test with both the jitted and non-jittedversion of the function by simply decorating a test method with@chex.variants, and then usingself.variant(fn) in place offn in the bodyof the test.

deffn(x,y):returnx+y...classExampleTest(chex.TestCase):@chex.variants(with_jit=True,without_jit=True)deftest(self):var_fn=self.variant(fn)self.assertEqual(fn(1,2),3)self.assertEqual(var_fn(1,2),fn(1,2))

If you define the function in the test method, you may also useself.variantas a decorator in the function definition. For example:

classExampleTest(chex.TestCase):@chex.variants(with_jit=True,without_jit=True)deftest(self):@self.variantdefvar_fn(x,y):returnx+yself.assertEqual(var_fn(1,2),3)

Example of parameterized test:

fromabsl.testingimportparameterized# Could also be:#  `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):`#  `class ExampleParameterizedTest(chex.TestCase):`classExampleParameterizedTest(parameterized.TestCase):@chex.variants(with_jit=True,without_jit=True)@parameterized.named_parameters(      ('case_positive',1,2,3),      ('case_negative',-1,-2,-3),  )deftest(self,arg_1,arg_2,expected):@self.variantdefvar_fn(x,y):returnx+yself.assertEqual(var_fn(arg_1,arg_2),expected)

Chex currently supports the following variants:

  • with_jit -- appliesjax.jit() transformation to the function.
  • without_jit -- uses the function as is, i.e. identity transformation.
  • with_device -- places all arguments (except specified inignore_argnumsargument) into device memory before applying the function.
  • without_device -- places all arguments in RAM before applying the function.
  • with_pmap -- appliesjax.pmap() transformation to the function (see notes below).

See documentation invariants.py for more details on the supported variants.More examples can be found invariants_test.py.

Variants notes

  • Test classes that use@chex.variants must inherit fromchex.TestCase (or any other base class that unrolls tests generatorswithinTestCase, e.g.absl.testing.parameterized.TestCase).

  • [jax.vmap] All variants can be applied to a vmapped function;please see an example invariants_test.py (test_vmapped_fn_named_params andtest_pmap_vmapped_fn).

  • [@chex.all_variants] You can get all supported variantsby using the decorator@chex.all_variants.

  • [with_pmap variant]jax.pmap(fn)(doc) performsparallel map offn onto multiple devices. Since most tests run in asingle-device environment (i.e. having access to a single CPU or GPU), in whichcasejax.pmap is a functional equivalent tojax.jit, with_pmap variant isskipped by default (although it works fine with a single device). Below wedescribe a way to properly testfn if it is supposed to be used inmulti-device environments (TPUs or multiple CPUs/GPUs). To disable skippingwith_pmap variants in case of a single device, add--chex_skip_pmap_variant_if_single_device=false to your test command.

Fakes (fake.py)

Debugging in JAX is made more difficult by code transformations such asjitandpmap, which introduce optimizations that make code hard to inspect andtrace. It can also be difficult to disable those transformations duringdebugging as they can be called at several places in the underlyingcode. Chex provides tools to globally replacejax.jit with a no-optransformation andjax.pmap with a (non-parallel)jax.vmap, in order to moreeasily debug code in a single-device context.

For example, you can use Chex to fakepmap and have it replaced with avmap.This can be achieved by wrapping your code with a context manager:

withchex.fake_pmap():@jax.pmapdeffn(inputs):    ...# Function will be vmapped over inputsfn(inputs)

The same functionality can also be invoked withstart andstop:

fake_pmap=chex.fake_pmap()fake_pmap.start()...yourjaxcode ...fake_pmap.stop()

In addition, you can fake a real multi-device test environment with amulti-threaded CPU. See sectionFaking multi-device test environments formore details.

See documentation infake.py and examples infake_test.py for more details.

Faking multi-device test environments

In situations where you do not have easy access to multiple devices, you canstill test parallel computation using single-device multi-threading.

In particular, one can force XLA to use a single CPU's threads as separatedevices, i.e. to fake a real multi-device environment with a multi-threaded one.These two options are theoretically equivalent from XLA perspective because theyexpose the same interface and use identical abstractions.

Chex has a flagchex_n_cpu_devices that specifies a number of CPU threads touse as XLA devices.

To set up a multi-threaded XLA environment forabsl tests, definesetUpModule function in your test module:

defsetUpModule():chex.set_n_cpu_devices()

Now you can launch your test withpython test.py --chex_n_cpu_devices=N to runit in multi-device regime. Note thatall tests within a module will have anaccess toN devices.

More examples can be found invariants_test.py,fake_test.py andfake_set_n_cpu_devices_test.py.

Using named dimension sizes.

Chex comes with a small utility that allows you to package a collection ofdimension sizes into a single object. The basic idea is:

dims=chex.Dimensions(B=batch_size,T=sequence_len,E=embedding_dim)...chex.assert_shape(arr,dims['BTE'])

String lookups are translated integer tuples. For instance, let's saybatch_size == 3,sequence_len = 5 andembedding_dim = 7, then

dims['BTE']== (3,5,7)dims['B']== (3,)dims['TTBEE']== (5,5,3,7,7)...

You can also assign dimension sizes dynamically as follows:

dims['XY']=some_matrix.shapedims.Z=13

For more examples, seechex.Dimensionsdocumentation.

Citing Chex

This repository is part of theDeepMind JAX Ecosystem, to cite Chex please usetheDeepMind JAX Ecosystem citation.


[8]ページ先頭

©2009-2025 Movatter.jp