jax.numpy.trim_zeros
Contents
jax.numpy.trim_zeros#
- jax.numpy.trim_zeros(filt,trim='fb',axis=None)[source]#
Trim leading and/or trailing zeros of the input array.
JAX implementation of
numpy.trim_zeros().- Parameters:
filt (ArrayLike) – N-dimensional input array.
trim (str) –
string, optional, default =
fb. Specifies from which end the inputis trimmed.f- trims only the leading zeros.b- trims only the trailing zeros.fb- trims both leading and trailing zeros.
axis (int |Sequence[int]|None) – optional axis or axes along which to trim. If not specified, trim alongall axes of the array.
- Returns:
An array containing the trimmed input with same dtype as
filt.- Return type:
Examples
One-dimensional input:
>>>x=jnp.array([0,0,2,0,1,4,3,0,0,0])>>>jnp.trim_zeros(x)Array([2, 0, 1, 4, 3], dtype=int32)>>>jnp.trim_zeros(x,trim='f')Array([2, 0, 1, 4, 3, 0, 0, 0], dtype=int32)>>>jnp.trim_zeros(x,trim='b')Array([0, 0, 2, 0, 1, 4, 3], dtype=int32)
Two-dimensional input:
>>>x=jnp.zeros((4,5)).at[1:3,1:4].set(1)>>>xArray([[0., 0., 0., 0., 0.], [0., 1., 1., 1., 0.], [0., 1., 1., 1., 0.], [0., 0., 0., 0., 0.]], dtype=float32)>>>jnp.trim_zeros(x)Array([[1., 1., 1.], [1., 1., 1.]], dtype=float32)>>>jnp.trim_zeros(x,trim='f')Array([[1., 1., 1., 0.], [1., 1., 1., 0.], [0., 0., 0., 0.]], dtype=float32)>>>jnp.trim_zeros(x,axis=0)Array([[0., 1., 1., 1., 0.], [0., 1., 1., 1., 0.]], dtype=float32)>>>jnp.trim_zeros(x,axis=1)Array([[0., 0., 0.], [1., 1., 1.], [1., 1., 1.], [0., 0., 0.]], dtype=float32)
Contents
