jax.numpy.load
Contents
jax.numpy.load#
- jax.numpy.load(file,*args,**kwargs)[source]#
Load JAX arrays from npy files.
JAX wrapper of
numpy.load().This function is a simple wrapper of
numpy.load(), but in the case of.npyfiles created withnumpy.save()orjax.numpy.save(),the output will be returned as ajax.Array, andbfloat16datatypes will be restored. For.npzfiles, results will be returned asnormal NumPy arrays.This function requires concrete array inputs, and is not compatible withtransformations like
jax.jit()orjax.vmap().- Parameters:
file (IO[bytes]|str |os.PathLike[Any]) – string, bytes, or path-like object containing the array data.
args (Any) – for additional arguments, see
numpy.load()kwargs (Any) – for additional arguments, see
numpy.load()
- Returns:
the array stored in the file.
- Return type:
See also
jax.numpy.save(): save an array to a file.
Examples
>>>importio>>>f=io.BytesIO()# use an in-memory file-like object.>>>x=jnp.array([2,4,6,8],dtype='bfloat16')>>>jnp.save(f,x)>>>f.seek(0)0>>>jnp.load(f)Array([2, 4, 6, 8], dtype=bfloat16)
Contents
