Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.reshape

Contents

jax.numpy.reshape#

jax.numpy.reshape(a,shape,order='C',*,copy=None,out_sharding=None)[source]#

Return a reshaped copy of an array.

JAX implementation ofnumpy.reshape(), implemented in terms ofjax.lax.reshape().

Parameters:
  • a (ArrayLike) – input array to reshape

  • shape (DimSize |Shape) – integer or sequence of integers giving the new shape, which must match thesize of the input array. If any single dimension is given size-1, it will bereplaced with a value such that the output has the correct size.

  • order (str) –'F' or'C', specifies whether the reshape should apply column-major(fortran-style,"F") or row-major (C-style,"C") order; default is"C".JAX does not supportorder="A".

  • copy (bool |None) – unused by JAX; JAX always returns a copy, though under JIT the compilermay optimize such copies away.

Returns:

reshaped copy of input array with the specified shape.

Return type:

Array

Notes

Unlikenumpy.reshape(),jax.numpy.reshape() will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize-awaysuch copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>>x=jnp.array([[1,2,3],...[4,5,6]])>>>jnp.reshape(x,6)Array([1, 2, 3, 4, 5, 6], dtype=int32)>>>jnp.reshape(x,(3,2))Array([[1, 2],       [3, 4],       [5, 6]], dtype=int32)

You can use-1 to automatically compute a shape that is consistent withthe input size:

>>>jnp.reshape(x,-1)# -1 is inferred to be 6Array([1, 2, 3, 4, 5, 6], dtype=int32)>>>jnp.reshape(x,(-1,2))# -1 is inferred to be 3Array([[1, 2],       [3, 4],       [5, 6]], dtype=int32)

The default ordering of axes in the reshape is C-style row-major ordering.To use Fortran-style column-major ordering, specifyorder='F':

>>>jnp.reshape(x,6,order='F')Array([1, 4, 2, 5, 3, 6], dtype=int32)>>>jnp.reshape(x,(3,2),order='F')Array([[1, 5],       [4, 3],       [2, 6]], dtype=int32)

For convenience, this functionality is also available via thejax.Array.reshape() method:

>>>x.reshape(3,2)Array([[1, 2],       [3, 4],       [5, 6]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp