jax.numpy.ix_
Contents
jax.numpy.ix_#
- jax.numpy.ix_(*args)[source]#
Return a multi-dimensional grid (open mesh) from N one-dimensional sequences.
JAX implementation of
numpy.ix_().- Parameters:
*args (ArrayLike) – N one-dimensional arrays
- Returns:
Tuple of Jax arrays forming an open mesh, each with N dimensions.
- Return type:
Examples
>>>rows=jnp.array([0,2])>>>cols=jnp.array([1,3])>>>open_mesh=jnp.ix_(rows,cols)>>>open_mesh(Array([[0], [2]], dtype=int32), Array([[1, 3]], dtype=int32))>>>[grid.shapeforgridinopen_mesh][(2, 1), (1, 2)]>>>x=jnp.array([[10,20,30,40],...[50,60,70,80],...[90,100,110,120],...[130,140,150,160]])>>>x[open_mesh]Array([[ 20, 40], [100, 120]], dtype=int32)
Contents
