jax.numpy.insert
Contents
jax.numpy.insert#
- jax.numpy.insert(arr,obj,values,axis=None)[source]#
Insert entries into an array at specified indices.
JAX implementation of
numpy.insert().- Parameters:
arr (ArrayLike) – array object into which values will be inserted.
obj (ArrayLike |slice) – slice or array of indices specifying insertion locations.
values (ArrayLike) – array of values to be inserted.
axis (int |None) – specify the insertion axis in the case of multi-dimensionalarrays. If unspecified,
arrwill be flattened.
- Returns:
A copy of
arrwith values inserted at the specified locations.- Return type:
See also
jax.numpy.delete(): delete entries from an array.
Examples
Inserting a single value:
>>>x=jnp.arange(5)>>>jnp.insert(x,2,99)Array([ 0, 1, 99, 2, 3, 4], dtype=int32)
Inserting multiple identical values using a slice:
>>>jnp.insert(x,slice(None,None,2),-1)Array([-1, 0, 1, -1, 2, 3, -1, 4], dtype=int32)
Inserting multiple values using an index:
>>>indices=jnp.array([4,2,5])>>>values=jnp.array([10,11,12])>>>jnp.insert(x,indices,values)Array([ 0, 1, 11, 2, 3, 10, 4, 12], dtype=int32)
Inserting columns into a 2D array:
>>>x=jnp.array([[1,2,3],...[4,5,6]])>>>indices=jnp.array([1,3])>>>values=jnp.array([[10,11],...[12,13]])>>>jnp.insert(x,indices,values,axis=1)Array([[ 1, 10, 2, 3, 11], [ 4, 12, 5, 6, 13]], dtype=int32)
Contents
