jax.numpy.squeeze
Contents
jax.numpy.squeeze#
- jax.numpy.squeeze(a,axis=None)[source]#
Remove one or more length-1 axes from array
JAX implementation of
numpy.sqeeze(), implemented viajax.lax.squeeze().- Parameters:
- Returns:
copy of
awith length-1 axes removed.- Return type:
Notes
Unlike
numpy.squeeze(),jax.numpy.squeeze()will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize-awaysuch copies when possible, so this doesn’t have performance impacts in practice.See also
jax.numpy.expand_dims(): the inverse ofsqueeze: add dimensions of length 1.jax.Array.squeeze(): equivalent functionality via an array method.jax.lax.squeeze(): equivalent XLA API.jax.numpy.ravel(): flatten an array into a 1D shape.jax.numpy.reshape(): general array reshape.
Examples
>>>x=jnp.array([[[0]],[[1]],[[2]]])>>>x.shape(3, 1, 1)
Squeeze all length-1 dimensions:
>>>jnp.squeeze(x)Array([0, 1, 2], dtype=int32)>>>_.shape(3,)
Equivalent while specifying the axes explicitly:
>>>jnp.squeeze(x,axis=(1,2))Array([0, 1, 2], dtype=int32)
Attempting to squeeze a non-unit axis results in an error:
>>>jnp.squeeze(x,axis=0)Traceback (most recent call last):...ValueError:cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)
For convenience, this functionality is also available via the
jax.Array.squeeze()method:>>>x.squeeze()Array([0, 1, 2], dtype=int32)
