jax.numpy.union1d
Contents
jax.numpy.union1d#
- jax.numpy.union1d(ar1,ar2,*,size=None,fill_value=None)[source]#
Compute the set union of two 1D arrays.
JAX implementation of
numpy.union1d().Because the size of the output of
union1dis data-dependent, the functionis not typically compatible withjit()and other JAX transformations.The JAX version adds the optionalsizeargument which must be specifiedstatically forjnp.union1dto be used in such contexts.- Parameters:
ar1 (ArrayLike) – first array of elements to be unioned.
ar2 (ArrayLike) – second array of elements to be unioned
size (int |None) – if specified, return only the first
sizesorted elements. If there are fewerelements thansizeindicates, the return value will be padded withfill_value.fill_value (ArrayLike |None) – when
sizeis specified and there are fewer than the indicated number ofelements, fill the remaining entriesfill_value. Defaults to the minimum value.
- Returns:
an array containing the union of elements in the input array.
- Return type:
See also
jax.numpy.intersect1d(): the set intersection of two 1D arrays.jax.numpy.setxor1d(): the set XOR of two 1D arrays.jax.numpy.setdiff1d(): the set difference of two 1D arrays.
Examples
Computing the union of two arrays:
>>>ar1=jnp.array([1,2,3,4])>>>ar2=jnp.array([3,4,5,6])>>>jnp.union1d(ar1,ar2)Array([1, 2, 3, 4, 5, 6], dtype=int32)
Because the output shape is dynamic, this will fail under
jit()and othertransformations:>>>jax.jit(jnp.union1d)(ar1,ar2)Traceback (most recent call last):...ConcretizationTypeError:Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].The error occurred while tracing the function union1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
In order to ensure statically-known output shapes, you can pass a static
sizeargument:>>>jit_union1d=jax.jit(jnp.union1d,static_argnames=['size'])>>>jit_union1d(ar1,ar2,size=6)Array([1, 2, 3, 4, 5, 6], dtype=int32)
If
sizeis too small, the union is truncated:>>>jit_union1d(ar1,ar2,size=4)Array([1, 2, 3, 4], dtype=int32)
If
sizeis too large, then the output is padded withfill_value:>>>jit_union1d(ar1,ar2,size=8,fill_value=0)Array([1, 2, 3, 4, 5, 6, 0, 0], dtype=int32)
