Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Buffer donation

Buffer donation#

When JAX executes a computation it uses buffers on the device for all inputs and outputs.If you know that one of the inputs is not needed after the computation, and if itmatches the shape and element type of one of the outputs, you can specify that youwant the corresponding input buffer to be donated to hold an output. This will reducethe memory required for the execution by the size of the donated buffer.

If you have something like the following pattern, you can use buffer donation:

params,state=jax.pmap(update_fn,donate_argnums=(0,1))(params,state)

You can think of this as a way to do a memory-efficient functional updateon your immutable JAX arrays. Within the boundaries of a computation XLA canmake this optimization for you, but at the jit/pmap boundary you need toguarantee to XLA that you will not use the donated input buffer after callingthe donating function.

You achieve this by using thedonate_argnums parameter to the functionsjax.jit(),jax.pjit(), andjax.pmap(). This parameter is a sequence of indices (0 based) intothe positional argument list:

defadd(x,y):returnx+yx=jax.device_put(np.ones((2,3)))y=jax.device_put(np.ones((2,3)))# Execute `add` with donation of the buffer for `y`. The result has# the same shape and type as `y`, so it will share its buffer.z=jax.jit(add,donate_argnums=(1,))(x,y)

Note that this currently does not work when calling your function with key-word arguments!The following code will not donate any buffers:

params,state=jax.pmap(update_fn,donate_argnums=(0,1))(params=params,state=state)

If an argument whose buffer is donated is a pytree, then all the buffersfor its components are donated:

defadd_ones(xs:List[Array]):return[x+1forxinxs]xs=[jax.device_put(np.ones((2,3))),jax.device_put(np.ones((3,4)))]# Execute `add_ones` with donation of all the buffers for `xs`.# The outputs have the same shape and type as the elements of `xs`,# so they will share those buffers.z=jax.jit(add_ones,donate_argnums=0)(xs)

It is not allowed to donate a buffer that is used subsequently in the computation,and JAX will give an error because the buffer fory has become invalidafter it was donated:

# Donate the buffer for `y`z=jax.jit(add,donate_argnums=(1,))(x,y)w=y+1# Reuses `y` whose buffer was donated above# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer

You will get a warning if the donated buffer is not used, e.g., becausethere are more donated buffers than can be used for the outputs:

# Execute `add` with donation of the buffers for both `x` and `y`.# One of those buffers will be used for the result, but the other will# not be used.z=jax.jit(add,donate_argnums=(0,1))(x,y)# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}

The donation may also be unused if there is no output whose shape matchesthe donation:

y=jax.device_put(np.ones((1,3)))# `y` has different shape than the output# Execute `add` with donation of the buffer for `y`.z=jax.jit(add,donate_argnums=(1,))(x,y)# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}

[8]ページ先頭

©2009-2026 Movatter.jp