Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

JEP 28661: Supporting the__jax_array__ protocol#

@jakevdp,May 2025

An occasional user request is for the ability to define custom array-like objects thatwork with jax APIs. JAX currently has a partial implementation of a mechanism that doesthis via a__jax_array__ method defined on the custom object. This was never intendedto be a load-bearing public API (see the discussion at#4725), but hasbecome essential to packages like Keras and flax, which explicitly document the abilityto use their custom array objects with jax functions. This JEP proposes a design forfull, documented support of the__jax_array__ protocol.

Levels of array extensibility#

Requests for extensibility of JAX arrays come in a few flavors:

Level 1 Extensibility: polymorphic inputs#

What I’ll call “Level 1” extensibility is the desire that JAX APIs accept polymorphic inputs.That is, a user desires behavior like this:

classCustomArray:data:numpy.ndarray...x=CustomArray(np.arange(5))result=jnp.sin(x)# Converts `x` to JAX array and returns a JAX array

Under this extensibility model, JAX functions would accept CustomArray objects as inputs,implicitly converting them tojax.Array objects for the sake of computation.This is similar to the functionality offered by NumPy via the__array__ method, and inJAX (in many but not all cases) via the__jax_array__ method.

This is the mode of extensibility that has been requested by the maintainers offlax.nnxand others. The current implementation is also used by JAX internally for the case ofsymbolic dimensions.

Level 2 extensibility: polymorphic outputs#

What I’ll call “Level 2” extensibility is the desire that JAX APIs should not only acceptpolymorphic inputs, but also wrap outputs to match the class of the input.That is, a user desires behavior like this:

classCustomArray:data:numpy.ndarray...x=CustomArray(np.arange(5))result=jnp.sin(x)# returns a new CustomArray

Under this extensibility model, JAX functions would not only accept custom objectsas inputs, but have some protocol to determine how to correctly re-wrap outputs withthe same class. In NumPy, this sort of functionality is offered in varying degrees bythe special__array_ufunc__,__array_wrap__, and__array_function__ protocols,which allow user-defined objects to customize how NumPy API functions operate onarbitrary inputs and map input types to outputs.JAX does not currently have any equivalent to these interfaces in NumPy.

This is the mode of extensibility that has been requested by the maintainers ofkeras,among others.

Level 3 extensibility: subclassingArray#

What I’ll call “Level 3” extensibility is the desire that the JAX array object itselfcould be subclassable. NumPy provides some APIs that allow this(seeSubclassing ndarray) butthis sort of approach would take some extra thought in JAX due to the need forrepresenting array objects abstractly via tracing.

This mode of extensibility has occasionally been requested by users who want to addspecial metadata to JAX arrays, such as units of measurement.

Synopsis#

For the sake of this proposal, we will stick with the simplest, level 1 extensibilitymodel. The proposed interface is the one currently non-uniformly supported by a numberof JAX APIs, the__jax_array__ method. Its usage looks something like this:

importjaximportjax.numpyasjnpimportnumpyasnpclassCustomArray:data:np.ndarraydef__init__(self,data:np.ndarray):self.data=datadef__jax_array__(self)->jax.Array:returnjnp.asarray(self.data)arr=CustomArray(np.arange(5))result=jnp.multiply(arr,2)print(repr(result))# Array([0, 2, 4, 6, 8], dtype=int32)

We may revisit other extensibility levels in the future.

Design challenges#

JAX presents some interesting design challenges related to this kind of extensibility,which have not been fully explored previously. We’ll discuss them in turn here:

Priority of__jax_array__ vs. PyTree flattening#

JAX already has a supported mechanism for registering custom objects, namely pytreeregistration (seeCustom pytree nodes).If we also supportjax_array, which one should take precedence?

To put this more concretely, what should be the result of this code?

@jax.jitdeff(x):print("is JAX array:",isinstance(x,jax.Array))f(CustomArray(...))

If we choose to prioritize__jax_array__ at the JIT boundary, then the output of thisfunction would be:

isJAXarray:True

That is, at the JIT boundary, theCustomArray object would be converted into a__jax_array__, and its shape and dtype would be used to construct a standard JAXtracer for the function.

If we choose to prioritize pytree flattening at the JIT boundary, then the output ofthis function would be:

type(x)=CustomArray

That is, at the JIT boundary, theCustomArray object is flattened, and then unflattenedbefore being passed to the JIT-compiled function for tracing. IfCustomArray has beenregistered as a pytree, it will generally contain traced arrays as its attributes, andwhen x is passed to any JAX API that supports__jax_array__, these traced attributeswill be converted to a single traced array according to the logic specified in the method.

There are deeper consequences here for how other transformations like vmap and grad workwhen encountering custom objects: for example, if we prioritize pytree flattening, vmapwould operate over the dimensions of the flattened contents of the custom object, whileif we prioritize__jax_array__, vmap would operate over the converted array dimensions.

This also has consequences when it comes to JIT invariance: consider a function like this:

deff(x):ifisinstance(x,CustomArray):returnx.custom_method()else:# do something else...result1=f(x)result2=jax.jit(f)(x)

Ifjit consumesx via pytree flattening, the results should agree for a well-specifiedflattening rule. Ifjit consumesx via__jax_array__, the results will differ becausex is no longer a CustomArray within the JIT-compiled version of the function.

Synopsis#

As of JAX v0.6.0, transformations prioritize__jax_array__ when it is available. This statusquo can lead to confusion around lack of JIT invariance, and the current implementation in practiceleads to subtle bugs in the case of automatic differentiation, where the forward and backward passdo not treat inputs consistently.

Because the pytree extensibility mechanism already exists for the case of customizingtransformations, it seems most straightforward if transformations act only via thismechanism: that is,we propose to remove__jax_array__ parsing during abstractification.This approach will preserve object identity through transformations, and give the user themost possible flexibility. If the user wants to opt-in to array conversion semantics, thatis always possible by explicitly casting their input via jnp.asarray, which will trigger the__jax_array__ protocol.

Which APIs should support__jax_array__?#

JAX has a number of different levels of API, from the level of explicit primitive binding(e.g.jax.lax.add_p.bind(x,y)) to thejax.lax APIs (e.g.jax.lax.add(x,y)) to thejax.numpy APIs (e.g.jax.numpy.add(x,y)). Which of these API categories should handleimplicit conversion via__jax_array__?

In order to limit the scope of the change and the required testing, I propose that__jax_array__only be explicitly supported injax.numpy APIs: after all, it is inspired by the__array__protocol which is supported by the NumPy package. We could always expand this in the future tojax.lax APIs if needed.

This is in line with the current state of the package, where__jax_array__ handling is mainlywithin the input validation utilities used byjax.numpy APIs.

Implementation#

With these design choices in mind, we plan to implement this as follows:

  • Adding runtime support tojax.numpy: This is likely the easiest part, as mostjax.numpy functions use a common internal utility (ensure_arraylike) to validateinputs and convert them to array. This utility already supports__jax_array__, andso most jax.numpy APIs are already compliant.

  • Adding test coverage: To ensure compliance across the APIs, we should add a newtest scaffold that calls everyjax.numpy API with custom inputs and validates correctbehavior.

  • Deprecating__jax_array__ during abstractification: Currently JAX’s abstractificationpass, used injit and other transformations, does parse the__jax_array__ protocol,and this is not the behavior we want long-term. We need to deprecate this behavior, andensure that downstream packages that rely on it can move toward pytree registration orexplicit array conversion where necessary.

  • Adding type annotations: the type interface for jax.numpy functions is injax/numpy/__init__.pyi, and we’ll need to change each input type fromArrayLike toArrayLike|SupportsJAXArray, where the latter is a protocol with a__jax_array__method. We cannot add this directly to theArrayLike definition, becauseArrayLikeis used in contexts where__jax_array__ should not be supported.

  • Documentation: once the above support is added, we should add a documentation sectionon array extensibility that outlines exactly what to expect regarding the__jax_array__protocol, with examples of how it can be used in conjunction with pytree registrationin order to effectively work with user-defined types.


[8]ページ先頭

©2009-2026 Movatter.jp