12.1.Array API support (experimental)#

TheArray API specification definesa standard API for all array manipulation libraries with a NumPy-like API.Scikit-learn vendors pinned copies ofarray-api-compatandarray-api-extra.

Scikit-learn’s support for the array API standard requires the environment variableSCIPY_ARRAY_API to be set to1 before importingscipy andscikit-learn:

exportSCIPY_ARRAY_API=1

Please note that this environment variable is intended for temporary use.For more details, refer to SciPy’sArray API documentation.

Some scikit-learn estimators that primarily rely on NumPy (as opposed to usingCython) to implement the algorithmic logic of theirfit,predict ortransform methods can be configured to accept any Array API compatible inputdata structures and automatically dispatch operations to the underlying namespaceinstead of relying on NumPy.

At this stage, this support isconsidered experimental and must be enabledexplicitly by thearray_api_dispatch configuration. See below for details.

Note

Currently, onlyarray-api-strict,cupy, andPyTorch are known to workwith scikit-learn’s estimators.

The following video provides an overview of the standard’s design principlesand how it facilitates interoperability between array libraries:

12.1.1.Enabling array API support#

The configurationarray_api_dispatch=True needs to be set toTrue to enable arrayAPI support. We recommend setting this configuration globally to ensure consistentbehaviour and prevent accidental mixing of array namespaces.Note that in the examples below, we use a context manager (config_context)to avoid having to reset it toFalse at the end of every code snippet, so as tonot affect the rest of the documentation.

Scikit-learn acceptsarray-like inputs for allmetricsand some estimators. Whenarray_api_dispatch=False, these inputs are convertedinto NumPy arrays usingnumpy.asarray (ornumpy.array).While this will successfully convert some array API inputs (e.g., JAX array),we generally recommend settingarray_api_dispatch=True when using array API inputs.This is because NumPy conversion can often fail, e.g., torch tensor allocated on GPU.

12.1.2.Example usage#

The example code snippet below demonstrates how to useCuPy to runLinearDiscriminantAnalysis on a GPU:

>>>fromsklearn.datasetsimportmake_classification>>>fromsklearnimportconfig_context>>>fromsklearn.discriminant_analysisimportLinearDiscriminantAnalysis>>>importcupy>>>X_np,y_np=make_classification(random_state=0)>>>X_cu=cupy.asarray(X_np)>>>y_cu=cupy.asarray(y_np)>>>X_cu.device<CUDA Device 0>>>>withconfig_context(array_api_dispatch=True):...lda=LinearDiscriminantAnalysis()...X_trans=lda.fit_transform(X_cu,y_cu)>>>X_trans.device<CUDA Device 0>

After the model is trained, fitted attributes that are arrays will also befrom the same Array API namespace as the training data. For example, if CuPy’sArray API namespace was used for training, then fitted attributes will be on theGPU. We provide an experimental_estimator_with_converted_arrays utility thattransfers an estimator attributes from Array API to an ndarray:

>>>fromsklearn.utils._array_apiimport_estimator_with_converted_arrays>>>cupy_to_ndarray=lambdaarray:array.get()>>>lda_np=_estimator_with_converted_arrays(lda,cupy_to_ndarray)>>>X_trans=lda_np.transform(X_np)>>>type(X_trans)<class 'numpy.ndarray'>

12.1.2.1.PyTorch Support#

PyTorch Tensors can also be passed directly:

>>>importtorch>>>X_torch=torch.asarray(X_np,device="cuda",dtype=torch.float32)>>>y_torch=torch.asarray(y_np,device="cuda",dtype=torch.float32)>>>withconfig_context(array_api_dispatch=True):...lda=LinearDiscriminantAnalysis()...X_trans=lda.fit_transform(X_torch,y_torch)>>>type(X_trans)<class 'torch.Tensor'>>>>X_trans.device.type'cuda'

12.1.3.Support forArrayAPI-compatible inputs#

Estimators and other tools in scikit-learn that support Array API compatible inputs.

12.1.3.1.Estimators#

12.1.3.2.Meta-estimators#

Meta-estimators that accept Array API inputs conditioned on the fact that thebase estimator also does:

12.1.3.3.Metrics#

12.1.3.4.Tools#

Coverage is expected to grow over time. Please follow the dedicatedmeta-issue on GitHub to track progress.

12.1.4.Input and output array type handling#

Estimators and scoring functions are able to accept input arraysfrom different array libraries and/or devices. When a mixed set of input arrays ispassed, scikit-learn converts arrays as needed to make them all consistent.

For estimators, the rule is“everything followsX - mixed array inputs areconverted so that they all match the array library and device ofX.For scoring functions the rule is“everything followsy_pred - mixed arrayinputs are converted so that they all match the array library and device ofy_pred.

When a function or method has been called with array API compatible inputs, theconvention is to return arrays from the same array library and on the samedevice as the input data.

12.1.4.1.Estimators#

When an estimator is fitted with an array API compatibleX, all otherarray inputs, including constructor arguments, (e.g.,y,sample_weight)will be converted to match the array library and device ofX, if they do not already.This behaviour enables switching from processing on the CPU to processingon the GPU at any point within a pipeline.

This allows estimators to accept mixed input types, enablingX to be movedto a different device within a pipeline, without explicitly movingy.Note that scikit-learn pipelines do not allow transformation ofy (to avoidleakage).

Take for example a pipeline whereX andy both start on CPU, and go throughthe following three steps:

X initially contains categorical string data (thus needs to be on CPU), which istarget encoded to numerical values inTargetEncoder.X is then explicitly moved to GPU to improve the performance ofRidge.y cannot be transformed by the pipeline(recall scikit-learn pipelines do not allow transformation ofy) but asRidge is able to accept mixed input types,this is not a problem and the pipeline is able to be run.

The fitted attributes of an estimator fitted with an array API compatibleX, willbe arrays from the same library as the input and stored on the same device.Thepredict andtransform method subsequently expectinputs from the same array library and device as the data passed to thefitmethod.

12.1.4.2.Scoring functions#

When an array API compatibley_pred is passed to a scoring function,all other array inputs (e.g.,y_true,sample_weight) will be convertedto match the array library and device ofy_pred, if they do not already.This allows scoring functions to accept mixed input types, enabling them to beused within ameta-estimator (or function that accepts estimators), with apipeline that moves input arrays between devices (e.g., CPU to GPU).

For example, to be able to use the pipeline described above within e.g.,cross_validate orGridSearchCV, the scoring function internallycalled needs to be able to accept mixed input types.

The output type of scoring functions depends on the number of output values.When a scoring function returns a scalar value, it will return a Pythonscalar (typically afloat instance) instead of an array scalar value.For scoring functions that supportmulticlass ormultioutput,an array from the same array library and device asy_pred will be returned whenmultiple values need to be output.

12.1.5.Common estimator checks#

Add thearray_api_support tag to an estimator’s set of tags to indicate thatit supports the array API. This will enable dedicated checks as part of thecommon tests to verify that the estimators’ results are the same when usingvanilla NumPy and array API inputs.

To run these checks you need to installarray-api-strict in yourtest environment. This allows you to run checks without having aGPU. To run the full set of checks you also need to installPyTorch,CuPy and havea GPU. Checks that can not be executed or have missing dependencies will beautomatically skipped. Therefore it’s important to run the tests with the-v flag to see which checks are skipped:

pipinstallarray-api-strict# and other libraries as neededpytest-k"array_api"-v

Running the scikit-learn tests againstarray-api-strict should help revealmost code problems related to handling multiple device inputs via the use ofsimulated non-CPU devices. This allows for fast iterative development and debugging ofarray API related code.

However, to ensure full handling of PyTorch or CuPy inputs allocated on actual GPUdevices, it is necessary to run the tests against those libraries and hardware.This can either be achieved by usingGoogle Colabor leveraging our CI infrastructure on pull requests (manually triggered by maintainersfor cost reasons).

12.1.5.1.Note on MPS device support#

On macOS, PyTorch can use the Metal Performance Shaders (MPS) to accesshardware accelerators (e.g. the internal GPU component of the M1 or M2 chips).However, the MPS device support for PyTorch is incomplete at the time ofwriting. See the following github issue for more details:

To enable the MPS support in PyTorch, set the environment variablePYTORCH_ENABLE_MPS_FALLBACK=1 before running the tests:

PYTORCH_ENABLE_MPS_FALLBACK=1pytest-k"array_api"-v

At the time of writing all scikit-learn tests should pass, however, thecomputational speed is not necessarily better than with the CPU device.

12.1.5.2.Note on device support forfloat64#

Certain operations within scikit-learn will automatically perform operationson floating-point values withfloat64 precision to prevent overflows and ensurecorrectness (e.g.,metrics.pairwise.euclidean_distances,preprocessing.StandardScaler). However,certain combinations of array namespaces and devices, such asPyTorchonMPS(seeNote on MPS device support) do not support thefloat64 data type. In these cases,scikit-learn will revert to using thefloat32 data type instead. This can result indifferent behavior (typically numerically unstable results) compared to not using arrayAPI dispatching or using a device withfloat64 support.