Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.squeeze

Contents

jax.numpy.squeeze#

jax.numpy.squeeze(a,axis=None)[source]#

Remove one or more length-1 axes from array

JAX implementation ofnumpy.sqeeze(), implemented viajax.lax.squeeze().

Parameters:
  • a (ArrayLike) – input array

  • axis (int |Sequence[int]|None) – integer or sequence of integers specifying axes to remove. If any specifiedaxis does not have a length of 1, an error is raised. If not specified, squeezeall length-1 axes ina.

Returns:

copy ofa with length-1 axes removed.

Return type:

Array

Notes

Unlikenumpy.squeeze(),jax.numpy.squeeze() will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize-awaysuch copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>>x=jnp.array([[[0]],[[1]],[[2]]])>>>x.shape(3, 1, 1)

Squeeze all length-1 dimensions:

>>>jnp.squeeze(x)Array([0, 1, 2], dtype=int32)>>>_.shape(3,)

Equivalent while specifying the axes explicitly:

>>>jnp.squeeze(x,axis=(1,2))Array([0, 1, 2], dtype=int32)

Attempting to squeeze a non-unit axis results in an error:

>>>jnp.squeeze(x,axis=0)Traceback (most recent call last):...ValueError:cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)

For convenience, this functionality is also available via thejax.Array.squeeze() method:

>>>x.squeeze()Array([0, 1, 2], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp