Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.ops.segment_max

Contents

jax.ops.segment_max#

jax.ops.segment_max(data,segment_ids,num_segments=None,indices_are_sorted=False,unique_indices=False,bucket_size=None,mode=None)[source]#

Computes the maximum within segments of an array.

Similar to TensorFlow’ssegment_max

Parameters:
  • data (ArrayLike) – an array with the values to be reduced.

  • segment_ids (ArrayLike) – an array with integer dtype that indicates the segments ofdata (along its leading axis) to be reduced. Values can be repeated andneed not be sorted.

  • num_segments (int |None) – optional, an int with nonnegative value indicating the numberof segments. The default is set to be the minimum number of segments thatwould support all indices insegment_ids, calculated asmax(segment_ids)+1.Sincenum_segments determines the size of the output, a static valuemust be provided to usesegment_max in a JIT-compiled function.

  • indices_are_sorted (bool) – whethersegment_ids is known to be sorted.

  • unique_indices (bool) – whethersegment_ids is known to be free of duplicates.

  • bucket_size (int |None) – size of bucket to group indices into.segment_max isperformed on each bucket separately. DefaultNone means no bucketing.

  • mode (slicing.GatherScatterMode |str |None) – ajax.lax.GatherScatterMode value describing howout-of-bounds indices should be handled. By default, values outside of therange [0, num_segments) are dropped and do not contribute to the result.

Returns:

An array with shape(num_segments,)+data.shape[1:] representing thesegment maximums.

Return type:

Array

Examples

Simple 1D segment max:

>>>data=jnp.arange(6)>>>segment_ids=jnp.array([0,0,1,1,2,2])>>>segment_max(data,segment_ids)Array([1, 3, 5], dtype=int32)

Using JIT requires staticnum_segments:

>>>fromjaximportjit>>>jit(segment_max,static_argnums=2)(data,segment_ids,3)Array([1, 3, 5], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp