jax.numpy.permute_dims
Contents
jax.numpy.permute_dims#
- jax.numpy.permute_dims(a,/,axes)[source]#
Permute the axes/dimensions of an array.
JAX implementation of
array_api.permute_dims().- Parameters:
- Returns:
a copy of
awith axes permuted.- Return type:
Examples
>>>a=jnp.array([[1,2,3],...[4,5,6]])>>>jnp.permute_dims(a,(1,0))Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
Contents
