Support for the array API standard#

Note

Array API standard support is still experimental and hidden behind anenvironment variable. Only a small part of the public API is coveredright now.

This guide describes how touse andadd support for thePython array API standard.This standard allows users to use any array API compatible array librarywith parts of SciPy out of the box.

TheRFC defines how SciPy implements support for the standard, with the mainprinciple being“array type in equals array type out”. In addition, theimplementation does more strict validation of allowed array-like inputs, e.g.rejecting numpy matrix and masked array instances, and arrays with objectdtype.

In the following, an array API compatible namespace is noted asxp.

Using array API standard support#

To enable the array API standard support, an environment variable must be setbefore importing SciPy:

exportSCIPY_ARRAY_API=1

This both enables array API standard support and the more strict inputvalidation for array-like arguments.Note that this environment variable ismeant to be temporary, as a way to make incremental changes and merge them into``main`` without affecting backwards compatibility immediately. We do notintend to keep this environment variable around long-term.

This clustering example shows usage with PyTorch tensors as inputs and returnvalues:

>>>importtorch>>>fromscipy.cluster.vqimportvq>>>code_book=torch.tensor([[1.,1.,1.],...[2.,2.,2.]])>>>features=torch.tensor([[1.9,2.3,1.7],...[1.5,2.5,2.2],...[0.8,0.6,1.7]])>>>code,dist=vq(features,code_book)>>>codetensor([1, 1, 0], dtype=torch.int32)>>>disttensor([0.4359, 0.7348, 0.8307])

Note that the above example works for PyTorch CPU tensors. For GPU tensors orCuPy arrays, the expected result forvq is aTypeError, becausevquses compiled code in its implementation, which won’t work on GPU.

More strict array input validation will rejectnp.matrix andnp.ma.MaskedArray instances, as well as arrays withobject dtype:

>>>importnumpyasnp>>>fromscipy.cluster.vqimportvq>>>code_book=np.array([[1.,1.,1.],...[2.,2.,2.]])>>>features=np.array([[1.9,2.3,1.7],...[1.5,2.5,2.2],...[0.8,0.6,1.7]])>>>vq(features,code_book)(array([1, 1, 0], dtype=int32), array([0.43588989, 0.73484692, 0.83066239]))>>># The above uses numpy arrays; trying to use np.matrix instances or object>>># arrays instead will yield an exception with `SCIPY_ARRAY_API=1`:>>>vq(np.asmatrix(features),code_book)...TypeError: 'numpy.matrix' are not supported>>>vq(np.ma.asarray(features),code_book)...TypeError: 'numpy.ma.MaskedArray' are not supported>>>vq(features.astype(np.object_),code_book)...TypeError: object arrays are not supported

Example capabilities table#

Library

CPU

GPU

NumPy

n/a

CuPy

n/a

PyTorch

JAX

⚠️ no JIT

Dask

n/a

In the example above, the feature has some support for NumPy, CuPy, PyTorch, and JAXarrays, but no support for Dask arrays. Some backends, like JAX and PyTorch, nativelysupport multiple devices (CPU and GPU), but SciPy support for such arrays may belimited; for instance, this SciPy feature is only expected to work with JAX arrayslocated on the CPU. Additionally, some backends can have major caveats; in the examplethe function will fail when running insidejax.jit.Additional caveats may be listed in the docstring of the function.

While the elements of the table marked with “n/a” are inherently out of scope, we arecontinually working on filling in the rest.Dask wrapping around backends other than NumPy (notably, CuPy) is currently out of scopebut it may change in the future.

Please seethe tracker issue for updates.

Implementation notes#

A key part of the support for the array API standard and specific compatibilityfunctions for Numpy, CuPy and PyTorch is provided througharray-api-compat.This package is included in the SciPy codebase via a git submodule (underscipy/_lib), so no new dependencies are introduced.

array-api-compat provides generic utility functions and adds aliases suchasxp.concat (which, for numpy, mapped tonp.concatenate before NumPy addednp.concat in NumPy 2.0). This allows using a uniform API across NumPy, PyTorch,CuPy and JAX (with other libraries, such as Dask, being worked on).

When the environment variable isn’t set and hence array API standard support inSciPy is disabled, we still use the wrapped version of the NumPy namespace,which isarray_api_compat.numpy. That should not change behavior of SciPyfunctions, as it’s effectively the existingnumpy namespace with a number ofaliases added and a handful of functions amended/added for array API standardsupport. When support is enabled,xp=array_namespace(input) willbe the standard-compatible namespace matching the input array type to afunction (e.g., if the input tocluster.vq.kmeans is a PyTorch tensor, thenxp isarray_api_compat.torch).

Adding array API standard support to a SciPy function#

As much as possible, new code added to SciPy should try to follow as closely aspossible the array API standard (these functions typically are best-practiceidioms for NumPy usage as well). By following the standard, effectively addingsupport for the array API standard is typically straightforward, and we ideallydon’t need to maintain any customization.

Various helper functions are available inscipy._lib._array_api - please seethe__all__ in that module for a list of current helpers, and their docstringsfor more information.

To add support to a SciPy function which is defined in a.py file, what youhave to change is:

  1. Input array validation,

  2. Usingxp rathernp functions,

  3. When calling into compiled code, convert the array to a NumPy array beforeand convert it back to the input array type after.

Input array validation uses the following pattern:

xp=array_namespace(arr)# where arr is the input array# alternatively, if there are multiple array inputs, include them all:xp=array_namespace(arr1,arr2)# replace np.asarray with xp.asarrayarr=xp.asarray(arr)# uses of non-standard parameters of np.asarray can be replaced with _asarrayarr=_asarray(arr,order='C',dtype=xp.float64,xp=xp)

Note that if one input is a non-NumPy array type, all array-like inputs have tobe of that type; trying to mix non-NumPy arrays with lists, Python scalars orother arbitrary Python objects will raise an exception. For NumPy arrays, thosetypes will continue to be accepted for backwards compatibility reasons.

If a function calls into a compiled code just once, use the following pattern:

x=np.asarray(x)# convert to numpy right before compiled call(s)y=_call_compiled_code(x)y=xp.asarray(y)# convert back to original array type

If there are multiple calls to compiled code, ensure doing the conversion justonce to avoid too much overhead.

Here is an example for a hypothetical public SciPy functiontoto:

deftoto(a,b):a=np.asarray(a)b=np.asarray(b,copy=True)c=np.sum(a)-np.prod(b)# this is some C or Cython calld=cdist(c)returnd

You would convert this like so:

deftoto(a,b):xp=array_namespace(a,b)a=xp.asarray(a)b=xp_copy(b,xp=xp)# our custom helper is needed for copyc=xp.sum(a)-xp.prod(b)# this is some C or Cython callc=np.asarray(c)d=cdist(c)d=xp.asarray(d)returnd

Going through compiled code requires going back to a NumPy array, becauseSciPy’s extension modules only work with NumPy arrays (or memoryviews in thecase of Cython). For arrays on CPU, theconversions should be zero-copy, while on GPU and other devices the attempt atconversion will raise an exception. The reason for that is that silent datatransfer between devices is considered bad practice, as it is likely to be alarge and hard-to-detect performance bottleneck.

Adding tests#

To run a test on multiple array backends, you should add thexp fixture to it,which is valued to the currently tested array namespace.

The following pytest markers are available:

  • skip_xp_backends(backend=None,reason=None,np_only=False,cpu_only=False,eager_only=False,exceptions=None):skip certain backends or categories of backends.See docstring ofscipy.conftest.skip_or_xfail_xp_backends for information on howto use this marker to skip tests.

  • xfail_xp_backends(backend=None,reason=None,np_only=False,cpu_only=False,eager_only=False,exceptions=None):xfail certain backends or categories of backends.See docstring ofscipy.conftest.skip_or_xfail_xp_backends for information on howto use this marker to xfail tests.

  • skip_xp_invalid_arg is used to skip tests that use arguments whichare invalid whenSCIPY_ARRAY_API is enabled. For instance, some tests ofscipy.stats functions pass masked arrays to the function being tested, butmasked arrays are incompatible with the array API. Use of theskip_xp_invalid_arg decorator allows these tests to protect againstregressions whenSCIPY_ARRAY_API is not used without resulting in failureswhenSCIPY_ARRAY_API is used. In time, we will want these functions to emitdeprecation warnings when they receive array API invalid input, and thisdecorator will check that the deprecation warning is emitted without itcausing the test to fail. WhenSCIPY_ARRAY_API=1 behavior becomes thedefault and only behavior, these tests (and the decorator itself) will beremoved.

  • array_api_backends: this marker is automatically added by thexp fixture toall tests that use it. This is useful e.g. to select all and only such tests:

    spintest-ball-marray_api_backends

scipy._lib._array_api contains array-agnostic assertions such asxp_assert_closewhich can be used to replace assertions fromnumpy.testing.

When these assertions are executed within a test that uses thexp fixture, theyenforce that the namespaces of both the actual and desired arrays match the namespacewhich was set by the fixture. Tests without thexp fixture infer the namespace fromthe desired array. This machinery can be overridden by explicitly passing thexp=parameter to the assertion functions.

The following examples demonstrate how to use the markers:

fromscipy.conftestimportskip_xp_invalid_argfromscipy._lib._array_apiimportxp_assert_close...@pytest.mark.skip_xp_backends(np_only=True,reason='skip reason')deftest_toto1(self,xp):a=xp.asarray([1,2,3])b=xp.asarray([0,2,5])xp_assert_close(toto(a,b),a)...@pytest.mark.skip_xp_backends('array_api_strict',reason='skip reason 1')@pytest.mark.skip_xp_backends('cupy',reason='skip reason 2')deftest_toto2(self,xp):......# Do not run when SCIPY_ARRAY_API is used@skip_xp_invalid_argdeftest_toto_masked_array(self):...

Passing names of backends intoexceptions means that they will not be skippedbycpu_only=True oreager_only=True. This is useful when delegationis implemented for some, but not all, non-CPU backends, and the CPU code pathrequires conversion to NumPy for compiled code:

# array-api-strict and CuPy will always be skipped, for the given reasons.# All libraries using a non-CPU device will also be skipped, apart from# JAX, for which delegation is implemented (hence non-CPU execution is supported).@pytest.mark.skip_xp_backends(cpu_only=True,exceptions=['jax.numpy'])@pytest.mark.skip_xp_backends('array_api_strict',reason='skip reason 1')@pytest.mark.skip_xp_backends('cupy',reason='skip reason 2')deftest_toto(self,xp):...

After applying these markers,spintest can be used with the new option-b or--array-api-backend:

spintest-bnumpy-btorch-scluster

This automatically setsSCIPY_ARRAY_API appropriately. To test a librarythat has multiple devices with a non-default device, a second environmentvariable (SCIPY_DEVICE, only used in the test suite) can be set. Validvalues depend on the array library under test, e.g. for PyTorch, valid values are"cpu","cuda","mps". To run the test suite with the PyTorch MPSbackend, use:SCIPY_DEVICE=mpsspintest-btorch.

Note that there is a GitHub Actions workflow which tests with array-api-strict,PyTorch, and JAX on CPU.

Testing Practice#

It’s important that for any supported functionf, there exist tests usingthexp fixture that restrict use of alternative backends to only the functionf being tested. Other functions evaluated within a test, for the purpose ofproducing reference values, inputs, round-trip calculations, etc. should insteaduse the NumPy backend. This helps ensure that any failures that occur on a backendactually relate to the function of interest, and avoids the need to skip backendsdue to lack of support for functions other thanf. Property based integrationtests which check that some invariant holds using the same alternative backendacross different functions can also have value, giving a window into the generalhealth of backend support for a module, but in order to ensure the test suiteactually reflects the state of backend support for each function, it’s vital tohave tests which isolate use of the alternative backend only to the function beingtested.

To help facilitate such backend isolation, there is a function_xp_copy_to_numpyinscipy._lib._array_api which can copy an arbitraryxp array to a NumPyarray, bypassing any device transfer guards, while preserving dtypes. It is essentialthat this function is only used in tests for functions other than the one beingtested. Attempts to copy a device array to NumPy outside of tests should fail,because otherwise it can become opaque whether a function is working on GPU or not.

When attempting to isolate use of alternative backends to a particular function, onemust be mindful that PyTorch allows for setting a default dtype, and SciPy is testedwith both default dtypefloat32 andfloat64 (this is controlled with theenvironment variableSCIPY_DEFAULT_DTYPE). Tests using thexp fixture rely onxp.asarray producing arrays with the default dtype when list input is given andno explicit dtype specified. This means that if a test involves taking input arraysand passing them to a function other than the one being tested in order to produceinputs for the function being tested, the following may appear natural to writebut would not produce the correct dtype behavior:

# z, p, k will have dtype float64 regardless of the value of# SCIPY_DEFAULT_DTYPEz=np.asarray([1j,-1j,2j,-2j])p=np.asarray([1+1j,3-100j,3+100j,1-1j])k=23# np.poly will preserve dtypeb=k*np.poly(z_np).reala=np.poly(p_np).real# Input arrays z, p, and reference outputs b, a will all have# dtype float64.z,p,b,a=map(xp.asarray,(z,p,b,a))# With float64 inputs, the outputs bp and ap will be of dtype# float64. Note that the parameter k is a Python scalar which does# not impact output dtype for NumPy >= 2.0.bp,ap=zpk2tf(z,p,k)# xp_assert_close checks for matching dtype. Due to the way the# code was written above, zpk2tf is not tested with float32 inputs# when SCIPY_DEFAULT_DTYPE is float32.xp_assert_close(b,bp)xp_assert_close(a,ap)

One could instead construct all inputs asxp arrays and then copy toNumPy arrays in order to ensure the default dtype is respected:

# calls to xp.asarray will respect the default dtype.z=xp.asarray([1j,-1j,2j,-2j])p=xp.asarray([1+1j,3-100j,3+100j,1-1j])k=23# _xp_copy_to_numpy preserves dtype, as does np.poly.b=k*np.poly(_xp_copy_to_numpy(z)).reala=np.poly(_xp_copy_to_numpy(p)).real# b and a will have dtype float32b,a=map(xp.asarray,(b,a))# zpk2tf is tested with float32 inputs when SCIPY_DEFAULT_DTYPE=float32# as intended.bp,ap=zpk2tf(z,p,k)xp_assert_close(b,bp)xp_assert_close(a,ap)

Testing the JAX JIT compiler#

TheJAX JIT compilerintroduces special restrictions to all code wrapped by@jax.jit, which are notpresent when running JAX in eager mode. Notably, boolean masks in__getitem__andat aren’t supported, and you can’t materialize the arrays by applyingbool(),float(),np.asarray() etc. to them.

To properly test scipy with JAX, you need to wrap the tested scipy functionswith@jax.jit before they are called by the unit tests.To achieve this, you should tag them as follows in your test module:

fromscipy._lib.array_api_extra.testingimportlazy_xp_functionfromscipy.mymoduleimporttotolazy_xp_function(toto)deftest_toto(xp):a=xp.asarray([1,2,3])b=xp.asarray([0,2,5])# When xp==jax.numpy, toto is wrapped with @jax.jitxp_assert_close(toto(a,b),a)

See full documentationhere.

Additional information#

Here are some additional resources which motivated some design decisions andhelped during the development phase:

  • InitialPR with some discussions

  • Quick started from thisPR andsome inspiration taken fromscikit-learn.

  • PR adding ArrayAPI support to scikit-learn

  • Some other relevant scikit-learn PRs:#22554 and#25956

API Coverage#

The below tables show the current state of alternative backend support acrossSciPy’s modules. Currently only public functions and function-like callableobjects are included in the tables, but it is planned to eventually also includerelevant public classes. Functions which are deemed out-of-scope are excludedfrom consideration. If a module or submodule contains no in-scope functions, itis excluded from the tables. For example,scipy.spatial.transform is currentlyexcluded because it’s API contains no functions, but may be included in the futurewhen the scope expands to include classes.scipy.odr andscipy.datasets are excludedbecause their contents are considered out-of-scope.

There is not yet a formal policy for which functions should be consideredout-of-scope for alternative backend support. Some general rules of thumbthat are being followed are to exclude:

  • functions which do not operate on arrays such asscipy.constants.value

  • functions which are too implementation specific such as those inscipy.linalg.blas which give direct wrappers to low-level BLAS routines.

  • functions which would inherently be very difficult or even impossible to compute efficiently on accelerated computing devices.

As an example. The contents ofscipy.odr are considered out-of-scope for acombination of reasons 2 and 3 above.scipy.odr essentially provides a directwrapper of the monolithic ODRPACK Fortran library, and it’s API is tied to thestructure of this monolithic library. Creation of an efficient GPU acceleratedimplementation of nonlinear weighted orthogonal distance regression is also achallenging problem in its own right. Nevertheless, considerations of what toconsider in-scope are evolving, and something which is now considered out-of-scopemay be decided to be in-scope in the future if sufficient user interest andfeasability are demonstrated.

Note

The coverage percentages shown below may be below thetrue values due to alternative backend support being added for some functionsbefore the infrastructure for registering this support was developed. Thissituation is denoted by placing asterisks next to the percentages.Documentation of alternative backend support is currently a work in progress.

Support on CPU#

module

torch

jax

dask

cluster.vq (4)

100%

100%

100%

cluster.hierarchy (29)

97%

97%

97%

constants (3)

100%

100%

100%

differentiate (3)

100%

100%

0%

fft (32)

100%

94%

100%

integrate (19)

37%

21%

26%

interpolate (14)

43%*

43%*

43%*

io (9)

0%*

0%*

0%*

linalg (95)

3%*

3%*

2%*

linalg.interpolative (9)

0%*

0%*

0%*

ndimage (73)

100%

100%

100%

optimize (57)

7%*

4%*

7%*

optimize.elementwise (4)

75%

0%

0%

signal (144)

62%

56%

60%

signal.windows (26)

96%

88%

92%

sparse (35)

0%*

0%*

0%*

sparse.linalg (32)

0%*

0%*

0%*

sparse.csgraph (25)

0%*

0%*

0%*

spatial (9)

0%*

0%*

0%*

spatial.distance (27)

0%*

0%*

0%*

special (340)

28%*

28%*

28%*

stats (132)

53%

48%

35%

stats.contingency (7)

0%*

0%*

0%*

stats.qmc (4)

0%*

0%*

0%*

Support on GPU#

module

cupy

torch

jax

cluster.vq (4)

25%

25%

25%

cluster.hierarchy (29)

28%

28%

28%

constants (3)

100%

100%

100%

differentiate (3)

100%

100%

100%

fft (32)

75%

75%

75%

integrate (19)

42%

37%

21%

interpolate (14)

0%*

0%*

0%*

io (9)

0%*

0%*

0%*

linalg (95)

3%*

3%*

3%*

linalg.interpolative (9)

0%*

0%*

0%*

ndimage (73)

93%

0%

1%

optimize (57)

7%*

7%*

4%*

optimize.elementwise (4)

100%

75%

0%

signal (144)

67%

33%

19%

signal.windows (26)

96%

96%

88%

sparse (35)

0%*

0%*

0%*

sparse.linalg (32)

0%*

0%*

0%*

sparse.csgraph (25)

0%*

0%*

0%*

spatial (9)

0%*

0%*

0%*

spatial.distance (27)

0%*

0%*

0%*

special (340)

28%*

12%*

12%*

stats (132)

42%

42%

43%

stats.contingency (7)

0%*

0%*

0%*

stats.qmc (4)

0%*

0%*

0%*

Support with JIT#

module

jax

cluster.vq (4)

25%

cluster.hierarchy (29)

79%

constants (3)

100%

differentiate (3)

0%

fft (32)

94%

integrate (19)

11%

interpolate (14)

0%*

io (9)

0%*

linalg (95)

1%*

linalg.interpolative (9)

0%*

ndimage (73)

1%

optimize (57)

4%*

optimize.elementwise (4)

0%

signal (144)

20%

signal.windows (26)

88%

sparse (35)

0%*

sparse.linalg (32)

0%*

sparse.csgraph (25)

0%*

spatial (9)

0%*

spatial.distance (27)

0%*

special (340)

12%*

stats (132)

27%

stats.contingency (7)

0%*

stats.qmc (4)

0%*