jax.numpy.ndarray.at
Contents
jax.numpy.ndarray.at#
- abstractpropertyndarray.at[source]#
Helper property for index update functionality.
The
atproperty provides a functionally pure equivalent of in-placearray modifications.In particular:
Alternate syntax
Equivalent In-place expression
x=x.at[idx].set(y)x[idx]=yx=x.at[idx].add(y)x[idx]+=yx=x.at[idx].subtract(y)x[idx]-=yx=x.at[idx].multiply(y)x[idx]*=yx=x.at[idx].divide(y)x[idx]/=yx=x.at[idx].power(y)x[idx]**=yx=x.at[idx].min(y)x[idx]=minimum(x[idx],y)x=x.at[idx].max(y)x[idx]=maximum(x[idx],y)x=x.at[idx].apply(ufunc)ufunc.at(x,idx)x=x.at[idx].get()x=x[idx]None of the
x.atexpressions modify the originalx; instead they returna modified copy ofx. However, inside ajit()compiled function,expressions likex=x.at[idx].set(y)are guaranteed to be applied in-place.Unlike NumPy in-place operations such as
x[idx]+=y, if multipleindices refer to the same location, all updates will be applied (NumPy wouldonly apply the last update, rather than applying all updates.) The orderin which conflicting updates are applied is implementation-defined and may benondeterministic (e.g., due to concurrency on some hardware platforms).By default, JAX assumes that all indices are in-bounds. Alternative out-of-boundindex semantics can be specified via the
modeparameter (see below).- Parameters:
mode –
string specifying out-of-bound indexing mode. Options are:
"promise_in_bounds": (default) The user promises that indices are in bounds.No additional checking will be performed. In practice, this means thatout-of-bounds indices inget()will be clipped, and out-of-bounds indicesinset(),add(), etc. will be dropped."clip": clamp out of bounds indices into valid range."drop": ignore out-of-bound indices."fill": alias for"drop". Forget(), the optionalfill_valueargument specifies the value that will be returned.
See
jax.lax.GatherScatterModefor more details.wrap_negative_indices – If True (default) then negative indices indicate positionfrom the end of the array, similar to Python and NumPy indexing. If False, thennegative indices are considered out-of-bounds and behave according to the
modeparameter.fill_value – Only applies to the
get()method: the fill value to return forout-of-bounds slices whenmodeis'fill'. Ignored otherwise. DefaultstoNaNfor inexact types, the largest negative value for signed types, thelargest positive value for unsigned types, andTruefor booleans.indices_are_sorted – If True, the implementation will assume that the (normalized)indices passed to
at[]are sorted in ascending order, which can lead to moreefficient execution on some backends. If True but the indices are not actuallysorted, the output is undefined.unique_indices – If True, the implementation will assume that the (normalized) indicespassed to
at[]are unique, which can result in more efficient execution on somebackends. If True but the indices are not actually unique, the output is undefined.
Examples
>>>x=jnp.arange(5.0)>>>xArray([0., 1., 2., 3., 4.], dtype=float32)>>>x.at[2].get()Array(2., dtype=float32)>>>x.at[2].add(10)Array([ 0., 1., 12., 3., 4.], dtype=float32)
By default, out-of-bound indices are ignored in updates, but this behaviorcan be controlled with the
modeparameter:>>>x.at[10].add(10)# droppedArray([0., 1., 2., 3., 4.], dtype=float32)>>>x.at[20].add(10,mode='clip')# clippedArray([ 0., 1., 2., 3., 14.], dtype=float32)
For
get(), out-of-bound indices are clipped by default:>>>x.at[20].get()# out-of-bounds indices clippedArray(4., dtype=float32)>>>x.at[20].get(mode='fill')# out-of-bounds indices filled with NaNArray(nan, dtype=float32)>>>x.at[20].get(mode='fill',fill_value=-1)# custom fill valueArray(-1., dtype=float32)
Negative indices count from the end of the array, but this behavior canbe disabled by setting
wrap_negative_indices=False:>>>x.at[-1].set(99)Array([ 0., 1., 2., 3., 99.], dtype=float32)>>>x.at[-1].set(99,wrap_negative_indices=False,mode='drop')# dropped!Array([0., 1., 2., 3., 4.], dtype=float32)
