jax.numpy.put
Contents
jax.numpy.put#
- jax.numpy.put(a,ind,v,mode=None,*,inplace=True)[source]#
Put elements into an array at given indices.
JAX implementation of
numpy.put().The semantics of
numpy.put()are to modify arrays in-place, whichis not possible for JAX’s immutable arrays. The JAX version returns a modifiedcopy of the input, and adds theinplaceparameter which must be set toFalse` by the user as a reminder of this API difference.- Parameters:
a (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array into which values will be placed.
ind (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array of indices over the flattened array at which to put values.
v (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array of values to put into the array.
mode (str |None) –
string specifying how to handle out-of-bound indices. Supported values:
"clip"(default): clip out-of-bound indices to the final index."wrap": wrap out-of-bound indices to the beginning of the array.
inplace (bool) – must be set to False to indicate that the input is not modifiedin-place, but rather a modified copy is returned.
- Returns:
A copy of
awith specified entries updated.- Return type:
See also
jax.numpy.place(): place elements into an array via boolean mask.jax.numpy.ndarray.at(): array updates using NumPy-style indexing.jax.numpy.take(): extract values from an array at given indices.
Examples
>>>x=jnp.zeros(5,dtype=int)>>>indices=jnp.array([0,2,4])>>>values=jnp.array([10,20,30])>>>jnp.put(x,indices,values,inplace=False)Array([10, 0, 20, 0, 30], dtype=int32)
This is equivalent to the following
jax.numpy.ndarray.atindexing syntax:>>>x.at[indices].set(values)Array([10, 0, 20, 0, 30], dtype=int32)
There are two modes for handling out-of-bound indices. By default they areclipped:
>>>indices=jnp.array([0,2,6])>>>jnp.put(x,indices,values,inplace=False,mode='clip')Array([10, 0, 20, 0, 30], dtype=int32)
Alternatively, they can be wrapped to the beginning of the array:
>>>jnp.put(x,indices,values,inplace=False,mode='wrap')Array([10, 30, 20, 0, 0], dtype=int32)
For N-dimensional inputs, the indices refer to the flattened array:
>>>x=jnp.zeros((3,5),dtype=int)>>>indices=jnp.array([0,7,14])>>>jnp.put(x,indices,values,inplace=False)Array([[10, 0, 0, 0, 0], [ 0, 0, 20, 0, 0], [ 0, 0, 0, 0, 30]], dtype=int32)
