Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.einsum_path

Contents

jax.numpy.einsum_path#

jax.numpy.einsum_path(subscripts,/,*operands,optimize='auto')[source]#

Evaluates the optimal contraction path without evaluating the einsum.

JAX implementation ofnumpy.einsum_path(). This function calls intotheopt_einsum package, and makes use of its optimization routines.

Parameters:
  • subscripts – string containing axes names separated by commas.

  • *operands – sequence of one or more arrays corresponding to the subscripts.

  • optimize (bool |str |list[tuple[int,...]]) – specify how to optimize the order of computation. In JAX this defaultsto"auto". Other options areTrue (same as"optimize"),False(unoptimized), or any string supported byopt_einsum, whichincludes"optimize",,"greedy","eager", and others.

Returns:

A tuple containing the path that may be passed toeinsum(), and aprintable object representing this optimal path.

Return type:

tuple[list[tuple[int, …]],Any]

Examples

>>>key1,key2,key3=jax.random.split(jax.random.key(0),3)>>>x=jax.random.randint(key1,minval=-5,maxval=5,shape=(2,3))>>>y=jax.random.randint(key2,minval=-5,maxval=5,shape=(3,100))>>>z=jax.random.randint(key3,minval=-5,maxval=5,shape=(100,5))>>>path,path_info=jnp.einsum_path("ij,jk,kl",x,y,z,optimize="optimal")>>>print(path)[(1, 2), (0, 1)]>>>print(path_info)      Complete contraction:  ij,jk,kl->il            Naive scaling:  4        Optimized scaling:  3          Naive FLOP count:  9.000e+3      Optimized FLOP count:  3.060e+3      Theoretical speedup:  2.941e+0      Largest intermediate:  1.500e+1 elements    --------------------------------------------------------------------------------    scaling        BLAS                current                             remaining    --------------------------------------------------------------------------------      3           GEMM              kl,jk->lj                             ij,lj->il      3           GEMM              lj,ij->il                                il->il

Use the computed path ineinsum():

>>>jnp.einsum("ij,jk,kl",x,y,z,optimize=path)Array([[-754,  324, -142,   82,   50],       [ 408,  -50,   87,  -29,    7]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp