jax.experimental.sparse module
Contents
jax.experimental.sparse module#
Note
The methods injax.experimental.sparse are experimental reference implementations,and not recommended for use in performance-critical applications. The submodule is nolonger being actively developed, but the team will continue supporting existing featuresas best we can.
Thejax.experimental.sparse module includes experimental support for sparse matrixoperations in JAX. The primary interfaces made available are theBCOO sparse arraytype, and thesparsify() transform.
Batched-coordinate (BCOO) sparse matrices#
The main high-level sparse object currently available in JAX is theBCOO,orbatched coordinate sparse array, which offers a compressed storage format compatiblewith JAX transformations, in particular JIT (e.g.jax.jit()), batching(e.g.jax.vmap()) and autodiff (e.g.jax.grad()).
Here is an example of creating a sparse array from a dense array:
>>>fromjax.experimentalimportsparse>>>importjax.numpyasjnp>>>importnumpyasnp
>>>M=jnp.array([[0.,1.,0.,2.],...[3.,0.,0.,0.],...[0.,0.,4.,0.]])
>>>M_sp=sparse.BCOO.fromdense(M)
>>>M_spBCOO(float32[3, 4], nse=4)
Convert back to a dense array with thetodense() method:
>>>M_sp.todense()Array([[0., 1., 0., 2.], [3., 0., 0., 0.], [0., 0., 4., 0.]], dtype=float32)
The BCOO format is a somewhat modified version of the standard COO format, and the denserepresentation can be seen in thedata andindices attributes:
>>>M_sp.data# Explicitly stored dataArray([1., 2., 3., 4.], dtype=float32)
>>>M_sp.indices# Indices of the stored dataArray([[0, 1], [0, 3], [1, 0], [2, 2]], dtype=int32)
BCOO objects have familiar array-like attributes, as well as sparse-specific attributes:
>>>M_sp.ndim2
>>>M_sp.shape(3, 4)
>>>M_sp.dtypedtype('float32')
>>>M_sp.nse# "number of specified elements"4
BCOO objects also implement a number of array-like methods, to allow you to use themdirectly within jax programs. For example, here we compute the transposed matrix-vectorproduct:
>>>y=jnp.array([3.,6.,5.])
>>>M_sp.T@yArray([18., 3., 20., 6.], dtype=float32)
>>>M.T@y# Compare to dense versionArray([18., 3., 20., 6.], dtype=float32)
BCOO objects are designed to be compatible with JAX transforms, includingjax.jit(),jax.vmap(),jax.grad(), and others. For example:
>>>fromjaximportgrad,jit
>>>deff(y):...return(M_sp.T@y).sum()...>>>jit(grad(f))(y)Array([3., 3., 4.], dtype=float32)
Note, however, that under normal circumstancesjax.numpy andjax.lax functionsdo not know how to handle sparse matrices, so attempting to compute things likejnp.dot(M_sp.T,y) will result in an error (however, see the next section).
Sparsify transform#
An overarching goal of the JAX sparse implementation is to provide a means to switch fromdense to sparse computation seamlessly, without having to modify the dense implementation.This sparse experiment accomplishes this through thesparsify() transform.
Consider this function, which computes a more complicated result from a matrix and a vector input:
>>>deff(M,v):...return2*jnp.dot(jnp.log1p(M.T),v)+1...>>>f(M,y)Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
Were we to pass a sparse matrix to this directly, it would result in an error, becausejnpfunctions do not recognize sparse inputs. However, withsparsify(), we get a version ofthis function that does accept sparse matrices:
>>>f_sp=sparse.sparsify(f)
>>>f_sp(M_sp,y)Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
Support forsparsify() includes a large number of the most common primitives, including:
generalized (batched) matrix products & einstein summations (
dot_general_p)zero-preserving elementwise binary operations (e.g.
add_p,mul_p, etc.)zero-preserving elementwise unary operations (e.g.
abs_p,jax.lax.neg_p, etc.)summation reductions (
reduce_sum_p)general indexing operations (
slice_p,lax.dynamic_slice_p,lax.gather_p)concatenation and stacking (
concatenate_p)transposition & reshaping ((
transpose_p,reshape_p,squeeze_p,broadcast_in_dim_p)some higher-order functions (
cond_p,while_p,scan_p)some simple 1D convolutions (
conv_general_dilated_p)
Nearly anyjax.numpy function that lowers to these supported primitives can be usedwithin a sparsify transform to operate on sparse arrays. This set of primitives is enoughto enable relatively sophisticated sparse workflows, as the next section will show.
Example: sparse logistic regression#
As an example of a more complicated sparse workflow, let’s consider a simple logistic regressionimplemented in JAX. Notice that the following implementation has no reference to sparsity:
>>>importfunctools>>>fromsklearn.datasetsimportmake_classification>>>fromjax.scipyimportoptimize
>>>defsigmoid(x):...return0.5*(jnp.tanh(x/2)+1)...>>>defy_model(params,X):...returnsigmoid(jnp.dot(X,params[1:])+params[0])...>>>defloss(params,X,y):...y_hat=y_model(params,X)...return-jnp.mean(y*jnp.log(y_hat)+(1-y)*jnp.log(1-y_hat))...>>>deffit_logreg(X,y):...params=jnp.zeros(X.shape[1]+1)...result=optimize.minimize(functools.partial(loss,X=X,y=y),...x0=params,method='BFGS')...returnresult.x
>>>X,y=make_classification(n_classes=2,random_state=1701)>>>params_dense=fit_logreg(X,y)>>>print(params_dense)[-0.7298445 0.29893667 1.0248291 -0.44436368 0.8785025 -0.7724008 -0.62893456 0.2934014 0.82974285 0.16838408 -0.39774987 -0.5071844 0.2028872 0.5227761 -0.3739224 -0.7104083 2.4212713 0.6310087 -0.67060554 0.03139788 -0.05359547]
This returns the best-fit parameters of a dense logistic regression problem.To fit the same model on sparse data, we can apply thesparsify() transform:
>>>Xsp=sparse.BCOO.fromdense(X)# Sparse version of the input>>>fit_logreg_sp=sparse.sparsify(fit_logreg)# Sparse-transformed fit function>>>params_sparse=fit_logreg_sp(Xsp,y)>>>print(params_sparse)[-0.72971725 0.29878938 1.0246326 -0.44430563 0.8784217 -0.77225566 -0.6288222 0.29335397 0.8293481 0.16820715 -0.39764675 -0.5069753 0.202579 0.522672 -0.3740134 -0.7102678 2.4209507 0.6310593 -0.670236 0.03132951 -0.05356663]
Sparse API Reference#
| Experimental sparsification transform. |
| Sparse-aware version of |
| Sparse-aware version of |
| Create an empty sparse array. |
| Create 2D sparse identity matrix. |
| Convert input to a dense matrix. |
| Generate a random BCOO matrix. |
| Base class for high-level JAX sparse objects. |
BCOO Data Structure#
BCOO is theBatched COO format, and is the main sparse data structureimplemented injax.experimental.sparse.Its operations are compatible with JAX’s core transformations, including batching(e.g.jax.vmap()) and autodiff (e.g.jax.grad()).
| Experimental batched COO matrix implemented in JAX |
| Expand the size and rank of a BCOO array by duplicating the data. |
| Sparse implementation of |
| A general contraction operation. |
| A contraction operation with output computed at given sparse indices. |
| Sparse implementation of |
| Extract values from a dense array according to the sparse array's indices. |
| Create BCOO-format sparse matrix from a dense matrix. |
| BCOO version of lax.gather. |
| An element-wise multiplication between a sparse and a dense array. |
| An element-wise multiplication of two sparse arrays. |
| Update the storage layout (i.e. n_batch & n_dense) of a BCOO matrix. |
| Sum array element over given axes. |
| Sparse implementation of |
| Sparse implementation of |
| Sort indices of a BCOO array. |
| Sparse implementation of |
| Sums duplicate indices within a BCOO array, returning an array with sorted indices. |
| Convert batched sparse matrix to a dense matrix. |
| Transpose a BCOO-format array. |
BCSR Data Structure#
BCSR is theBatched Compressed Sparse Row format, and is under development.Its operations are compatible with JAX’s core transformations, including batching(e.g.jax.vmap()) and autodiff (e.g.jax.grad()).
| Experimental batched CSR matrix implemented in JAX. |
| A general contraction operation. |
| Extract values from a dense matrix at given BCSR (indices, indptr). |
| Create BCSR-format sparse matrix from a dense matrix. |
| Convert batched sparse matrix to a dense matrix. |
Other Sparse Data Structures#
Other sparse data structures includeCOO,CSR, andCSC. These arereference implementations of simple sparse structures with a few core operations implemented.Their operations are generally compatible with autodiff transformations such asjax.grad(),but not with batching transforms likejax.vmap().
| Experimental COO matrix implemented in JAX. |
| Experimental CSC matrix implemented in JAX; API subject to change. |
| Experimental CSR matrix implemented in JAX. |
| Create a COO-format sparse matrix from a dense matrix. |
| Product of COO sparse matrix and a dense matrix. |
| Product of COO sparse matrix and a dense vector. |
| Convert a COO-format sparse matrix to a dense matrix. |
| Create a CSR-format sparse matrix from a dense matrix. |
| Product of CSR sparse matrix and a dense matrix. |
| Product of CSR sparse matrix and a dense vector. |
| Convert a CSR-format sparse matrix to a dense matrix. |
jax.experimental.sparse.linalg#
Sparse linear algebra routines.
| A sparse direct solver using QR factorization. |
| Compute the top-k standard eigenvalues using the LOBPCG routine. |
