jax.lax.map
Contents
jax.lax.map#
- jax.lax.map(f,xs,*,batch_size=None)[source]#
Map a function over leading array axes.
Like Python’s builtin map, except inputs and outputs are in the form ofstacked arrays. Consider using the
vmap()transform instead, unless youneed to apply a function element by element for reduced memory usage orheterogeneous computation with other control flow primitives.When
xsis an array type, the semantics ofmap()are given by thisPython implementation:defmap(f,xs):returnnp.stack([f(x)forxinxs])
Like
scan(),map()is implemented in terms of JAX primitives somany of the same advantages over a Python loop apply:xsmay be anarbitrary nested pytree type, and the mapped computation is compiled onlyonce.If
batch_sizeis provided, the computation is executed in batches of that sizeand parallelized usingvmap(). This can be used as either a more performantversion ofmapor as a memory-efficient version ofvmap. If the axis is notdivisible by the batch size, the remainder is processed in a separatevmapandconcatenated to the result.>>>x=jnp.ones((10,3,4))>>>deff(x):...print('inner shape:',x.shape)...returnx+1>>>y=lax.map(f,x,batch_size=3)inner shape: (3, 4)inner shape: (3, 4)>>>y.shape(10, 3, 4)
In the example above, “inner shape” is printed twice, once while tracing the batchedcomputation and once while tracing the remainder computation.
- Parameters:
f – a Python function to apply element-wise over the first axis or axes of
xs.xs – values over which to map along the leading axis.
batch_size (int |None) – (optional) integer specifying the size of the batch for each step to executein parallel.
- Returns:
Mapped values.
