Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.place

Contents

jax.numpy.place#

jax.numpy.place(arr,mask,vals,*,inplace=True)[source]#

Update array elements based on a mask.

JAX implementation ofnumpy.place().

The semantics ofnumpy.place() are to modify arrays in-place, whichis not possible for JAX’s immutable arrays. The JAX version returns a modifiedcopy of the input, and adds theinplace parameter which must be set toFalse` by the user as a reminder of this API difference.

Parameters:
Returns:

A copy ofarr with masked values set to entries fromvals.

Return type:

Array

See also

Examples

>>>x=jnp.zeros((3,5),dtype=int)>>>mask=(jnp.arange(x.size)%3==0).reshape(x.shape)>>>maskArray([[ True, False, False,  True, False],       [False,  True, False, False,  True],       [False, False,  True, False, False]], dtype=bool)

Placing a scalar value:

>>>jnp.place(x,mask,1,inplace=False)Array([[1, 0, 0, 1, 0],       [0, 1, 0, 0, 1],       [0, 0, 1, 0, 0]], dtype=int32)

In this case,jnp.place is similar to the masked array update syntax:

>>>x.at[mask].set(1)Array([[1, 0, 0, 1, 0],       [0, 1, 0, 0, 1],       [0, 0, 1, 0, 0]], dtype=int32)

place differs when placing values from an array. The array is repeatedto fill the masked entries:

>>>vals=jnp.array([1,3,5])>>>jnp.place(x,mask,vals,inplace=False)Array([[1, 0, 0, 3, 0],       [0, 5, 0, 0, 1],       [0, 0, 3, 0, 0]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp