Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Type promotion semantics#

This document describes JAX’s type promotion rules–i.e., the result ofjax.numpy.promote_types() for each pair of types.For some background on the considerations that went into the design of what is described below, seeDesign of Type Promotion Semantics for JAX.

JAX’s type promotion behavior is determined via the following type promotion lattice:

_images/type_lattice.svg

where, for example:

  • b1 meansnp.bool_,

  • i2 meansnp.int16,

  • u4 meansnp.uint32,

  • bf meansnp.bfloat16,

  • f2 meansnp.float16,

  • c8 meansnp.complex64,

  • i* means Pythonint or weakly-typedint,

  • f* means Pythonfloat or weakly-typedfloat, and

  • c* means Pythoncomplex or weakly-typedcomplex.

(for more about weak types, seeWeakly-typed values in JAX below).

Promotion between any two types is given by theirjoinon this lattice, which generates the following binary promotion table:

b1u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
b1b1u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
u1u1u1u2u4u8i2i2i4i8bff2f4f8c8c16u1f*c*
u2u2u2u2u4u8i4i4i4i8bff2f4f8c8c16u2f*c*
u4u4u4u4u4u8i8i8i8i8bff2f4f8c8c16u4f*c*
u8u8u8u8u8u8f*f*f*f*bff2f4f8c8c16u8f*c*
i1i1i2i4i8f*i1i2i4i8bff2f4f8c8c16i1f*c*
i2i2i2i4i8f*i2i2i4i8bff2f4f8c8c16i2f*c*
i4i4i4i4i8f*i4i4i4i8bff2f4f8c8c16i4f*c*
i8i8i8i8i8f*i8i8i8i8bff2f4f8c8c16i8f*c*
bfbfbfbfbfbfbfbfbfbfbff4f4f8c8c16bfbfc8
f2f2f2f2f2f2f2f2f2f2f4f2f4f8c8c16f2f2c8
f4f4f4f4f4f4f4f4f4f4f4f4f4f8c8c16f4f4c8
f8f8f8f8f8f8f8f8f8f8f8f8f8f8c16c16f8f8c16
c8c8c8c8c8c8c8c8c8c8c8c8c8c16c8c16c8c8c8
c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16
i*i*u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
f*f*f*f*f*f*f*f*f*f*bff2f4f8c8c16f*f*c*
c*c*c*c*c*c*c*c*c*c*c8c8c8c16c8c16c*c*c*

Jax’s type promotion rules differ from those of NumPy, as given bynumpy.promote_types(), in those cells highlighted with a green backgroundin the table above. There are three key classes of differences:

  • When promoting a weakly typed value against a typed JAX value of the same category,JAX always prefers the precision of the JAX value. For example,jnp.int16(1)+1will returnint16 rather than promoting toint64 as in NumPy.Note that this applies only to Python scalar values; if the constant is a NumPyarray then the above lattice is used for type promotion.For example,jnp.int16(1)+np.array(1) will returnint64.

  • When promoting an integer or boolean type against a floating-point or complextype, JAX always prefers the type of the floating-point or complex type.

  • JAX supports thebfloat16non-standard 16-bit floating point type(jax.numpy.bfloat16), which is useful for neural network training.The only notable promotion behavior is with respect to IEEE-754float16, with whichbfloat16 promotes to afloat32.

The differences between NumPy and JAX are motivated by the fact thataccelerator devices, such as GPUs and TPUs, either pay a significantperformance penalty to use 64-bit floating point types (GPUs) or do notsupport 64-bit floating point types at all (TPUs). Classic NumPy’s promotionrules are too willing to overpromote to 64-bit types, which is problematic fora system designed to run on accelerators.

JAX uses floating point promotion rules that are more suited to modernaccelerator devices and are less aggressive about promoting floating pointtypes. The promotion rules used by JAX for floating-point types are similar tothose used by PyTorch.

Effects of Python operator dispatch#

Keep in mind that Python operators like+ will dispatch based on the Python type ofthe two values being added. This means that, for example,np.int16(1)+1 willpromote using NumPy rules, whereasjnp.int16(1)+1 will promote using JAX rules.This can lead to potentially confusing non-associative promotion semantics whenthe two types of promotion are combined;for example withnp.int16(1)+1+jnp.int16(1).

Weakly-typed values in JAX#

Weakly-typed values in JAX can in most cases be thought of as having promotion behaviorequivalent to that of Python scalars, such as the integer scalar2 in the following:

>>>x=jnp.arange(5,dtype='int8')>>>2*xArray([0, 2, 4, 6, 8], dtype=int8)

JAX’s weak type framework is designed to prevent unwanted type promotion withinbinary operations between JAX values and values with no explicitly user-specified type,such as Python scalar literals. For example, if2 were not treated as weakly-typed,the expression above would lead to an implicit type promotion:

>>>jnp.int32(2)*xArray([0, 2, 4, 6, 8], dtype=int32)

When used in JAX, Python scalars are sometimes promoted toDeviceArrayobjects, for example during JIT compilation. To maintain the desired promotionsemantics in this case,DeviceArray objects carry aweak_type flagthat can be seen in an array’s string representation:

>>>jnp.asarray(2)Array(2, dtype=int32, weak_type=True)

If thedtype is specified explicitly, it will instead result in a standardstrongly-typed array value:

>>>jnp.asarray(2,dtype='int32')Array(2, dtype=int32)

Strict dtype promotion#

In some contexts it can be useful to disable implicit type promotion behavior, andinstead require all promotions to be explicit. This can be done in JAX by setting thejax_numpy_dtype_promotion flag to'strict'. Locally, it can be done with acontext manager:

>>>x=jnp.float32(1)>>>y=jnp.int32(1)>>>withjax.numpy_dtype_promotion('strict'):...z=x+y...Traceback (most recent call last):TypePromotionError:Input dtypes ('float32', 'int32') have no available implicitdtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly castinginputs to the desired output type, or set jax_numpy_dtype_promotion=standard.

For convenience, strict promotion mode will still allow safe weakly-typed promotions,so you can still write code code that mixes JAX arrays and Python scalars:

>>>withjax.numpy_dtype_promotion('strict'):...z=x+1>>>print(z)2.0

If you would prefer to set the configuration globally, you can do so using the standardconfiguration update:

jax.config.update('jax_numpy_dtype_promotion','strict')

To restore the default standard type promotion, set this configuration to'standard':

jax.config.update('jax_numpy_dtype_promotion','standard')

[8]ページ先頭

©2009-2026 Movatter.jp