jax.numpy.einsum
Contents
jax.numpy.einsum#
- jax.numpy.einsum(subscripts,/,*operands,out=None,optimize='auto',precision=None,preferred_element_type=None,_dot_general=<functiondot_general>,out_sharding=None)[source]#
Einstein summation
JAX implementation of
numpy.einsum().einsumis a powerful and generic API for computing various reductions,inner products, outer products, axis reorderings, and combinations thereofacross one or more input arrays. It has a somewhat complicated overloaded API;the arguments below reflect the most common calling convention. The Examplessection below demonstrates some of the alternative calling conventions.- Parameters:
subscripts – string containing axes names separated by commas.
*operands – sequence of one or more arrays corresponding to the subscripts.
optimize (str |bool |list[tuple[int,...]]) – specify how to optimize the order of computation. In JAX this defaultsto
"auto"which produces optimized expressions via theopt_einsumpackage. Other options areTrue(same as"optimal"),False(unoptimized), or any string supported byopt_einsum, whichincludes"optimal","greedy","eager", and others. It may alsobe a pre-computed path (seeeinsum_path()).precision (None |str |Precision |tuple[str,str]|tuple[Precision,Precision]|DotAlgorithm |DotAlgorithmPreset) – either
None(default), which means the default precision forthe backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST).preferred_element_type (str |type[Any]|dtype |SupportsDType |None) – either
None(default), which means the defaultaccumulation type for the input types, or a datatype, indicating toaccumulate results to and return a result with that datatype.out (None) – unsupported by JAX
_dot_general (Callable[[...],Array]) – optionally override the
dot_generalcallable used byeinsum.This parameter is experimental, and may be removed without warning at any time.
- Returns:
array containing the result of the einstein summation.
- Return type:
See also
Examples
The mechanics of
einsumare perhaps best demonstrated by example. Here weshow how to useeinsumto compute a number of quantities from one or morearrays. For more discussion and examples ofeinsum, see the documentationofnumpy.einsum().>>>M=jnp.arange(16).reshape(4,4)>>>x=jnp.arange(4)>>>y=jnp.array([5,4,3,2])
Vector product
>>>jnp.einsum('i,i',x,y)Array(16, dtype=int32)>>>jnp.vecdot(x,y)Array(16, dtype=int32)
Here are some alternative
einsumcalling conventions to compute the sameresult:>>>jnp.einsum('i,i->',x,y)# explicit formArray(16, dtype=int32)>>>jnp.einsum(x,(0,),y,(0,))# implicit form via indicesArray(16, dtype=int32)>>>jnp.einsum(x,(0,),y,(0,),())# explicit form via indicesArray(16, dtype=int32)
Matrix product
>>>jnp.einsum('ij,j->i',M,x)# explicit formArray([14, 38, 62, 86], dtype=int32)>>>jnp.matmul(M,x)Array([14, 38, 62, 86], dtype=int32)
Here are some alternative
einsumcalling conventions to compute the sameresult:>>>jnp.einsum('ij,j',M,x)# implicit formArray([14, 38, 62, 86], dtype=int32)>>>jnp.einsum(M,(0,1),x,(1,),(0,))# explicit form via indicesArray([14, 38, 62, 86], dtype=int32)>>>jnp.einsum(M,(0,1),x,(1,))# implicit form via indicesArray([14, 38, 62, 86], dtype=int32)
Outer product
>>>jnp.einsum("i,j->ij",x,y)Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)>>>jnp.outer(x,y)Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)
Some other ways of computing outer products:
>>>jnp.einsum("i,j",x,y)# implicit formArray([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)>>>jnp.einsum(x,(0,),y,(1,),(0,1))# explicit form via indicesArray([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)>>>jnp.einsum(x,(0,),y,(1,))# implicit form via indicesArray([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)
1D array sum
>>>jnp.einsum("i->",x)# requires explicit formArray(6, dtype=int32)>>>jnp.einsum(x,(0,),())# explicit form via indicesArray(6, dtype=int32)>>>jnp.sum(x)Array(6, dtype=int32)
Sum along an axis
>>>jnp.einsum("...j->...",M)# requires explicit formArray([ 6, 22, 38, 54], dtype=int32)>>>jnp.einsum(M,(...,0),(...,))# explicit form via indicesArray([ 6, 22, 38, 54], dtype=int32)>>>M.sum(-1)Array([ 6, 22, 38, 54], dtype=int32)
Matrix transpose
>>>y=jnp.array([[1,2,3],...[4,5,6]])>>>jnp.einsum("ij->ji",y)# explicit formArray([[1, 4], [2, 5], [3, 6]], dtype=int32)>>>jnp.einsum("ji",y)# implicit formArray([[1, 4], [2, 5], [3, 6]], dtype=int32)>>>jnp.einsum(y,(1,0))# implicit form via indicesArray([[1, 4], [2, 5], [3, 6]], dtype=int32)>>>jnp.einsum(y,(0,1),(1,0))# explicit form via indicesArray([[1, 4], [2, 5], [3, 6]], dtype=int32)>>>jnp.transpose(y)Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
Matrix diagonal
>>>jnp.einsum("ii->i",M)Array([ 0, 5, 10, 15], dtype=int32)>>>jnp.diagonal(M)Array([ 0, 5, 10, 15], dtype=int32)
Matrix trace
>>>jnp.einsum("ii",M)Array(30, dtype=int32)>>>jnp.trace(M)Array(30, dtype=int32)
Tensor products
>>>x=jnp.arange(30).reshape(2,3,5)>>>y=jnp.arange(60).reshape(3,4,5)>>>jnp.einsum('ijk,jlk->il',x,y)# explicit formArray([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32)>>>jnp.tensordot(x,y,axes=[(1,2),(0,2)])Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32)>>>jnp.einsum('ijk,jlk',x,y)# implicit formArray([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32)>>>jnp.einsum(x,(0,1,2),y,(1,3,2),(0,3))# explicit form via indicesArray([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32)>>>jnp.einsum(x,(0,1,2),y,(1,3,2))# implicit form via indicesArray([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32)
Chained dot products
>>>w=jnp.arange(5,9).reshape(2,2)>>>x=jnp.arange(6).reshape(2,3)>>>y=jnp.arange(-2,4).reshape(3,2)>>>z=jnp.array([[2,4,6],[3,5,7]])>>>jnp.einsum('ij,jk,kl,lm->im',w,x,y,z)Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32)>>>jnp.einsum(w,(0,1),x,(1,2),y,(2,3),z,(3,4))# implicit, via indicesArray([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32)>>>w@x@y@z# direct chain of matmulsArray([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32)>>>jnp.linalg.multi_dot([w,x,y,z])Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32)
