Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.transpose

Contents

jax.numpy.transpose#

jax.numpy.transpose(a,axes=None)[source]#

Return a transposed version of an N-dimensional array.

JAX implementation ofnumpy.transpose(), implemented in terms ofjax.lax.transpose().

Parameters:
  • a (ArrayLike) – input array

  • axes (Sequence[int]|None) – optionally specify the permutation using a length-a.ndim sequence of integersi satisfying0<=i<a.ndim. Defaults torange(a.ndim)[::-1], i.e.reverses the order of all axes.

Returns:

transposed copy of the array.

Return type:

Array

See also

Note

Unlikenumpy.transpose(),jax.numpy.transpose() will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize-awaysuch copies when possible, so this doesn’t have performance impacts in practice.

Examples

For a 1D array, the transpose is the identity:

>>>x=jnp.array([1,2,3,4])>>>jnp.transpose(x)Array([1, 2, 3, 4], dtype=int32)

For a 2D array, the transpose is a matrix transpose:

>>>x=jnp.array([[1,2],...[3,4]])>>>jnp.transpose(x)Array([[1, 3],       [2, 4]], dtype=int32)

For an N-dimensional array, the transpose reverses the order of the axes:

>>>x=jnp.zeros(shape=(3,4,5))>>>jnp.transpose(x).shape(5, 4, 3)

Theaxes argument can be specified to change this default behavior:

>>>jnp.transpose(x,(0,2,1)).shape(3, 5, 4)

Since swapping the last two axes is a common operation, it can be donevia its own API,jax.numpy.matrix_transpose():

>>>jnp.matrix_transpose(x).shape(3, 5, 4)

For convenience, transposes may also be performed using thejax.Array.transpose()method or thejax.Array.T property:

>>>x=jnp.array([[1,2],...[3,4]])>>>x.transpose()Array([[1, 3],       [2, 4]], dtype=int32)>>>x.TArray([[1, 3],       [2, 4]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp