Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.intersect1d

Contents

jax.numpy.intersect1d#

jax.numpy.intersect1d(ar1,ar2,assume_unique=False,return_indices=False,*,size=None,fill_value=None)[source]#

Compute the set intersection of two 1D arrays.

JAX implementation ofnumpy.intersect1d().

Because the size of the output ofintersect1d is data-dependent, the functionis not typically compatible withjit() and other JAX transformations.The JAX version adds the optionalsize argument which must be specifiedstatically forjnp.intersect1d to be used in such contexts.

Parameters:
  • ar1 (ArrayLike) – first array of values to intersect.

  • ar2 (ArrayLike) – second array of values to intersect.

  • assume_unique (bool) – if True, assume the input arrays contain unique values. This allowsa more efficient implementation, but ifassume_unique is True and the inputarrays contain duplicates, the behavior is undefined. default: False.

  • return_indices (bool) – If True, return arrays of indices specifying where the intersectedvalues first appear in the input arrays.

  • size (int |None) – if specified, return only the firstsize sorted elements. If there are fewerelements thansize indicates, the return value will be padded withfill_value,and returned indices will be padded with an out-of-bound index.

  • fill_value (ArrayLike |None) – whensize is specified and there are fewer than the indicated number ofelements, fill the remaining entriesfill_value. Defaults to the smallest valuein the intersection.

Returns:

An arrayintersection, or ifreturn_indices=True, a tuple of arrays(intersection,ar1_indices,ar2_indices). Returned values are

  • intersection:A 1D array containing each value that appears in bothar1 andar2.

  • ar1_indices:(returned if return_indices=True) an array of shapeintersection.shape containingthe indices in flattenedar1 of values inintersection. For 1D inputs,intersection is equivalent toar1[ar1_indices].

  • ar2_indices:(returned if return_indices=True) an array of shapeintersection.shape containingthe indices in flattenedar2 of values inintersection. For 1D inputs,intersection is equivalent toar2[ar2_indices].

Return type:

Array |tuple[Array,Array,Array]

See also

Examples

>>>ar1=jnp.array([1,2,3,4])>>>ar2=jnp.array([3,4,5,6])>>>jnp.intersect1d(ar1,ar2)Array([3, 4], dtype=int32)

Computing intersection with indices:

>>>intersection,ar1_indices,ar2_indices=jnp.intersect1d(ar1,ar2,return_indices=True)>>>intersectionArray([3, 4], dtype=int32)

ar1_indices gives the indices of the intersected values withinar1:

>>>ar1_indicesArray([2, 3], dtype=int32)>>>jnp.all(intersection==ar1[ar1_indices])Array(True, dtype=bool)

ar2_indices gives the indices of the intersected values withinar2:

>>>ar2_indicesArray([0, 1], dtype=int32)>>>jnp.all(intersection==ar2[ar2_indices])Array(True, dtype=bool)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp