Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.put

Contents

jax.numpy.put#

jax.numpy.put(a,ind,v,mode=None,*,inplace=True)[source]#

Put elements into an array at given indices.

JAX implementation ofnumpy.put().

The semantics ofnumpy.put() are to modify arrays in-place, whichis not possible for JAX’s immutable arrays. The JAX version returns a modifiedcopy of the input, and adds theinplace parameter which must be set toFalse` by the user as a reminder of this API difference.

Parameters:
  • a (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array into which values will be placed.

  • ind (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array of indices over the flattened array at which to put values.

  • v (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array of values to put into the array.

  • mode (str |None) –

    string specifying how to handle out-of-bound indices. Supported values:

    • "clip" (default): clip out-of-bound indices to the final index.

    • "wrap": wrap out-of-bound indices to the beginning of the array.

  • inplace (bool) – must be set to False to indicate that the input is not modifiedin-place, but rather a modified copy is returned.

Returns:

A copy ofa with specified entries updated.

Return type:

Array

See also

Examples

>>>x=jnp.zeros(5,dtype=int)>>>indices=jnp.array([0,2,4])>>>values=jnp.array([10,20,30])>>>jnp.put(x,indices,values,inplace=False)Array([10,  0, 20,  0, 30], dtype=int32)

This is equivalent to the followingjax.numpy.ndarray.at indexing syntax:

>>>x.at[indices].set(values)Array([10,  0, 20,  0, 30], dtype=int32)

There are two modes for handling out-of-bound indices. By default they areclipped:

>>>indices=jnp.array([0,2,6])>>>jnp.put(x,indices,values,inplace=False,mode='clip')Array([10,  0, 20,  0, 30], dtype=int32)

Alternatively, they can be wrapped to the beginning of the array:

>>>jnp.put(x,indices,values,inplace=False,mode='wrap')Array([10,  30, 20,  0, 0], dtype=int32)

For N-dimensional inputs, the indices refer to the flattened array:

>>>x=jnp.zeros((3,5),dtype=int)>>>indices=jnp.array([0,7,14])>>>jnp.put(x,indices,values,inplace=False)Array([[10,  0,  0,  0,  0],       [ 0,  0, 20,  0,  0],       [ 0,  0,  0,  0, 30]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp