jax.numpy.transpose
Contents
jax.numpy.transpose#
- jax.numpy.transpose(a,axes=None)[source]#
Return a transposed version of an N-dimensional array.
JAX implementation of
numpy.transpose(), implemented in terms ofjax.lax.transpose().- Parameters:
a (ArrayLike) – input array
axes (Sequence[int]|None) – optionally specify the permutation using a length-a.ndim sequence of integers
isatisfying0<=i<a.ndim. Defaults torange(a.ndim)[::-1], i.e.reverses the order of all axes.
- Returns:
transposed copy of the array.
- Return type:
See also
jax.Array.transpose(): equivalent function via anArraymethod.jax.Array.T: equivalent function via anArrayproperty.jax.numpy.matrix_transpose(): transpose the last two axes of an array. This issuitable for working with batched 2D matrices.jax.numpy.swapaxes(): swap any two axes in an array.jax.numpy.moveaxis(): move an axis to another position in the array.
Note
Unlike
numpy.transpose(),jax.numpy.transpose()will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize-awaysuch copies when possible, so this doesn’t have performance impacts in practice.Examples
For a 1D array, the transpose is the identity:
>>>x=jnp.array([1,2,3,4])>>>jnp.transpose(x)Array([1, 2, 3, 4], dtype=int32)
For a 2D array, the transpose is a matrix transpose:
>>>x=jnp.array([[1,2],...[3,4]])>>>jnp.transpose(x)Array([[1, 3], [2, 4]], dtype=int32)
For an N-dimensional array, the transpose reverses the order of the axes:
>>>x=jnp.zeros(shape=(3,4,5))>>>jnp.transpose(x).shape(5, 4, 3)
The
axesargument can be specified to change this default behavior:>>>jnp.transpose(x,(0,2,1)).shape(3, 5, 4)
Since swapping the last two axes is a common operation, it can be donevia its own API,
jax.numpy.matrix_transpose():>>>jnp.matrix_transpose(x).shape(3, 5, 4)
For convenience, transposes may also be performed using the
jax.Array.transpose()method or thejax.Array.Tproperty:>>>x=jnp.array([[1,2],...[3,4]])>>>x.transpose()Array([[1, 3], [2, 4]], dtype=int32)>>>x.TArray([[1, 3], [2, 4]], dtype=int32)
