Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.from_dlpack

Contents

jax.numpy.from_dlpack#

jax.numpy.from_dlpack(x,/,*,device=None,copy=None)[source]#

Construct a JAX array via DLPack.

JAX implementation ofnumpy.from_dlpack().

Parameters:
  • x (Any) – An object that implements theDLPack protocol via the__dlpack__and__dlpack_device__ methods, or a legacy DLPack tensor on eitherCPU or GPU.

  • device (xc.Device |Sharding |None) – An optionalDevice orSharding,representing the single device onto which the returned array should be placed.If given, then the result is committed to the device. If unspecified,the resulting array will be unpacked onto the same device it originated from.Settingdevice to a device different from the source ofexternal_arraywill require a copy, meaningcopy must be set to eitherTrue orNone.

  • copy (bool |None) – An optional boolean, controlling whether or not a copy is performed.Ifcopy=True then a copy is always performed, even if unpacked onto thesame device. Ifcopy=False then the copy is never performed and will raisean error if necessary. Whencopy=None (default) then a copy may be performedif needed for a device transfer.

Returns:

A JAX array of the input buffer.

Return type:

Array

Note

While JAX arrays are always immutable, dlpack buffers cannot be marked asimmutable, and it is possible for processes external to JAX to mutate themin-place. If a JAX Array is constructed from a dlpack buffer without copyingand the source buffer is later modified in-place, it may lead to undefinedbehavior when using the associated JAX array.

Examples

Passing data between NumPy and JAX viaDLPack:

>>>importnumpyasnp>>>rng=np.random.default_rng(42)>>>x_numpy=rng.random(4,dtype='float32')>>>print(x_numpy)[0.08925092 0.773956   0.6545715  0.43887842]>>>hasattr(x_numpy,"__dlpack__")# NumPy supports the DLPack interfaceTrue
>>>importjax.numpyasjnp>>>x_jax=jnp.from_dlpack(x_numpy)>>>print(x_jax)[0.08925092 0.773956   0.6545715  0.43887842]>>>hasattr(x_jax,"__dlpack__")# JAX supports the DLPack interfaceTrue
>>>x_numpy_round_trip=np.from_dlpack(x_jax)>>>print(x_numpy_round_trip)[0.08925092 0.773956   0.6545715  0.43887842]
Contents

[8]ページ先頭

©2009-2025 Movatter.jp