Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Transfer guard

Transfer guard#

JAX may transfer data between the host and devices and between devices duringtype conversion and input sharding. To log or disallow any unintendedtransfers, the user may configure a JAX transfer guard.

JAX transfer guards distinguish between two types of transfers:

  • Explicit transfers:jax.device_put*() andjax.device_get() calls.

  • Implicit transfers: Other transfers (e.g., printing aDeviceArray).

A transfer guard can take an action based on its guard level:

  • "allow": Silently allow all transfers (default).

  • "log": Log and allow implicit transfers. Silently allow explicittransfers.

  • "disallow": Disallow implicit transfers. Silently allow explicittransfers.

  • "log_explicit": Log and allow all transfers.

  • "disallow_explicit": Disallow all transfers.

JAX will raise aRuntimeError when disallowing a transfer.

The transfer guards use the standard JAX configuration system:

  • A--jax_transfer_guard=GUARD_LEVEL command-line flag andjax.config.update("jax_transfer_guard",GUARD_LEVEL) will set the globaloption.

  • Awithjax.transfer_guard(GUARD_LEVEL):... context manager will set thethread-local option within the scope of the context manager.

Note that similar to other JAX configuration options, a newly spawned threadwill use the global option instead of any active thread-local option of thescope where the thread was spawned.

The transfer guards can also be applied more selectively, based on thedirection of transfer. The flag and context manager name is suffixed with acorresponding transfer direction (e.g.,--jax_transfer_guard_host_to_deviceandjax.config.transfer_guard_host_to_device):

  • "host_to_device": Converting a Python value or NumPy array into a JAXon-device buffer.

  • "device_to_device": Copying a JAX on-device buffer to a different device.

  • "device_to_host": Fetching a JAX on-device buffer.

Fetching a buffer on a CPU device is always allowed regardless of the transferguard level.

The following shows an example of using the transfer guard.

>>>jax.config.update("jax_transfer_guard","allow")# This is default.>>>>>>x=jnp.array(1)>>>y=jnp.array(2)>>>z=jnp.array(3)>>>>>>print("x",x)# All transfers are allowed.x 1>>>withjax.transfer_guard("disallow"):...print("x",x)# x has already been fetched into the host....print("y",jax.device_get(y))# Explicit transfers are allowed....try:...print("z",z)# Implicit transfers are disallowed....assertFalse,"This line is expected to be unreachable."...except:...print("z could not be fetched")x 1y 2z could not be fetched

[8]ページ先頭

©2009-2026 Movatter.jp