jax.numpy.mgrid
Contents
jax.numpy.mgrid#
- jax.numpy.mgrid=<jax._src.numpy.index_tricks._Mgridobject>#
Return dense multi-dimensional “meshgrid”.
LAX-backend implementation of
numpy.mgrid. This is a convenience wrapper forfunctionality provided byjax.numpy.meshgrid()withsparse=False.See also
jnp.ogrid: open/sparse version of jnp.mgrid
Examples
Pass
[start:stop:step]to generate values similar tojax.numpy.arange():>>>jnp.mgrid[0:4:1]Array([0, 1, 2, 3], dtype=int32)
Passing an imaginary step generates values similar to
jax.numpy.linspace():>>>jnp.mgrid[0:1:4j]Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
Multiple slices can be used to create broadcasted grids of indices:
>>>jnp.mgrid[:2,:3]Array([[[0, 0, 0], [1, 1, 1]], [[0, 1, 2], [0, 1, 2]]], dtype=int32)
Contents
