jax.vmap
Contents
jax.vmap#
- jax.vmap(fun,in_axes=0,out_axes=0,axis_name=None,axis_size=None,spmd_axis_name=None)[source]#
Vectorizing map. Creates a function which maps
funover argument axes.- Parameters:
fun (F) – Function to be mapped over additional axes.
in_axes (int |None |Sequence[Any]) –
An integer, None, or sequence of values specifying which inputarray axes to map over.
If each positional argument to
funis an array, thenin_axescanbe an integer, a None, or a tuple of integers and Nones with length equalto the number of positional arguments tofun. An integer orNoneindicates which array axis to map over for all arguments (withNoneindicating not to map any axis), and a tuple indicates which axis to mapfor each corresponding positional argument. Axis integers must be in therange[-ndim,ndim)for each array, wherendimis the number ofdimensions (axes) of the corresponding input array.If the positional arguments to
funare container (pytree) types,in_axesmust be a sequence with length equal to the number of positional arguments tofun, and for each argument the corresponding element ofin_axescanbe a container with a matching pytree structure specifying the mapping of itscontainer elements. In other words,in_axesmust be a container tree prefixof the positional argument tuple passed tofun. See this link for more detail:https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytreesEither
axis_sizemust be provided explicitly, or at least onepositional argument must havein_axesnot None. The sizes of themapped input axes for all mapped positional arguments must all be equal.Arguments passed as keywords are always mapped over their leading axis(i.e. axis index 0).
See below for examples.
out_axes (Any) – An integer, None, or (nested) standard Python container(tuple/list/dict) thereof indicating where the mapped axis should appearin the output. All outputs with a mapped axis must have a non-None
out_axesspecification. Axis integers must be in the range[-ndim,ndim)for each output array, wherendimis the number of dimensions(axes) of the array returned by thevmap()-ed function, which is onemore than the number of dimensions (axes) of the corresponding arrayreturned byfun.axis_name (AxisName |None) – Optional, a hashable Python object used to identify the mappedaxis so that parallel collectives can be applied.
axis_size (int |None) – Optional, an integer indicating the size of the axis to bemapped. If not provided, the mapped axis size is inferred from arguments.
spmd_axis_name (AxisName |tuple[AxisName,...]|None)
- Returns:
Batched/vectorized version of
funwith arguments that correspond tothose offun, but with extra array axes at positions indicated byin_axes, and a return value that corresponds to that offun, butwith extra array axes at positions indicated byout_axes.- Return type:
F
For example, we can implement a matrix-matrix product using a vector dotproduct:
>>>importjax.numpyasjnp>>>>>>vv=lambdax,y:jnp.vdot(x,y)# ([a], [a]) -> []>>>mv=vmap(vv,(0,None),0)# ([b,a], [a]) -> [b] (b is the mapped axis)>>>mm=vmap(mv,(None,1),1)# ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
Here we use
[a,b]to indicate an array with shape (a,b). Here are somevariants:>>>mv1=vmap(vv,(0,0),0)# ([b,a], [b,a]) -> [b] (b is the mapped axis)>>>mv2=vmap(vv,(0,1),0)# ([b,a], [a,b]) -> [b] (b is the mapped axis)>>>mm2=vmap(mv2,(1,1),0)# ([b,c,a], [a,c,b]) -> [c,b] (c is the mapped axis)
Here’s an example of using container types in
in_axesto specify whichaxes of the container elements to map over:>>>A,B,C,D=2,3,4,5>>>x=jnp.ones((A,B))>>>y=jnp.ones((B,C))>>>z=jnp.ones((C,D))>>>deffoo(tree_arg):...x,(y,z)=tree_arg...returnjnp.dot(x,jnp.dot(y,z))>>>tree=(x,(y,z))>>>print(foo(tree))[[12. 12. 12. 12. 12.] [12. 12. 12. 12. 12.]]>>>fromjaximportvmap>>>K=6# batch size>>>x=jnp.ones((K,A,B))# batch axis in different locations>>>y=jnp.ones((B,K,C))>>>z=jnp.ones((C,D,K))>>>tree=(x,(y,z))>>>vfoo=vmap(foo,in_axes=((0,(1,2)),))>>>print(vfoo(tree).shape)(6, 2, 5)
Here’s another example using container types in
in_axes, this time adictionary, to specify the elements of the container to map over:>>>dct={'a':0.,'b':jnp.arange(5.)}>>>x=1.>>>deffoo(dct,x):...returndct['a']+dct['b']+x>>>out=vmap(foo,in_axes=({'a':None,'b':0},None))(dct,x)>>>print(out)[1. 2. 3. 4. 5.]
The results of a vectorized function can be mapped or unmapped. For example,the function below returns a pair with the first element mapped and the secondunmapped. Only for unmapped results we can specify
out_axesto beNone(to keep it unmapped).>>>print(vmap(lambdax,y:(x+y,y*2.),in_axes=(0,None),out_axes=(0,None))(jnp.arange(2.),4.))(Array([4., 5.], dtype=float32), 8.0)
If the
out_axesis specified for an unmapped result, the result isbroadcast across the mapped axis:>>>print(vmap(lambdax,y:(x+y,y*2.),in_axes=(0,None),out_axes=0)(jnp.arange(2.),4.))(Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))
If the
out_axesis specified for a mapped result, the result is transposedaccordingly.Finally, here’s an example using
axis_nametogether with collectives:>>>xs=jnp.arange(3.*4.).reshape(3,4)>>>print(vmap(lambdax:lax.psum(x,'i'),axis_name='i')(xs))[[12. 15. 18. 21.] [12. 15. 18. 21.] [12. 15. 18. 21.]]
See the
jax.pmap()docstring for more examples involving collectives.
