jax.numpy.delete
Contents
jax.numpy.delete#
- jax.numpy.delete(arr,obj,axis=None,*,assume_unique_indices=False)[source]#
Delete entry or entries from an array.
JAX implementation of
numpy.delete().- Parameters:
arr (ArrayLike) – array from which entries will be deleted.
obj (ArrayLike |slice) – index, indices, or slice to be deleted.
axis (int |None) – axis along which entries will be deleted.
assume_unique_indices (bool) – In case of array-like integer (not boolean) indices,assume the indices are unique, and perform the deletion in a way that iscompatible with JIT and other JAX transformations.
- Returns:
Copy of
arrwith specified indices deleted.- Return type:
Note
delete()usually requires the index specification to be static. If theindex is an integer array that is guaranteed to contain unique entries, youmay specifyassume_unique_indices=Trueto perform the operation in amanner that does not require static indices.See also
jax.numpy.insert(): insert entries into an array.
Examples
Delete entries from a 1D array:
>>>a=jnp.array([4,5,6,7,8,9])>>>jnp.delete(a,2)Array([4, 5, 7, 8, 9], dtype=int32)>>>jnp.delete(a,slice(1,4))# delete a[1:4]Array([4, 8, 9], dtype=int32)>>>jnp.delete(a,slice(None,None,2))# delete a[::2]Array([5, 7, 9], dtype=int32)
Delete entries from a 2D array along a specified axis:
>>>a2=jnp.array([[4,5,6],...[7,8,9]])>>>jnp.delete(a2,1,axis=1)Array([[4, 6], [7, 9]], dtype=int32)
Delete multiple entries via a sequence of indices:
>>>indices=jnp.array([0,1,3])>>>jnp.delete(a,indices)Array([6, 8, 9], dtype=int32)
This will fail under
jit()and other transformations, becausethe output shape cannot be known with the possibility of duplicate indices:>>>jax.jit(jnp.delete)(a,indices)Traceback (most recent call last):...ConcretizationTypeError:Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3].
If you can ensure that the indices are unique, pass
assume_unique_indicesto allow this to be executed under JIT:>>>jit_delete=jax.jit(jnp.delete,static_argnames=['assume_unique_indices'])>>>jit_delete(a,indices,assume_unique_indices=True)Array([6, 8, 9], dtype=int32)
