Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.atleast_3d

Contents

jax.numpy.atleast_3d#

jax.numpy.atleast_3d(*arys)[source]#

Convert inputs to arrays with at least 3 dimensions.

JAX implementation ofnumpy.atleast_3d().

Parameters:
  • arguments. (zero ormore arraylike)

  • arys (ArrayLike)

Returns:

an array or list of arrays corresponding to the input values. Arraysof shape() are converted to shape(1,1,1), 1D arrays ofshape(N,) are converted to shape(1,N,1), 2D arrays ofshape(M,N) are converted to shape(M,N,1), and arraysof all other shapes are returned unchanged.

Return type:

Array |list[Array]

Examples

Scalar arguments are converted to 3D, size-1 arrays:

>>>x=jnp.float32(1.0)>>>jnp.atleast_3d(x)Array([[[1.]]], dtype=float32)

1D arrays have a unit dimension prepended and appended:

>>>y=jnp.arange(4)>>>jnp.atleast_3d(y).shape(1, 4, 1)

2D arrays have a unit dimension appended:

>>>z=jnp.ones((2,3))>>>jnp.atleast_3d(z).shape(2, 3, 1)

Multiple arguments can be passed to the function at once, in whichcase a list of results is returned:

>>>x3,y3=jnp.atleast_3d(x,y)>>>print(x3)[[[1.]]]>>>print(y3)[[[0]  [1]  [2]  [3]]]
Contents

[8]ページ先頭

©2009-2025 Movatter.jp