Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.delete

Contents

jax.numpy.delete#

jax.numpy.delete(arr,obj,axis=None,*,assume_unique_indices=False)[source]#

Delete entry or entries from an array.

JAX implementation ofnumpy.delete().

Parameters:
  • arr (ArrayLike) – array from which entries will be deleted.

  • obj (ArrayLike |slice) – index, indices, or slice to be deleted.

  • axis (int |None) – axis along which entries will be deleted.

  • assume_unique_indices (bool) – In case of array-like integer (not boolean) indices,assume the indices are unique, and perform the deletion in a way that iscompatible with JIT and other JAX transformations.

Returns:

Copy ofarr with specified indices deleted.

Return type:

Array

Note

delete() usually requires the index specification to be static. If theindex is an integer array that is guaranteed to contain unique entries, youmay specifyassume_unique_indices=True to perform the operation in amanner that does not require static indices.

See also

Examples

Delete entries from a 1D array:

>>>a=jnp.array([4,5,6,7,8,9])>>>jnp.delete(a,2)Array([4, 5, 7, 8, 9], dtype=int32)>>>jnp.delete(a,slice(1,4))# delete a[1:4]Array([4, 8, 9], dtype=int32)>>>jnp.delete(a,slice(None,None,2))# delete a[::2]Array([5, 7, 9], dtype=int32)

Delete entries from a 2D array along a specified axis:

>>>a2=jnp.array([[4,5,6],...[7,8,9]])>>>jnp.delete(a2,1,axis=1)Array([[4, 6],       [7, 9]], dtype=int32)

Delete multiple entries via a sequence of indices:

>>>indices=jnp.array([0,1,3])>>>jnp.delete(a,indices)Array([6, 8, 9], dtype=int32)

This will fail underjit() and other transformations, becausethe output shape cannot be known with the possibility of duplicate indices:

>>>jax.jit(jnp.delete)(a,indices)Traceback (most recent call last):...ConcretizationTypeError:Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3].

If you can ensure that the indices are unique, passassume_unique_indicesto allow this to be executed under JIT:

>>>jit_delete=jax.jit(jnp.delete,static_argnames=['assume_unique_indices'])>>>jit_delete(a,indices,assume_unique_indices=True)Array([6, 8, 9], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp