Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.vmap

Contents

jax.vmap#

jax.vmap(fun,in_axes=0,out_axes=0,axis_name=None,axis_size=None,spmd_axis_name=None)[source]#

Vectorizing map. Creates a function which mapsfun over argument axes.

Parameters:
  • fun (F) – Function to be mapped over additional axes.

  • in_axes (int |None |Sequence[Any]) –

    An integer, None, or sequence of values specifying which inputarray axes to map over.

    If each positional argument tofun is an array, thenin_axes canbe an integer, a None, or a tuple of integers and Nones with length equalto the number of positional arguments tofun. An integer orNoneindicates which array axis to map over for all arguments (withNoneindicating not to map any axis), and a tuple indicates which axis to mapfor each corresponding positional argument. Axis integers must be in therange[-ndim,ndim) for each array, wherendim is the number ofdimensions (axes) of the corresponding input array.

    If the positional arguments tofun are container (pytree) types,in_axesmust be a sequence with length equal to the number of positional arguments tofun, and for each argument the corresponding element ofin_axes canbe a container with a matching pytree structure specifying the mapping of itscontainer elements. In other words,in_axes must be a container tree prefixof the positional argument tuple passed tofun. See this link for more detail:https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees

    Eitheraxis_size must be provided explicitly, or at least onepositional argument must havein_axes not None. The sizes of themapped input axes for all mapped positional arguments must all be equal.

    Arguments passed as keywords are always mapped over their leading axis(i.e. axis index 0).

    See below for examples.

  • out_axes (Any) – An integer, None, or (nested) standard Python container(tuple/list/dict) thereof indicating where the mapped axis should appearin the output. All outputs with a mapped axis must have a non-Noneout_axes specification. Axis integers must be in the range[-ndim,ndim) for each output array, wherendim is the number of dimensions(axes) of the array returned by thevmap()-ed function, which is onemore than the number of dimensions (axes) of the corresponding arrayreturned byfun.

  • axis_name (AxisName |None) – Optional, a hashable Python object used to identify the mappedaxis so that parallel collectives can be applied.

  • axis_size (int |None) – Optional, an integer indicating the size of the axis to bemapped. If not provided, the mapped axis size is inferred from arguments.

  • spmd_axis_name (AxisName |tuple[AxisName,...]|None)

Returns:

Batched/vectorized version offun with arguments that correspond tothose offun, but with extra array axes at positions indicated byin_axes, and a return value that corresponds to that offun, butwith extra array axes at positions indicated byout_axes.

Return type:

F

For example, we can implement a matrix-matrix product using a vector dotproduct:

>>>importjax.numpyasjnp>>>>>>vv=lambdax,y:jnp.vdot(x,y)#  ([a], [a]) -> []>>>mv=vmap(vv,(0,None),0)#  ([b,a], [a]) -> [b]      (b is the mapped axis)>>>mm=vmap(mv,(None,1),1)#  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)

Here we use[a,b] to indicate an array with shape (a,b). Here are somevariants:

>>>mv1=vmap(vv,(0,0),0)#  ([b,a], [b,a]) -> [b]        (b is the mapped axis)>>>mv2=vmap(vv,(0,1),0)#  ([b,a], [a,b]) -> [b]        (b is the mapped axis)>>>mm2=vmap(mv2,(1,1),0)#  ([b,c,a], [a,c,b]) -> [c,b]  (c is the mapped axis)

Here’s an example of using container types inin_axes to specify whichaxes of the container elements to map over:

>>>A,B,C,D=2,3,4,5>>>x=jnp.ones((A,B))>>>y=jnp.ones((B,C))>>>z=jnp.ones((C,D))>>>deffoo(tree_arg):...x,(y,z)=tree_arg...returnjnp.dot(x,jnp.dot(y,z))>>>tree=(x,(y,z))>>>print(foo(tree))[[12. 12. 12. 12. 12.] [12. 12. 12. 12. 12.]]>>>fromjaximportvmap>>>K=6# batch size>>>x=jnp.ones((K,A,B))# batch axis in different locations>>>y=jnp.ones((B,K,C))>>>z=jnp.ones((C,D,K))>>>tree=(x,(y,z))>>>vfoo=vmap(foo,in_axes=((0,(1,2)),))>>>print(vfoo(tree).shape)(6, 2, 5)

Here’s another example using container types inin_axes, this time adictionary, to specify the elements of the container to map over:

>>>dct={'a':0.,'b':jnp.arange(5.)}>>>x=1.>>>deffoo(dct,x):...returndct['a']+dct['b']+x>>>out=vmap(foo,in_axes=({'a':None,'b':0},None))(dct,x)>>>print(out)[1. 2. 3. 4. 5.]

The results of a vectorized function can be mapped or unmapped. For example,the function below returns a pair with the first element mapped and the secondunmapped. Only for unmapped results we can specifyout_axes to beNone(to keep it unmapped).

>>>print(vmap(lambdax,y:(x+y,y*2.),in_axes=(0,None),out_axes=(0,None))(jnp.arange(2.),4.))(Array([4., 5.], dtype=float32), 8.0)

If theout_axes is specified for an unmapped result, the result isbroadcast across the mapped axis:

>>>print(vmap(lambdax,y:(x+y,y*2.),in_axes=(0,None),out_axes=0)(jnp.arange(2.),4.))(Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))

If theout_axes is specified for a mapped result, the result is transposedaccordingly.

Finally, here’s an example usingaxis_name together with collectives:

>>>xs=jnp.arange(3.*4.).reshape(3,4)>>>print(vmap(lambdax:lax.psum(x,'i'),axis_name='i')(xs))[[12. 15. 18. 21.] [12. 15. 18. 21.] [12. 15. 18. 21.]]

See thejax.pmap() docstring for more examples involving collectives.

Contents

[8]ページ先頭

©2009-2025 Movatter.jp