jax.numpy.meshgrid
Contents
jax.numpy.meshgrid#
- jax.numpy.meshgrid(*xi,copy=True,sparse=False,indexing='xy')[source]#
Construct N-dimensional grid arrays from N 1-dimensional vectors.
JAX implementation of
numpy.meshgrid().- Parameters:
xi (ArrayLike) – N arrays to convert to a grid.
copy (bool) – whether to copy the input arrays. JAX supports only
copy=True,though under JIT compilation the compiler may opt to avoid copies.sparse (bool) – if False (default), then each returned arrays will be of shape
[len(x1),len(x2),...,len(xN)]. If False, then returned arrayswill be of shape[1,1,...,len(xi),...,1,1].indexing (str) – options are
'xy'for cartesian indexing (default) or'ij'for matrix indexing.
- Returns:
A length-N list of grid arrays.
- Return type:
See also
jax.numpy.indices(): generate a grid of indices.jax.numpy.mgrid: create a meshgrid using indexing syntax.jax.numpy.ogrid: create an open meshgrid using indexing syntax.
Examples
For the following examples, we’ll use these 1D arrays as inputs:
>>>x=jnp.array([1,2])>>>y=jnp.array([10,20,30])
2D cartesian mesh grid:
>>>x_grid,y_grid=jnp.meshgrid(x,y)>>>print(x_grid)[[1 2] [1 2] [1 2]]>>>print(y_grid)[[10 10] [20 20] [30 30]]
2D sparse cartesian mesh grid:
>>>x_grid,y_grid=jnp.meshgrid(x,y,sparse=True)>>>print(x_grid)[[1 2]]>>>print(y_grid)[[10] [20] [30]]
2D matrix-index mesh grid:
>>>x_grid,y_grid=jnp.meshgrid(x,y,indexing='ij')>>>print(x_grid)[[1 1 1] [2 2 2]]>>>print(y_grid)[[10 20 30] [10 20 30]]
