Default dtypes and the X64 flag
Contents
Default dtypes and the X64 flag#
JAX strives to meet the needs of a range of numerical computing practitioners, whosometimes have conflicting preferences. When it comes to default dtypes, there aretwo different camps:
Classic scientific computing practitioners (i.e. users of tools like
numpyorscipy) tend to value accuracy of computations foremost: such users wouldprefer that computations default to thewidest available representation: e.g.floating point values should default tofloat64, integers toint64, etc.AI researchers (i.e. folks implementing and training neural networks) tend to valuespeed over accuracy, to the point where they have developed special data types likebfloat16 and otherswhich deliberately discard the least significant bits in order to speed up computation.For these users, the mere presence of a float64 value in their computation can leadto programs that are slow at best, and incompatible with their hardware at worst!These users would prefer that computations default to
float32orint32.
The main mechanism JAX offers for this is thejax_enable_x64 flag, which controlswhether 64-bit values can be created at all. By default this flag is set toFalse(serving the needs of AI researchers and practitioners), but can be set toTrueby users who value accuracy over computational speed.
Default setting: 32-bits everywhere#
By defaultjax_enable_x64 is set to False, and sojax.numpy array creationfunctions will default to returning 32-bit values.
For example:
>>>importjax.numpyasjnp>>>jnp.arange(5)Array([0, 1, 2, 3, 4], dtype=int32)>>>jnp.zeros(5)Array([0., 0., 0., 0., 0.], dtype=float32)>>>jnp.ones(5,dtype=int)Array([1, 1, 1, 1, 1], dtype=int32)
Beyond defaults, because 64-bit values can be so poisonous to AI workflows, havingthis flag set to False prevents you from creating 64-bit arrays at all! For example:
>>>jnp.arange(5,dtype='float64')UserWarning: Explicitly requested dtype float64 requested in arange is not available, and will betruncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or theJAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.Array([0., 1., 2., 3., 4.], dtype=float32)
The X64 flag: enabling 64-bit values#
To work in the “other mode” where functions default to producing 64-bit values, you can set thejax_enable_x64 flag toTrue:
importjaximportjax.numpyasjnpjax.config.update('jax_enable_x64',True)print(repr(jnp.arange(5)))print(repr(jnp.zeros(5)))print(repr(jnp.ones(5,dtype=int)))
Array([0,1,2,3,4],dtype=int64)Array([0.,0.,0.,0.,0.],dtype=float64)Array([1,1,1,1,1],dtype=int64)
The X64 configuration can also be set via theJAX_ENABLE_X64 shell environment variable,for example:
$JAX_ENABLE_X64=1pythonmain.py
The X64 flag is intended as aglobal setting that should have one value for your wholeprogram, set at the top of your main file. A common feature request is for the flag tobe contextually configurable (e.g. enabling X64 just for one section of a long program):this turns out to be difficult to implement within JAX’s programming model, where codeexecution may happen in a different context than code compilation. There is ongoing workexploring the feasibility of relaxing this constraint, so stay tuned!
