Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.einsum

Contents

jax.numpy.einsum#

jax.numpy.einsum(subscripts,/,*operands,out=None,optimize='auto',precision=None,preferred_element_type=None,_dot_general=<functiondot_general>,out_sharding=None)[source]#

Einstein summation

JAX implementation ofnumpy.einsum().

einsum is a powerful and generic API for computing various reductions,inner products, outer products, axis reorderings, and combinations thereofacross one or more input arrays. It has a somewhat complicated overloaded API;the arguments below reflect the most common calling convention. The Examplessection below demonstrates some of the alternative calling conventions.

Parameters:
  • subscripts – string containing axes names separated by commas.

  • *operands – sequence of one or more arrays corresponding to the subscripts.

  • optimize (str |bool |list[tuple[int,...]]) – specify how to optimize the order of computation. In JAX this defaultsto"auto" which produces optimized expressions via theopt_einsumpackage. Other options areTrue (same as"optimal"),False(unoptimized), or any string supported byopt_einsum, whichincludes"optimal","greedy","eager", and others. It may alsobe a pre-computed path (seeeinsum_path()).

  • precision (None |str |Precision |tuple[str,str]|tuple[Precision,Precision]|DotAlgorithm |DotAlgorithmPreset) – eitherNone (default), which means the default precision forthe backend, aPrecision enum value (Precision.DEFAULT,Precision.HIGH orPrecision.HIGHEST).

  • preferred_element_type (str |type[Any]|dtype |SupportsDType |None) – eitherNone (default), which means the defaultaccumulation type for the input types, or a datatype, indicating toaccumulate results to and return a result with that datatype.

  • out (None) – unsupported by JAX

  • _dot_general (Callable[[...],Array]) – optionally override thedot_general callable used byeinsum.This parameter is experimental, and may be removed without warning at any time.

Returns:

array containing the result of the einstein summation.

Return type:

Array

Examples

The mechanics ofeinsum are perhaps best demonstrated by example. Here weshow how to useeinsum to compute a number of quantities from one or morearrays. For more discussion and examples ofeinsum, see the documentationofnumpy.einsum().

>>>M=jnp.arange(16).reshape(4,4)>>>x=jnp.arange(4)>>>y=jnp.array([5,4,3,2])

Vector product

>>>jnp.einsum('i,i',x,y)Array(16, dtype=int32)>>>jnp.vecdot(x,y)Array(16, dtype=int32)

Here are some alternativeeinsum calling conventions to compute the sameresult:

>>>jnp.einsum('i,i->',x,y)# explicit formArray(16, dtype=int32)>>>jnp.einsum(x,(0,),y,(0,))# implicit form via indicesArray(16, dtype=int32)>>>jnp.einsum(x,(0,),y,(0,),())# explicit form via indicesArray(16, dtype=int32)

Matrix product

>>>jnp.einsum('ij,j->i',M,x)# explicit formArray([14, 38, 62, 86], dtype=int32)>>>jnp.matmul(M,x)Array([14, 38, 62, 86], dtype=int32)

Here are some alternativeeinsum calling conventions to compute the sameresult:

>>>jnp.einsum('ij,j',M,x)# implicit formArray([14, 38, 62, 86], dtype=int32)>>>jnp.einsum(M,(0,1),x,(1,),(0,))# explicit form via indicesArray([14, 38, 62, 86], dtype=int32)>>>jnp.einsum(M,(0,1),x,(1,))# implicit form via indicesArray([14, 38, 62, 86], dtype=int32)

Outer product

>>>jnp.einsum("i,j->ij",x,y)Array([[ 0,  0,  0,  0],       [ 5,  4,  3,  2],       [10,  8,  6,  4],       [15, 12,  9,  6]], dtype=int32)>>>jnp.outer(x,y)Array([[ 0,  0,  0,  0],       [ 5,  4,  3,  2],       [10,  8,  6,  4],       [15, 12,  9,  6]], dtype=int32)

Some other ways of computing outer products:

