Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.ndarray.at

Contents

jax.numpy.ndarray.at#

abstractpropertyndarray.at[source]#

Helper property for index update functionality.

Theat property provides a functionally pure equivalent of in-placearray modifications.

In particular:

Alternate syntax

Equivalent In-place expression

x=x.at[idx].set(y)

x[idx]=y

x=x.at[idx].add(y)

x[idx]+=y

x=x.at[idx].subtract(y)

x[idx]-=y

x=x.at[idx].multiply(y)

x[idx]*=y

x=x.at[idx].divide(y)

x[idx]/=y

x=x.at[idx].power(y)

x[idx]**=y

x=x.at[idx].min(y)

x[idx]=minimum(x[idx],y)

x=x.at[idx].max(y)

x[idx]=maximum(x[idx],y)

x=x.at[idx].apply(ufunc)

ufunc.at(x,idx)

x=x.at[idx].get()

x=x[idx]

None of thex.at expressions modify the originalx; instead they returna modified copy ofx. However, inside ajit() compiled function,expressions likex=x.at[idx].set(y) are guaranteed to be applied in-place.

Unlike NumPy in-place operations such asx[idx]+=y, if multipleindices refer to the same location, all updates will be applied (NumPy wouldonly apply the last update, rather than applying all updates.) The orderin which conflicting updates are applied is implementation-defined and may benondeterministic (e.g., due to concurrency on some hardware platforms).

By default, JAX assumes that all indices are in-bounds. Alternative out-of-boundindex semantics can be specified via themode parameter (see below).

Parameters:
  • mode

    string specifying out-of-bound indexing mode. Options are:

    • "promise_in_bounds": (default) The user promises that indices are in bounds.No additional checking will be performed. In practice, this means thatout-of-bounds indices inget() will be clipped, and out-of-bounds indicesinset(),add(), etc. will be dropped.

    • "clip": clamp out of bounds indices into valid range.

    • "drop": ignore out-of-bound indices.

    • "fill": alias for"drop". Forget(), the optionalfill_valueargument specifies the value that will be returned.

    Seejax.lax.GatherScatterMode for more details.

  • wrap_negative_indices – If True (default) then negative indices indicate positionfrom the end of the array, similar to Python and NumPy indexing. If False, thennegative indices are considered out-of-bounds and behave according to themode parameter.

  • fill_value – Only applies to theget() method: the fill value to return forout-of-bounds slices whenmode is'fill'. Ignored otherwise. DefaultstoNaN for inexact types, the largest negative value for signed types, thelargest positive value for unsigned types, andTrue for booleans.

  • indices_are_sorted – If True, the implementation will assume that the (normalized)indices passed toat[] are sorted in ascending order, which can lead to moreefficient execution on some backends. If True but the indices are not actuallysorted, the output is undefined.

  • unique_indices – If True, the implementation will assume that the (normalized) indicespassed toat[] are unique, which can result in more efficient execution on somebackends. If True but the indices are not actually unique, the output is undefined.

Examples

>>>x=jnp.arange(5.0)>>>xArray([0., 1., 2., 3., 4.], dtype=float32)>>>x.at[2].get()Array(2., dtype=float32)>>>x.at[2].add(10)Array([ 0.,  1., 12.,  3.,  4.], dtype=float32)

By default, out-of-bound indices are ignored in updates, but this behaviorcan be controlled with themode parameter:

>>>x.at[10].add(10)# droppedArray([0., 1., 2., 3., 4.], dtype=float32)>>>x.at[20].add(10,mode='clip')# clippedArray([ 0.,  1.,  2.,  3., 14.], dtype=float32)

Forget(), out-of-bound indices are clipped by default:

>>>x.at[20].get()# out-of-bounds indices clippedArray(4., dtype=float32)>>>x.at[20].get(mode='fill')# out-of-bounds indices filled with NaNArray(nan, dtype=float32)>>>x.at[20].get(mode='fill',fill_value=-1)# custom fill valueArray(-1., dtype=float32)

Negative indices count from the end of the array, but this behavior canbe disabled by settingwrap_negative_indices=False:

>>>x.at[-1].set(99)Array([ 0.,  1.,  2.,  3., 99.], dtype=float32)>>>x.at[-1].set(99,wrap_negative_indices=False,mode='drop')# dropped!Array([0., 1., 2., 3., 4.], dtype=float32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp