Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

JEP 18137: Scope of JAX NumPy & SciPy Wrappers#

Jake VanderPlas

October 2023

Until now, the intended scope ofjax.numpy andjax.scipy has been relativelyill-defined. This document proposes a well-defined scope for these packages to better guideand evaluate future contributions, and to motivate the removal of some out-of-scope code.

Background#

From the beginning, JAX has aimed to provide a NumPy-like API for executing code in XLA,and a big part of the project’s development has been building out thejax.numpy andjax.scipy namespaces as JAX-based implementations of NumPy and SciPy APIs. There has alwaysbeen an implicit understanding that some parts ofnumpy andscipy are out-of-scopefor JAX, but this scope has not been well defined. This can lead to confusion and frustrationfor contributors, because there’s no clear answer to whether potentialjax.numpy andjax.scipy contributions will be accepted into JAX.

Why Limit the Scope?#

To avoid leaving this unsaid, we should be explicit: it is a fact that any code includedin a project like JAX incurs a small but nonzero ongoing maintenance burden for the developers.The success of a project over time directly relates to the ability of maintainers to continuethis maintenance for the sum of all the project’s parts: documenting functionality, respondingto questions, fixing bugs, etc. For long-term success and sustainability of any software tool,it’s vital that maintainers carefully weigh whether any particular contribution will be a netpositive for the project given its goals and resources.

Evaluation Rubric#

This document proposes a rubric of six axes along which the scope of any particularnumpyorscipy API can be judged for inclusion into JAX. An API which is strong along all axesis an excellent candidate for inclusion in the JAX package; a strong weakness alongany ofthe six axes is a good argument against inclusion in JAX.

Axis 1: XLA alignment#

The first axis we consider is the degree to which the proposed API aligns with native XLAoperations. For example,jax.numpy.exp() is a function that more-or-less directly mirrorsjax.lax.exp. A large number of functions innumpy,scipy.special,numpy.linalg,scipy.linalg, and others meet this criteria: such functions pass the XLA-alignment checkwhen considering their inclusion into JAX.

On the other end, there are functions likenumpy.unique(), which do not directly correspondto any XLA operation, and in some cases are fundamentally incompatible with JAX’s currentcomputational model, which requires statically-shaped arrays (e.g.unique returns avalue-dependent dynamic array shape). Such functions do not pass the XLA alignment checkwhen considering their inclusion into JAX.

We also consider as part of this axis the need for pure function semantics. For example,numpy.random is built on an implicitly-updated state-based RNG, which is fundamentallyincompatible with JAX’s computational model built on XLA.

Axis 2: Array API Alignment#

The second axis we consider focuses on thePython Array API Standard: this is in somesenses a community-driven outline of which array operations are central to array-orientedprogramming across a wide range of user communities. If an API innumpy orscipy islisted within the Array API standard, it is a strong signal that JAX should include it.Using the example from above, the Array API standard includes several variants ofnumpy.unique() (unique_all,unique_counts,unique_inverse,unique_values) whichsuggests that, despite the function not being precisely aligned with XLA, it is importantenough to the Python user community that JAX should perhaps implement it.

Axis 3: Existence of Downstream Implementations#

For functionality that does not align with Axis 1 or 2, an important consideration forinclusion into JAX is whether there exist well-supported downstream packages that supplythe functionality in question. A good example of this isscipy.optimize: while JAX doesinclude a minimal set of wrappers ofscipy.optimize functionality, a much more completetreatment exists in theJAXopt package, which is activelymaintained by JAX collaborators. In cases like this, we should lean toward pointing usersand contributors to these specialized packages rather than re-implementing such APIs inJAX itself.

Axis 4: Complexity & Robustness of Implementation#

For functionality that does not align with XLA, one consideration is the degree ofcomplexity of the proposed implementation. This aligns to some degree with Axis 1,but nevertheless is important to call out. A number of functions have been contributedto JAX which have relatively complex implementations which are difficult to validateand introduce outsized maintenance burdens; an example isjax.scipy.special.bessel_jn():as of the writing of this JEP, its current implementation is a non-straightforwarditerative approximation that hasconvergence issues in some domains,andproposed fixes introduce furthercomplexity. Had we more carefully weighed the complexity and robustness of theimplementation when accepting the contribution, we may have chosen not to accept thiscontribution to the package.

Axis 5: Functional vs. Object-Oriented APIs#

JAX works best with functional APIs rather than object-oriented APIs. Object-orientedAPIs can often hide impure semantics, making them often difficult to implement well.NumPy and SciPy generally stick to functional APIs, but sometimes provide object-orientedconvenience wrappers.

Examples of this arenumpy.polynomial.Polynomial, which wraps lower-level operationslikenumpy.polyadd(),numpy.polydiv(), etc. In general, when there are both functionaland object-oriented APIs available, JAX should avoid providing wrappers for theobject-oriented APIs and instead provide wrappers for the functional APIs.

In cases where only the object-oriented APIs exist, JAX should avoid providing wrappersunless the case is strong along other axes.

Axis 6: General “Importance” to JAX Users & Stakeholders#

The decision to include a NumPy/SciPy API in JAX should also take into account theimportance of the algorithm to the general user community. It is admittedly difficultto quantify who is a “stakeholder” and how this importance should be measured; but weinclude this to make clear that any decision about what to include in JAX’s NumPy andSciPy wrappers will involve some amount of discretion that cannot be easily quantified.

For existing APIs, searches for usage in github may be useful in establishing importanceor lack thereof; as an example, we might return tojax.scipy.special.bessel_jn()discussed above: a search shows that this function has only ahandful of useson github, probably partly to do with the previously mentioned accuracy issues.

Evaluation: what’s in scope?#

In this section, we’ll attempt to evaluate the NumPy and SciPy APIs, including someexamples from the current JAX API, in light of the above rubric. This will not be acomprehensive listing of all existing functions and classes, but rather a more generaldiscussion by submodule and topic, with relevant examples.

NumPy APIs#

numpy namespace#

We consider the functions in the mainnumpy namespace to be essentially all in-scopefor JAX, due to its general alignment with XLA (Axis 1) and the Python Array API(Axis 2), as well as its general importance to the JAX user community (Axis 6).Some functions are perhaps borderline (functions likenumpy.intersect1d(),np.setdiff1d(),np.union1d() arguably fail parts of the rubric) but forsimplicity we declare that all array functions in the main numpy namespace are in-scopefor JAX.

numpy.linalg &numpy.fft#

Thenumpy.linalg andnumpy.fft submodules contain many functions thatbroadly align with functionality provided by XLA. Others have complicated device-specificlowerings, but represent a case where importance to stakeholders (Axis 6) outweighs complexity.For this reason, we deem both of these submodules in-scope for JAX.

numpy.random#

numpy.random is out-of-scope for JAX, because state-based RNGs are fundamentallyincompatible with JAX’s computation model. We instead focus onjax.random,which offers similar functionality using a counter-based PRNG.

numpy.ma &numpy.polynomial#

Thenumpy.ma andnumpy.polynomial submodules are mostly concerned withproviding object-oriented interfaces to computations that can be expressed via otherfunctional means (Axis 5); for this reason, we deem them out-of-scope for JAX.

numpy.testing#

NumPy’s testing functionality only really makes sense for host-side computation,and so we don’t include any wrappers for it in JAX. That said, JAX arrays arecompatible withnumpy.testing, and JAX makes frequent use of it throughoutthe JAX test suite.

SciPy APIs#

SciPy has no functions in the top-level namespace, but includes a number ofsubmodules. We consider each below, leaving out modules which have been deprecated.

scipy.cluster#

Thescipy.cluster module includes tools for hierarchical clustering, k-means,and related algorithms. These are weak along several axes, and would be betterserved by a downstream package. One function already exists within JAX(jax.scipy.cluster.vq.vq()) but hasno obvious usageon github: this suggests that clustering is not broadly important to JAX users.

Recommendation: deprecate and removejax.scipy.cluster.vq().

scipy.constants#

Thescipy.constants module includes mathematical and physical constants.These constants can be used directly with JAX, and so there is no reason tore-implement this in JAX.

scipy.datasets#

Thescipy.datasets module includes tools to fetch and load datasets.These fetched datasets can be used directly with JAX, and so there is noreason to re-implement this in JAX.

scipy.fft#

