Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Rank promotion warning

Rank promotion warning#

NumPy broadcasting rulesallow the automatic promotion of arguments from one rank (number of array axes)to another. This behavior can be convenient when intended but can also lead tosurprising bugs where a silent rank promotion masks an underlying shape error.

Here’s an example of rank promotion:

>>>fromjaximportnumpyasjnp>>>x=jnp.arange(12).reshape(4,3)>>>y=jnp.array([0,1,0])>>>x+yArray([[ 0,  2,  2],       [ 3,  5,  5],       [ 6,  8,  8],       [ 9, 11, 11]], dtype=int32)

To avoid potential surprises,jax.numpy is configurable so thatexpressions requiring rank promotion can lead to a warning, error, or can beallowed just like regular NumPy. The configuration option is namedjax_numpy_rank_promotion and it can take on string valuesallow,warn, andraise. The default setting isallow, which allows rank promotion without warning or error.Theraise setting raises an error on rank promotion, andwarnraises a warning on the first occurrence of rank promotion.

Rank promotion can be enabled or disabled locally with thejax.numpy_rank_promotion()context manager:

withjax.numpy_rank_promotion("warn"):z=x+y

This configuration can also be set globally in several ways.One is by usingjax.config in your code:

importjaxjax.config.update("jax_numpy_rank_promotion","warn")

You can also set the option using the environment variableJAX_NUMPY_RANK_PROMOTION, for example asJAX_NUMPY_RANK_PROMOTION='warn'. Finally, when usingabsl-pythe option can be set with a command-line flag.


[8]ページ先頭

©2009-2026 Movatter.jp