Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Automatic vectorization#

In the previous section we discussed JIT compilation via thejax.jit() function.This notebook discusses another of JAX’s transforms: vectorization viajax.vmap().

Manual vectorization#

Consider the following simple code that computes the convolution of two one-dimensional vectors:

importjaximportjax.numpyasjnpx=jnp.arange(5)w=jnp.array([2.,3.,4.])defconvolve(x,w):output=[]foriinrange(1,len(x)-1):output.append(jnp.dot(x[i-1:i+2],w))returnjnp.array(output)convolve(x,w)
Array([11., 20., 29.], dtype=float32)

Suppose we would like to apply this function to a batch of weightsw to a batch of vectorsx.

xs=jnp.stack([x,x])ws=jnp.stack([w,w])

The most naive option would be to simply loop over the batch in Python:

defmanually_batched_convolve(xs,ws):output=[]foriinrange(xs.shape[0]):output.append(convolve(xs[i],ws[i]))returnjnp.stack(output)manually_batched_convolve(xs,ws)
Array([[11., 20., 29.],       [11., 20., 29.]], dtype=float32)

This produces the correct result, however it is not very efficient.

In order to batch the computation efficiently, you would normally have to rewrite the function manually to ensure it is done in vectorized form. This is not particularly difficult to implement, but does involve changing how the function treats indices, axes, and other parts of the input.

For example, we could manually rewriteconvolve() to support vectorized computation across the batch dimension as follows:

defmanually_vectorized_convolve(xs,ws):output=[]foriinrange(1,xs.shape[-1]-1):output.append(jnp.sum(xs[:,i-1:i+2]*ws,axis=1))returnjnp.stack(output,axis=1)manually_vectorized_convolve(xs,ws)
Array([[11., 20., 29.],       [11., 20., 29.]], dtype=float32)

Such re-implementation can be messy and error-prone as the complexity of a function increases; fortunately JAX provides another way.

Automatic vectorization#

In JAX, thejax.vmap() transformation is designed to generate such a vectorized implementation of a function automatically:

auto_batch_convolve=jax.vmap(convolve)auto_batch_convolve(xs,ws)
Array([[11., 20., 29.],       [11., 20., 29.]], dtype=float32)

It does this by tracing the function similarly tojax.jit(), and automatically adding batch axes at the beginning of each input.

If the batch dimension is not the first, you may use thein_axes andout_axes arguments to specify the location of the batch dimension in inputs and outputs. These may be an integer if the batch axis is the same for all inputs and outputs, or lists, otherwise.

auto_batch_convolve_v2=jax.vmap(convolve,in_axes=1,out_axes=1)xst=jnp.transpose(xs)wst=jnp.transpose(ws)auto_batch_convolve_v2(xst,wst)
Array([[11., 11.],       [20., 20.],       [29., 29.]], dtype=float32)

jax.vmap() also supports the case where only one of the arguments is batched: for example, if you would like to convolve to a single set of weightsw with a batch of vectorsx; in this case thein_axes argument can be set toNone:

batch_convolve_v3=jax.vmap(convolve,in_axes=[0,None])batch_convolve_v3(xs,w)
Array([[11., 20., 29.],       [11., 20., 29.]], dtype=float32)

Combining transformations#

As with all JAX transformations,jax.jit() andjax.vmap() are designed to be composable, which means you can wrap a vmapped function withjit, or a jitted function withvmap, and everything will work correctly:

jitted_batch_convolve=jax.jit(auto_batch_convolve)jitted_batch_convolve(xs,ws)
Array([[11., 20., 29.],       [11., 20., 29.]], dtype=float32)

[8]ページ先頭

©2009-2025 Movatter.jp