Thescipy.fft module contains functions that broadly align with functionalityprovided by XLA, and fare well along other axes as well. For this reason,we deem them in-scope for JAX.

scipy.integrate#

Thescipy.integrate module contains functions for numerical integration. Themore sophisticated of these (quad,dblquad,ode) are out-of-scope for JAX byaxes 1 & 4, since they tend to be loopy algorithms based on dynamic numbers ofevaluations.jax.experimental.ode.odeint() is related, but rather limited and notunder any active development.

JAX does currently includejax.scipy.integrate.trapezoid(), but this is only becausenumpy.trapz() was recently deprecated in favor of this. For any particular input,its implementation could be replaced with one line ofjax.numpy expressions, soit’s not a particularly useful API to provide.

Based on Axes 1, 2, 4, and 6,scipy.integrate should be considered out-of-scope for JAX.

Recommendation: removejax.scipy.integrate.trapezoid(), which was added in JAX 0.4.14.

scipy.interpolate#

Thescipy.interpolate module provides both low-level and object-oriented routinesfor interpolating in one or more dimensions. These APIs rate poorly along a numberof the axes above: they are class-based rather than low-level, and none but thesimplest methods can be expressed efficiently in terms of XLA operations.

JAX does currently have wrappers forscipy.interpolate.RegularGridInterpolator.Were we considering this contribution today, we would probably reject it by theabove criteria. But this code has been fairly stable so there is not much downsideto continuing to maintain it.

Going forward, we should consider other members ofscipy.interpolate to beout-of-scope for JAX.

scipy.io#

Thescipy.io submodule has to do with file input/output. There is no reasonto re-implement this in JAX.

scipy.linalg#

Thescipy.linalg submodule contains functions that broadly align with functionalityprovided by XLA, and fast linear algebra is broadly important to the JAX user community.For this reason, we deem it in-scope for JAX.

scipy.ndimage#

Thescipy.ndimage submodule contains a set of tools for working on image data. Manyof these overlap with tools inscipy.signal (e.g. convolutions and filtering). JAXcurrently provides onescipy.ndimage API, injax.scipy.ndimage.map_coordinates().Additionally, JAX provides some image-related tools in thejax.image module. Thedeepmind ecosystem includesdm-pix, amore full-featured set of tools for image manipulation in JAX. Given all these factors,I’d suggest thatscipy.ndimage should be considered out-of-scope for JAX core; we canpoint interested users and contributors to dm-pix. We can consider movingmap_coordinatestodm-pix or to another appropriate package.

scipy.odr#

Thescipy.odr module provides an object-oriented wrapper aroundODRPACK forperforming orthogonal distance regressions. It is not clear that this could be cleanlyexpressed using existing JAX primitives, and so we deem it out of scope for JAX itself.

scipy.optimize#

Thescipy.optimize module provides high-level and low-level interfaces for optimization.Such functionality is important to a lot of JAX users, and very early on JAX createdjax.scipy.optimize wrappers. However, developers of these routines soon realized thatthescipy.optimize API was too constraining, and different teams began working on theJAXopt package and theOptimistix package, each of which containa much more comprehensive and better-tested set of optimization routines in JAX.

Because of these well-supported external packages, we now considerscipy.optimizeto be out-of-scope for JAX.

Recommendation: deprecatejax.scipy.optimize and/or make it a lightweight wrapperaround an optional JAXopt or Optimistix dependency.

🟡scipy.signal#

Thescipy.signal module is mixed: some functions are squarely in-scope for JAX(e.g.correlate andconvolve, which are more user-friendly wrappers oflax.conv_general_dilated), while many others are squarely out-of-scope (domain-specifictools with no viable lowering path to XLA). Potential contributions tojax.scipy.signalwill have to be weighed on a case-by-case basis.

🟡scipy.sparse#

Thescipy.sparse submodule mainly contains data structures for storing and operatingon sparse matrices and arrays in a variety of formats. Additionally,scipy.sparse.linalgcontains a number of matrix-free solvers, suitable for use with sparse matrices,dense matrices, and linear operators.

Thescipy.sparse array and matrix data structures are out-of-scope for JAX, becausethey do not align with JAX’s computational model (e.g. many operations depend ondynamically-sized buffers). JAX has developed thejax.experimental.sparse moduleas an alternative set of data structures that are more in-line with JAX’s computationalconstraints. For these reasons, we consider the data structures inscipy.sparse tobe out-of-scope for JAX.

