jax.ops.segment_sum
Contents
jax.ops.segment_sum#
- jax.ops.segment_sum(data,segment_ids,num_segments=None,indices_are_sorted=False,unique_indices=False,bucket_size=None,mode=None)[source]#
Computes the sum within segments of an array.
Similar to TensorFlow’ssegment_sum
- Parameters:
data (ArrayLike) – an array with the values to be summed.
segment_ids (ArrayLike) – an array with integer dtype that indicates the segments ofdata (along its leading axis) to be summed. Values can be repeated andneed not be sorted.
num_segments (int |None) – optional, an int with nonnegative value indicating the numberof segments. The default is set to be the minimum number of segments thatwould support all indices in
segment_ids, calculated asmax(segment_ids)+1.Sincenum_segments determines the size of the output, a static valuemust be provided to usesegment_sumin a JIT-compiled function.indices_are_sorted (bool) – whether
segment_idsis known to be sorted.unique_indices (bool) – whethersegment_ids is known to be free of duplicates.
bucket_size (int |None) – size of bucket to group indices into.
segment_sumisperformed on each bucket separately to improve numerical stability ofaddition. DefaultNonemeans no bucketing.mode (slicing.GatherScatterMode |str |None) – a
jax.lax.GatherScatterModevalue describing howout-of-bounds indices should be handled. By default, values outside of therange [0, num_segments) are dropped and do not contribute to the sum.
- Returns:
An array with shape
(num_segments,)+data.shape[1:]representing thesegment sums.- Return type:
Examples
Simple 1D segment sum:
>>>data=jnp.arange(5)>>>segment_ids=jnp.array([0,0,1,1,2])>>>segment_sum(data,segment_ids)Array([1, 5, 4], dtype=int32)
Using JIT requires staticnum_segments:
>>>fromjaximportjit>>>jit(segment_sum,static_argnums=2)(data,segment_ids,3)Array([1, 5, 4], dtype=int32)
