Rank promotion warning
Rank promotion warning#
NumPy broadcasting rulesallow the automatic promotion of arguments from one rank (number of array axes)to another. This behavior can be convenient when intended but can also lead tosurprising bugs where a silent rank promotion masks an underlying shape error.
Here’s an example of rank promotion:
>>>fromjaximportnumpyasjnp>>>x=jnp.arange(12).reshape(4,3)>>>y=jnp.array([0,1,0])>>>x+yArray([[ 0, 2, 2], [ 3, 5, 5], [ 6, 8, 8], [ 9, 11, 11]], dtype=int32)
To avoid potential surprises,jax.numpy is configurable so thatexpressions requiring rank promotion can lead to a warning, error, or can beallowed just like regular NumPy. The configuration option is namedjax_numpy_rank_promotion and it can take on string valuesallow,warn, andraise. The default setting isallow, which allows rank promotion without warning or error.Theraise setting raises an error on rank promotion, andwarnraises a warning on the first occurrence of rank promotion.
Rank promotion can be enabled or disabled locally with thejax.numpy_rank_promotion()context manager:
withjax.numpy_rank_promotion("warn"):z=x+y
This configuration can also be set globally in several ways.One is by usingjax.config in your code:
importjaxjax.config.update("jax_numpy_rank_promotion","warn")
You can also set the option using the environment variableJAX_NUMPY_RANK_PROMOTION, for example asJAX_NUMPY_RANK_PROMOTION='warn'. Finally, when usingabsl-pythe option can be set with a command-line flag.