On the other hand,scipy.sparse.linalg has proven to be an interesting area, andjax.scipy.sparse.linalg includes thebicgstab,cg, andgmres solvers. Theseare useful to the JAX user community (Axis 6) but aside from this do not fare wellalong other axes. They would be very suitable for moving into a downstream library;one potential option may beLineax, which featuresa number of linear solvers built on JAX.

Recommendation: explore moving sparse solvers into Lineax, and otherwise treat`scipy.sparse`` as out-of-scope for JAX.

scipy.spatial#

Thescipy.spatial module contains mainly object-oriented interfaces to spatial/distancecomputations and nearest neighbor searches. It is mostly out-of-scope for JAX

Thescipy.spatial.transform submodule provides tools for manipulating three-dimensionalspatial rotations. It is a relatively complicated object-oriented interface, and couldperhaps be better served by a downstream project. JAX currently contains partialimplementations ofRotation andSlerp withinjax.scipy.spatial.transform;these are object-oriented wrappers of otherwise basicfunctions, which introduce a very large API surface and have very few users. It is ourjudgment that they are out-of-scope for JAX itself, with users better-served by ahypothetical downstream project.

Thescipy.spatial.distance submodule contains a useful collection of distance metrics,and it might be tempting to provide JAX wrappers for these. That said, with jit and vmapit would be straightforward for a user to define efficient versions of most of these fromscratch if needed, so adding them to JAX is not particularly beneficial.

Recommendation: consider deprecating and removing theRotation andSlerp APIs, andconsiderscipy.spatial as a whole out-of-scope for future contributions.

scipy.special#

Thescipy.special module includes implementations of a number of more specializedfunctions. In many cases, these functions are squarely in scope: for example, functionslikegammaln,betainc,digamma, and many others correspond directly to availableXLA primitives, and are clearly in-scope by Axis 1 and others.

Other functions require more complicated implementations; one example mentioned aboveisbessel_jn. Despite not aligning on Axes 1 and 2, these functions tend to be verystrong along Axis 6:scipy.special provides fundamental functions necessary forcomputation in a variety of domains, so even functions with complicated implementationsshould lean toward in-scope, so long as the implementations are well-designed and robust.

There are a few existing function wrappers that we should take a closer look at; for example:

  • jax.scipy.special.lpmn(): this generates legendre polynomials via a complicated fori_loop,in a way that does not match the scipy API (e.g. forscipy,z must be a scalar, while forJAX,z must be a 1D array). The function has few discoverable uses making it a weakcandidate along Axes 1, 2, 4, and 6.

  • jax.scipy.special.lpmn_values(): this has similar weaknesses tolmpn above.

  • jax.scipy.special.sph_harm(): this is built on lpmn, and similarly has an API that divergesfrom the correspondingscipy function.

  • jax.scipy.special.bessel_jn(): as discussed under Axis 4 above, this has weaknesses interms of robustness of implementation and little usage. We might consider replacing itwith a new, more robust implementation (e.g.#17038).

Recommendation: refactor and improve robustness & test coverage forbessel_jn. Consider deprecatinglpmn,lpmn_values, andsph_harm if they cannot be modified to more closely match thescipy APIs.

scipy.stats#

Thescipy.stats module contains a wide range of statistical functions, including discreteand continuous distributions, summary statistics, and hypothesis testing. JAX currently wrapsa number of these injax.scipy.stats, primarily including 20 or so statistical distributions,along with a few other functions (mode,rankdata,gaussian_kde). In general these arewell-aligned with JAX: distributions usually are expressible in terms of efficient XLA operations,and the APIs are clean and functional.

We don’t currently have any wrappers for hypothesis testing functions, probably becausethese are less useful to the primary user-base of JAX.

Regarding distributions, in some cases,tensorflow_probability provides similar functionality,and in the future we might consider whether to deprecate the scipy.stats distributions in favorof that implementation.

Recommendation: going forward, we should treat statistical distributions and summary statistics as in-scope, and consider hypothesis tests and related functionality generally out-of-scope.


[8]ページ先頭

©2009-2025 Movatter.jp