jax.numpy.vectorize
Contents
jax.numpy.vectorize#
- jax.numpy.vectorize(pyfunc,*,excluded=frozenset({}),signature=None)[source]#
Define a vectorized function with broadcasting.
vectorize()is a convenience wrapper for defining vectorizedfunctions with broadcasting, in the style of NumPy’sgeneralized universal functions.It allows for defining functions that are automatically repeated acrossany leading dimensions, without the implementation of the function needing tobe concerned about how to handle higher dimensional inputs.jax.numpy.vectorize()has the same interface asnumpy.vectorize, but it is syntactic sugar for an auto-batchingtransformation (vmap()) rather than a Python loop. This should beconsiderably more efficient, but the implementation must be written in termsof functions that act on JAX arrays.- Parameters:
pyfunc – function to vectorize.
excluded – optional set of integers representing positional arguments forwhich the function will not be vectorized. These will be passed directlyto
pyfuncunmodified.signature – optional generalized universal function signature, e.g.,
(m,n),(n)->(m)for vectorized matrix-vector multiplication. Ifprovided,pyfuncwill be called with (and expected to return) arrayswith shapes given by the size of corresponding core dimensions. Bydefault, pyfunc is assumed to take scalar arrays as input, and ifsignatureisNone,pyfunccan produce outputs of any shape.
- Returns:
Vectorized version of the given function.
Examples
Here are a few examples of how one could write vectorized linear algebraroutines using
vectorize():>>>fromfunctoolsimportpartial
>>>@partial(jnp.vectorize,signature='(k),(k)->(k)')...defcross_product(a,b):...asserta.shape==b.shapeanda.ndim==b.ndim==1...returnjnp.array([a[1]*b[2]-a[2]*b[1],...a[2]*b[0]-a[0]*b[2],...a[0]*b[1]-a[1]*b[0]])
>>>@partial(jnp.vectorize,signature='(n,m),(m)->(n)')...defmatrix_vector_product(matrix,vector):...assertmatrix.ndim==2andmatrix.shape[1:]==vector.shape...returnmatrix@vector
These functions are only written to handle 1D or 2D arrays (the
assertstatements will never be violated), but with vectorize they supportarbitrary dimensional inputs with NumPy style broadcasting, e.g.,>>>cross_product(jnp.ones(3),jnp.ones(3)).shape(3,)>>>cross_product(jnp.ones((2,3)),jnp.ones(3)).shape(2, 3)>>>cross_product(jnp.ones((1,2,3)),jnp.ones((2,1,3))).shape(2, 2, 3)>>>matrix_vector_product(jnp.ones(3),jnp.ones(3))Traceback (most recent call last):ValueError:input with shape (3,) does not have enough dimensions for allcore dimensions ('n', 'k') on vectorized function with excluded=frozenset()and signature='(n,k),(k)->(k)'>>>matrix_vector_product(jnp.ones((2,3)),jnp.ones(3)).shape(2,)>>>matrix_vector_product(jnp.ones((2,3)),jnp.ones((4,3))).shape(4, 2)
Note that this has different semantics thanjnp.matmul:
>>>jnp.matmul(jnp.ones((2,3)),jnp.ones((4,3)))Traceback (most recent call last):TypeError:dot_general requires contracting dimensions to have the same shape, got [3] and [4].
