Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.setdiff1d

Contents

jax.numpy.setdiff1d#

jax.numpy.setdiff1d(ar1,ar2,assume_unique=False,*,size=None,fill_value=None)[source]#

Compute the set difference of two 1D arrays.

JAX implementation ofnumpy.setdiff1d().

Because the size of the output ofsetdiff1d is data-dependent, the functionis not typically compatible withjit() and other JAX transformations.The JAX version adds the optionalsize argument which must be specified staticallyforjnp.setdiff1d to be used in such contexts.

Parameters:
  • ar1 (ArrayLike) – first array of elements to be differenced.

  • ar2 (ArrayLike) – second array of elements to be differenced.

  • assume_unique (bool) – if True, assume the input arrays contain unique values. This allowsa more efficient implementation, but ifassume_unique is True and the inputarrays contain duplicates, the behavior is undefined. default: False.

  • size (int |None) – if specified, return only the firstsize sorted elements. If there are fewerelements thansize indicates, the return value will be padded withfill_value.

  • fill_value (ArrayLike |None) – whensize is specified and there are fewer than the indicated number ofelements, fill the remaining entriesfill_value. Defaults to the minimum value.

Returns:

i.e. the elementsinar1 that are not contained inar2.

Return type:

an array containing the set difference of elements in the input array

See also

Examples

Computing the set difference of two arrays:

>>>ar1=jnp.array([1,2,3,4])>>>ar2=jnp.array([3,4,5,6])>>>jnp.setdiff1d(ar1,ar2)Array([1, 2], dtype=int32)

Because the output shape is dynamic, this will fail underjit() and othertransformations:

>>>jax.jit(jnp.setdiff1d)(ar1,ar2)Traceback (most recent call last):...ConcretizationTypeError:Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.

In order to ensure statically-known output shapes, you can pass a staticsizeargument:

>>>jit_setdiff1d=jax.jit(jnp.setdiff1d,static_argnames=['size'])>>>jit_setdiff1d(ar1,ar2,size=2)Array([1, 2], dtype=int32)

Ifsize is too small, the difference is truncated:

>>>jit_setdiff1d(ar1,ar2,size=1)Array([1], dtype=int32)

Ifsize is too large, then the output is padded withfill_value:

>>>jit_setdiff1d(ar1,ar2,size=4,fill_value=0)Array([1, 2, 0, 0], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp