Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.load

Contents

jax.numpy.load#

jax.numpy.load(file,*args,**kwargs)[source]#

Load JAX arrays from npy files.

JAX wrapper ofnumpy.load().

This function is a simple wrapper ofnumpy.load(), but in the case of.npy files created withnumpy.save() orjax.numpy.save(),the output will be returned as ajax.Array, andbfloat16 datatypes will be restored. For.npz files, results will be returned asnormal NumPy arrays.

This function requires concrete array inputs, and is not compatible withtransformations likejax.jit() orjax.vmap().

Parameters:
  • file (IO[bytes]|str |os.PathLike[Any]) – string, bytes, or path-like object containing the array data.

  • args (Any) – for additional arguments, seenumpy.load()

  • kwargs (Any) – for additional arguments, seenumpy.load()

Returns:

the array stored in the file.

Return type:

Array

See also

Examples

>>>importio>>>f=io.BytesIO()# use an in-memory file-like object.>>>x=jnp.array([2,4,6,8],dtype='bfloat16')>>>jnp.save(f,x)>>>f.seek(0)0>>>jnp.load(f)Array([2, 4, 6, 8], dtype=bfloat16)
Contents

[8]ページ先頭

©2009-2026 Movatter.jp