>>>jnp.einsum("i,j",x,y)# implicit formArray([[ 0,  0,  0,  0],       [ 5,  4,  3,  2],       [10,  8,  6,  4],       [15, 12,  9,  6]], dtype=int32)>>>jnp.einsum(x,(0,),y,(1,),(0,1))# explicit form via indicesArray([[ 0,  0,  0,  0],       [ 5,  4,  3,  2],       [10,  8,  6,  4],       [15, 12,  9,  6]], dtype=int32)>>>jnp.einsum(x,(0,),y,(1,))# implicit form via indicesArray([[ 0,  0,  0,  0],       [ 5,  4,  3,  2],       [10,  8,  6,  4],       [15, 12,  9,  6]], dtype=int32)

1D array sum

>>>jnp.einsum("i->",x)# requires explicit formArray(6, dtype=int32)>>>jnp.einsum(x,(0,),())# explicit form via indicesArray(6, dtype=int32)>>>jnp.sum(x)Array(6, dtype=int32)

Sum along an axis

>>>jnp.einsum("...j->...",M)# requires explicit formArray([ 6, 22, 38, 54], dtype=int32)>>>jnp.einsum(M,(...,0),(...,))# explicit form via indicesArray([ 6, 22, 38, 54], dtype=int32)>>>M.sum(-1)Array([ 6, 22, 38, 54], dtype=int32)

Matrix transpose

>>>y=jnp.array([[1,2,3],...[4,5,6]])>>>jnp.einsum("ij->ji",y)# explicit formArray([[1, 4],       [2, 5],       [3, 6]], dtype=int32)>>>jnp.einsum("ji",y)# implicit formArray([[1, 4],       [2, 5],       [3, 6]], dtype=int32)>>>jnp.einsum(y,(1,0))# implicit form via indicesArray([[1, 4],       [2, 5],       [3, 6]], dtype=int32)>>>jnp.einsum(y,(0,1),(1,0))# explicit form via indicesArray([[1, 4],       [2, 5],       [3, 6]], dtype=int32)>>>jnp.transpose(y)Array([[1, 4],       [2, 5],       [3, 6]], dtype=int32)

Matrix diagonal

>>>jnp.einsum("ii->i",M)Array([ 0,  5, 10, 15], dtype=int32)>>>jnp.diagonal(M)Array([ 0,  5, 10, 15], dtype=int32)

Matrix trace

>>>jnp.einsum("ii",M)Array(30, dtype=int32)>>>jnp.trace(M)Array(30, dtype=int32)

Tensor products

>>>x=jnp.arange(30).reshape(2,3,5)>>>y=jnp.arange(60).reshape(3,4,5)>>>jnp.einsum('ijk,jlk->il',x,y)# explicit formArray([[ 3340,  3865,  4390,  4915],       [ 8290,  9940, 11590, 13240]], dtype=int32)>>>jnp.tensordot(x,y,axes=[(1,2),(0,2)])Array([[ 3340,  3865,  4390,  4915],       [ 8290,  9940, 11590, 13240]], dtype=int32)>>>jnp.einsum('ijk,jlk',x,y)# implicit formArray([[ 3340,  3865,  4390,  4915],       [ 8290,  9940, 11590, 13240]], dtype=int32)>>>jnp.einsum(x,(0,1,2),y,(1,3,2),(0,3))# explicit form via indicesArray([[ 3340,  3865,  4390,  4915],       [ 8290,  9940, 11590, 13240]], dtype=int32)>>>jnp.einsum(x,(0,1,2),y,(1,3,2))# implicit form via indicesArray([[ 3340,  3865,  4390,  4915],       [ 8290,  9940, 11590, 13240]], dtype=int32)

Chained dot products

>>>w=jnp.arange(5,9).reshape(2,2)>>>x=jnp.arange(6).reshape(2,3)>>>y=jnp.arange(-2,4).reshape(3,2)>>>z=jnp.array([[2,4,6],[3,5,7]])>>>jnp.einsum('ij,jk,kl,lm->im',w,x,y,z)Array([[ 481,  831, 1181],       [ 651, 1125, 1599]], dtype=int32)>>>jnp.einsum(w,(0,1),x,(1,2),y,(2,3),z,(3,4))# implicit, via indicesArray([[ 481,  831, 1181],       [ 651, 1125, 1599]], dtype=int32)>>>w@x@y@z# direct chain of matmulsArray([[ 481,  831, 1181],       [ 651, 1125, 1599]], dtype=int32)>>>jnp.linalg.multi_dot([w,x,y,z])Array([[ 481,  831, 1181],       [ 651, 1125, 1599]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp