Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.lax.map

Contents

jax.lax.map#

jax.lax.map(f,xs,*,batch_size=None)[source]#

Map a function over leading array axes.

Like Python’s builtin map, except inputs and outputs are in the form ofstacked arrays. Consider using thevmap() transform instead, unless youneed to apply a function element by element for reduced memory usage orheterogeneous computation with other control flow primitives.

Whenxs is an array type, the semantics ofmap() are given by thisPython implementation:

defmap(f,xs):returnnp.stack([f(x)forxinxs])

Likescan(),map() is implemented in terms of JAX primitives somany of the same advantages over a Python loop apply:xs may be anarbitrary nested pytree type, and the mapped computation is compiled onlyonce.

Ifbatch_size is provided, the computation is executed in batches of that sizeand parallelized usingvmap(). This can be used as either a more performantversion ofmap or as a memory-efficient version ofvmap. If the axis is notdivisible by the batch size, the remainder is processed in a separatevmap andconcatenated to the result.

>>>x=jnp.ones((10,3,4))>>>deff(x):...print('inner shape:',x.shape)...returnx+1>>>y=lax.map(f,x,batch_size=3)inner shape: (3, 4)inner shape: (3, 4)>>>y.shape(10, 3, 4)

In the example above, “inner shape” is printed twice, once while tracing the batchedcomputation and once while tracing the remainder computation.

Parameters:
  • f – a Python function to apply element-wise over the first axis or axes ofxs.

  • xs – values over which to map along the leading axis.

  • batch_size (int |None) – (optional) integer specifying the size of the batch for each step to executein parallel.

Returns:

Mapped values.

Contents

[8]ページ先頭

©2009-2025 Movatter.jp