Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.vectorize

Contents

jax.numpy.vectorize#

jax.numpy.vectorize(pyfunc,*,excluded=frozenset({}),signature=None)[source]#

Define a vectorized function with broadcasting.

vectorize() is a convenience wrapper for defining vectorizedfunctions with broadcasting, in the style of NumPy’sgeneralized universal functions.It allows for defining functions that are automatically repeated acrossany leading dimensions, without the implementation of the function needing tobe concerned about how to handle higher dimensional inputs.

jax.numpy.vectorize() has the same interface asnumpy.vectorize, but it is syntactic sugar for an auto-batchingtransformation (vmap()) rather than a Python loop. This should beconsiderably more efficient, but the implementation must be written in termsof functions that act on JAX arrays.

Parameters:
  • pyfunc – function to vectorize.

  • excluded – optional set of integers representing positional arguments forwhich the function will not be vectorized. These will be passed directlytopyfunc unmodified.

  • signature – optional generalized universal function signature, e.g.,(m,n),(n)->(m) for vectorized matrix-vector multiplication. Ifprovided,pyfunc will be called with (and expected to return) arrayswith shapes given by the size of corresponding core dimensions. Bydefault, pyfunc is assumed to take scalar arrays as input, and ifsignature isNone,pyfunc can produce outputs of any shape.

Returns:

Vectorized version of the given function.

Examples

Here are a few examples of how one could write vectorized linear algebraroutines usingvectorize():

>>>fromfunctoolsimportpartial
>>>@partial(jnp.vectorize,signature='(k),(k)->(k)')...defcross_product(a,b):...asserta.shape==b.shapeanda.ndim==b.ndim==1...returnjnp.array([a[1]*b[2]-a[2]*b[1],...a[2]*b[0]-a[0]*b[2],...a[0]*b[1]-a[1]*b[0]])
>>>@partial(jnp.vectorize,signature='(n,m),(m)->(n)')...defmatrix_vector_product(matrix,vector):...assertmatrix.ndim==2andmatrix.shape[1:]==vector.shape...returnmatrix@vector

These functions are only written to handle 1D or 2D arrays (theassertstatements will never be violated), but with vectorize they supportarbitrary dimensional inputs with NumPy style broadcasting, e.g.,

>>>cross_product(jnp.ones(3),jnp.ones(3)).shape(3,)>>>cross_product(jnp.ones((2,3)),jnp.ones(3)).shape(2, 3)>>>cross_product(jnp.ones((1,2,3)),jnp.ones((2,1,3))).shape(2, 2, 3)>>>matrix_vector_product(jnp.ones(3),jnp.ones(3))Traceback (most recent call last):ValueError:input with shape (3,) does not have enough dimensions for allcore dimensions ('n', 'k') on vectorized function with excluded=frozenset()and signature='(n,k),(k)->(k)'>>>matrix_vector_product(jnp.ones((2,3)),jnp.ones(3)).shape(2,)>>>matrix_vector_product(jnp.ones((2,3)),jnp.ones((4,3))).shape(4, 2)

Note that this has different semantics thanjnp.matmul:

>>>jnp.matmul(jnp.ones((2,3)),jnp.ones((4,3)))Traceback (most recent call last):TypeError:dot_general requires contracting dimensions to have the same shape, got [3] and [4].
Contents

[8]ページ先頭

©2009-2025 Movatter.jp