Change log
Contents
Change log#
Best viewedhere.For the changes specific to the experimental Pallas APIs,seePallas Changelog.
JAX follows Effort-based versioning; for a discussion of this and JAX’s APIcompatibility policy, refer toAPI compatibility. For the Python andNumPy version support policy, refer toPython and NumPy version support policy.
Unreleased#
Changes:
JAX tracers that are not of
Arraytype (e.g., ofReftype) will nolonger report themselves to be instances ofArray.Using
jax.shard_mapin Explicit mode will raise an errorif the PartitionSpec of input does not match the PartitionSpec specified inin_specs. In other words, it will act like an assert instead of animplicit reshard.in_specsis an optional argument so you can omit specifying itandshard_mapwill infer thePartitionSpecfrom the argument. If youwant to reshard your inputs, you can usejax.reshardon the arguments andthen pass those args to shard_map.
New features:
Added a debug config
jax_compilation_cache_check_contents. If set, we misswhenget()is called on a value that has not beenput()by the currentprocess, even if the value is actually in the disk cache. When a value isput(), we verify that its contents match.
JAX 0.9.0 (January 20, 2026)#
New features:
Added
jax.thread_guard(), a context manager that detects when devicesare used by multiple threads in multi-controller JAX.
Bug fixes:
Fixed a workspace size calculation error for pivoted QR (
magma_zgeqp3_gpu)in MAGMA 2.9.0 when usinguse_magma=Trueandpivoting=True.(#34145).
Deprecations:
The flag
jax_collectives_common_channel_idwas removed.The
jax_pmap_no_rank_reductionconfig state has been removed. Theno-rank-reduction behavior is now the only supported behavior: ajax.pmapped functionfsees inputs of the same rank as the input tojax.pmap(f). For example, ifjax.pmap(f)receives shape(8,128)on8 devices, thenfreceives shape(1,128).Setting the
jax_pmap_shmap_mergeconfig state is deprecated in JAX v0.9.0and will be removed in JAX v0.10.0.jax.numpy.fix()is deprecated, anticipating the deprecation ofnumpy.fix()in NumPy v2.5.0.jax.numpy.trunc()is a drop-inreplacement.
Changes:
jax.export()now supports explicit sharding. This required a newexport serialization format version that includes the NamedSharding,including the abstract mesh, and the partition spec. As part of thischange we have added a restriction in the use of exported modules: whencalling them the abstract mesh must match the one used at export time,including the axis names. Previously, only the number of the devicesmattered.
JAX 0.8.2 (December 18, 2025)#
Deprecations
jax.lax.pvaryhas been deprecated.Please usejax.lax.pcast(...,to='varying')as the replacement.Complex arguments passed to
jax.numpy.arange()now result in adeprecation warning, because the output is poorly-defined.From
jax.corea number of symbols are newly deprecated including:call_impl,get_aval,mapped_aval,subjaxprs,set_current_trace,take_current_trace,traverse_jaxpr_params,unmapped_aval,AbstractToken, andTraceTag.All symbols in
jax.interpreters.pxlaare deprecated. These areprimarily JAX internal APIs, and users should not rely on them.
Changes:
jax’s
Tracerno longer inherits fromjax.Arrayat runtime. However,jax.Arraynow uses a custom metaclass suchisinstance(x,Array)is trueif an objectxrepresents a tracedArray. Only someTracers representArrays, so it is not correct forTracerto inherit fromArray.For the moment, during Python type checking, we continue to declare
Traceras a subclass ofArray, however we expect to remove this in a futurerelease.jax.experimental.si_vjphas been deleted.jax.vjpsubsumes it’s functionality.
JAX 0.8.1 (November 18, 2025)#
New features:
jax.jit()now supports the decorator factory pattern; i.e instead ofwriting@functools.partial(jax.jit,static_argnames=['n'])deff(x,n):...
you may write
@jax.jit(static_argnames=['n'])deff(x,n):...
Changes:
jax.lax.linalg.eigh()now accepts animplementationargument toselect between QR (CPU/GPU), Jacobi (GPU/TPU), and QDWH (TPU)implementations. TheEighImplementationenum is publicly exported fromjax.lax.linalg.jax.lax.linalg.svd()now implements analgorithmthat uses the polardecomposition on CUDA GPUs. This is also an alias for the existing algorithmon TPUs.
Bug fixes:
Fixed a bug introduced in JAX 0.7.2 where eigh failed for large matrices onGPU ((#33062).
Deprecations:
jax.sharding.PmapShardingis now deprecated. Please usejax.NamedShardinginstead.jx.device_put_replicatedis now deprecated. Please usejax.device_putwith the appropriate sharding instead.jax.device_put_shardedis now deprecated. Please usejax.device_putwiththe appropriate sharding instead.Default
axis_typesofjax.make_meshwill change in JAX v0.9.0 to returnjax.sharding.AxisType.Explicit. Leaving axis_types unspecified will raise aDeprecationWarning.jax.cloud_tpu_initand its contents were deprecated. There is no reason for a user to import or use the contents of this module; JAX handles this for you automatically if needed.
JAX 0.8.0 (October 15, 2025)#
Breaking changes:
JAX is changing the default
jax.pmapimplementation to one implemented interms ofjax.jitandjax.shard_map.jax.pmapis in maintenance modeand we encourage all new code to usejax.shard_mapdirectly. See themigration guide formore information.The
auto=parameter ofjax.experimental.shard_map.shard_maphas beenremoved. This means thatjax.experimental.shard_map.shard_mapno longersupports nesting. If you want to nest shard_map calls, please usejax.shard_map.JAX no longer allows passing objects that support
__jax_array__directlyto, e.g.jit-ed functions. Calljax.numpy.asarrayon them first.jax.numpy.cov()is now returns NaN for empty arrays (#32305),and matches NumPy 2.2 behavior for single-row design matrices (#32308).JAX no longer accepts
Arrayvalues where adtypevalue is expected. Call.dtypeon these values first.The deprecated function
jax.interpreters.mlir.custom_call()wasremoved.The
jax.util,jax.extend.ffi, andjax.experimental.host_callbackmodules have been removed. All public APIs within these modules weredeprecated and removed in v0.7.0 or earlier.The deprecated symbol
jax.custom_derivatives.custom_jvp_call_jaxpr_pwas removed.jax.experimental.multihost_utils.process_allgatherraises an error whenthe input is a jax.Array and not fully-addressable andtiled=False. To fixthis, passtiled=Trueto yourprocess_allgatherinvocation.from
jax.experimental.compilation_cache, the deprecated symbolsis_initializedandinitialize_cachewere removed.The deprecated function
jax.interpreters.xla.canonicalize_dtype()was removed.jaxlib.hlo_helpershas been removed. Usejax.ffiinstead.The option
jax_cpu_enable_gloo_collectiveshas been removed. Usejax_cpu_collectives_implementationinstead.The previously-deprecated
interpolationargument tojax.numpy.percentile()andjax.numpy.quantile()has beenremoved; usemethodinstead.The JAX-internal
for_loopprimitive was removed. Its functionality,reading from and writing to refs in the loop body, is now directlysupported byjax.lax.fori_loop(). If you need help updating yourcode, please file a bug.jax.numpy.trimzeros()now errors for non-1D input.The
whereargument tojax.numpy.sum()and other reductions is nowrequired to be boolean. Non-boolean values have resulted in aDeprecationWarningsince JAX v0.5.0.The deprecated functions in {mod}
jax.dlpack, {mod}jax.errors, {mod}jax.lib.xla_bridge, {mod}jax.lib.xla_client, and {mod}jax.lib.xla_extensionwere removed.jax.interpreters.mlir.dense_bool_arraywas removed. Use MLIR APIs toconstruct attributes instead.
Changes
jax.numpy.linalg.eig()now returns a namedtuple (with attributeseigenvaluesandeigenvectors) instead of a plain tuple.jax.grad()andjax.vjp()will now round always primals tofloat32iffloat64mode is not enabled.jax.dlpack.from_dlpack()now accepts arrays with non-default layouts,for example, transposed.The default nonsymmetric eigendecomposition on NVIDIA GPUs now usescusolver. The magma and LAPACK implementations are still available via thenew
implementationargument tojax.lax.linalg.eig()(#27265). Theuse_magmaargument is now deprecated in favorofimplementation.jax.numpy.trim_zeros()now follows NumPy 2.2 in supportingmulti-dimensional inputs.
Deprecations
jax.experimental.enable_x64()andjax.experimental.disable_x64()are deprecated in favor of the new non-experimental context managerjax.enable_x64().jax.experimental.shard_map.shard_map()is deprecated; going forward usejax.shard_map().jax.experimental.pjit.pjit()is deprecated; going forward usejax.jit().
JAX 0.7.2 (September 16, 2025)#
Breaking changes:
jax.dlpack.from_dlpack()no longer accepts a DLPack capsule. Thisbehavior was deprecated and is now removed. The function must be calledwith an array implementing__dlpack__and__dlpack_device__.
Changes
The minimum supported NumPy version is now 2.0. Since SciPy 1.13 is requiredfor NumPy 2.0 support, the minimum supported SciPy version is now 1.13.
JAX now represents constants in its internal jaxpr representation as a
TypedNdArray, which is a private JAX type that duck types as anumpy.ndarray. This type may be exposed to users viacustom_jvprules,for example, and may break code that usesisinstance(x,np.ndarray). Ifthis breaks your code, you may convert these arrays to classic NumPy arraysusingnp.asarray(x).
Bug fixes
arr.view(dtype=None)now returns the array unchanged, matching NumPy’ssemantics. Previously it returned the array with a float dtype.jax.random.randintnow produces a less-biased distribution for 8-bit and16-bit integer types (#27742). To restore the previous biasedbehavior, you may temporarily set thejax_safer_randintconfiguration toFalse, but note this is a temporary config that will be removed in afuture release.
Deprecations:
The parameters
enable_xlaandnative_serializationforjax2tf.convertare deprecated and will be removed in a future version of JAX. These wereused for jax2tf with non-native serialization, which has been now removed.Setting the config state
jax_pmap_no_rank_reductiontoFalseisdeprecated. By default,jax_pmap_no_rank_reductionwill be set toTrueandjax.pmapshards will not have their rank reduced, keeping the samerank as their enclosing array.
JAX 0.7.1 (August 20, 2025)#
New features
JAX now ships Python 3.14 and 3.14t wheels.
JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we onlyoffered free-threading builds on Linux.
Changes
Exposed
jax.set_meshwhich acts as a global setter and a context manager.Removedjax.sharding.use_meshin favor ofjax.set_mesh.JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remainsupported.
jax.lax.dot()now implements the general dot product via the optionaldimension_numbersargument.
Deprecations:
jax.lax.zeros_like_array()is deprecated. Please usejax.numpy.zeros_like()instead.Attempting to import
jax.experimental.host_callbacknow results inaDeprecationWarning, and will result in anImportErrorstarting in JAXv0.8.0. Its APIs have raisedNotImplementedErrorsince JAX version 0.4.35.In
jax.lax.dot(), passing theprecisionandpreferred_element_typearguments by position is deprecated. Pass them by explicit keyword instead.Several dozen internal APIs have been deprecated from
jax.interpreters.ad,jax.interpreters.batching, andjax.interpreters.partial_eval; theyare used rarely if ever outside JAX itself, and most are deprecated without anypublic replacement.
JAX 0.7.0 (July 22, 2025)#
New features:
Added
jax.Pwhich is an alias forjax.sharding.PartitionSpec.The
jax.numpy.ndarray.atindexing methods now support awrap_negative_indicesargument, which defaults toTrueto match the current behavior (#29434).
Breaking changes:
JAX is migrating from GSPMD to Shardy by default. See themigration guidefor more information.
JAX autodiff is switching to using direct linearization by default (instead ofimplementing linearization via JVP and partial eval).Seemigration guidefor more information.
jax.stages.OutInfohas been replaced withjax.ShapeDtypeStruct.jax.jit()now requiresfunto be passed by position, and additionalarguments to be passed by keyword. Doing otherwise will result in an errorstarting in v0.7.x. This raised a DeprecationWarning in v0.6.x.The minimum Python version is now 3.11. 3.11 will remain the minimumsupported version until July 2026.
Layout API renames:
Layout,.layout,.input_layoutsand.output_layoutshave beenrenamed toFormat,.format,.input_formatsand.output_formatsDeviceLocalLayout,.device_local_layouthave been renamed toLayoutand.layout
jax.experimental.shardmodule has been deleted and all the APIs have beenmoved to thejax.shardingendpoint. So usejax.sharding.reshard,jax.sharding.auto_axesandjax.sharding.explicit_axesinstead of theirexperimental endpoints.lax.infeedandlax.outfeedwere removed, after being deprecated inJAX 0.6. Thetransfer_to_infeedandtransfer_from_outfeedmethods werealso removed theDeviceobjects.The
jax.extend.core.primitives.pjit_pprimitive has been renamed tojit_p, and itsnameattribute has changed from"pjit"to"jit".This affects the string representations of jaxprs. The same primitive is nolonger exported from thejax.experimental.pjitmodule.The (undocumented) function
jax.extend.backend.add_clear_backends_callbackhas been removed. Users should usejax.extend.backend.register_backend_cacheinstead.out_shardingarg added tox.at[y].setandx.at[y].add. Previousbehavior propagating operand sharding removed. Please usex.at[y].set/add(z,out_sharding=jax.typeof(x).sharding)to retain previousbehavior if scatter op requires collectives.
Deprecations:
jax.dlpack.SUPPORTED_DTYPESis deprecated; please use the newjax.dlpack.is_supported_dtype()function.jax.scipy.special.sph_harm()has been deprecated following a similardeprecation in SciPy; usejax.scipy.special.sph_harm_y()instead.From
jax.interpreters.xla, the previously deprecated symbolsabstractifyandpytype_aval_mappingshave been removed.jax.interpreters.xla.canonicalize_dtype()is deprecated. Forcanonicalizing dtypes, preferjax.dtypes.canonicalize_dtype().For checking whether an object is a valid jax input, preferjax.core.valid_jaxtype().From
jax.core, the previously deprecated symbolsAxisName,ConcretizationTypeError,axis_frame,call_p,closed_call_p,get_type,trace_state_clean,typematch, andtypecheckhave beenremoved.From
jax.lib.xla_client, the previously deprecated symbolsDeviceAssignment,get_topology_for_devices, andmlir_api_versionhave been removed.jax.extend.ffiwas removed after being deprecated in v0.5.0.Usejax.ffiinstead.jax.lib.xla_bridge.get_compile_options()is deprecated, and replaced byjax.extend.backend.get_compile_options().
JAX 0.6.2 (June 17, 2025)#
New features:
Added
jax.tree.broadcast()which implements a pytree prefix broadcasting helper.
Changes
The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12.
JAX 0.6.1 (May 21, 2025)#
New features:
Added
jax.lax.axis_size()which returns the size of the mapped axisgiven its name.
Changes
Additional checking for the versions of CUDA package dependencies wasre-enabled, having been accidentally disabled in a previous release.
JAX nightly packages are now published to artifact registry. To installthese packages, see theJAX installation guide.
jax.sharding.PartitionSpecno longer inherits from a tuple.jax.ShapeDtypeStructis immutable now. Please use.updatemethod toupdate yourShapeDtypeStructinstead of doing in-place updates.
Deprecations
jax.custom_derivatives.custom_jvp_call_jaxpr_pis deprecated, and will beremoved in JAX v0.7.0.
JAX 0.6.0 (April 16, 2025)#
Breaking changes
jax.numpy.array()no longer acceptsNone. This behavior wasdeprecated since November 2023 and is now removed.Removed the
config.jax_data_dependent_tracing_fallbackconfig option,which was added temporarily in v0.4.36 to allow users to opt out of thenew “stackless” tracing machinery.Removed the
config.jax_eager_pmapconfig option.Disallow the calling of
lowerandtraceAOT APIs on the resultofjax.jitif there have been subsequent wrappers applied.Previously this worked, but silently ignored the wrappers.The workaround is to applyjax.jitlast among the wrappers,and similarly forjax.pmap.See#27873.The
cuda12_pipextra forjaxhas been removed; usepipinstalljax[cuda12]instead.
Changes
The minimum CuDNN version is v9.8.
JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remainsupported.
JAX package extras are now updated to use dash instead of underscore toalign with PEP 685. For instance, if you were previously using
pipinstalljax[cuda12_local]to install JAX, runpipinstalljax[cuda12-local]instead.jax.jit()now requiresfunto be passed by position, and additionalarguments to be passed by keyword. Doing otherwise will result in aDeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
Deprecations
jax.tree_util.build_tree()is deprecated. Usejax.tree.unflatten()instead.Implemented host callback handlers for CPU and GPU devices using XLA’s FFIand removed existing CPU/GPU handlers using XLA’s custom call.
All APIs in
jax.lib.xla_extensionare now deprecated.jax.interpreters.mlir.hloandjax.interpreters.mlir.func_dialect,which were accidental exports, have been removed. If needed, they areavailable fromjax.extend.mlir.jax.interpreters.mlir.custom_callis deprecated. The APIs provided byjax.ffishould be used instead.The deprecated use of
jax.ffi.ffi_call()with inline arguments is nolonger supported.ffi_call()now unconditionally returns acallable.The following exports in
jax.lib.xla_clientare deprecated:get_topology_for_devices,heap_profile,mlir_api_version,Client,CompileOptions,DeviceAssignment,Frame,HloSharding,OpSharding,Traceback.The following internal APIs in
jax.utilare deprecated:HashableFunction,as_hashable_function,cache,safe_map,safe_zip,split_dict,split_list,split_list_checked,split_merge,subvals,toposort,unzip2,wrap_name, andwraps.jax.dlpack.to_dlpackhas been deprecated. You can usually pass a JAXArraydirectly to thefrom_dlpackfunction of another framework. If youneed the functionality ofto_dlpack, use the__dlpack__attribute of anarray.jax.lax.infeed,jax.lax.infeed_p,jax.lax.outfeed, andjax.lax.outfeed_pare deprecated and will be removed in JAX v0.7.0.Several previously-deprecated APIs have been removed, including:
From
jax.lib.xla_client:ArrayImpl,FftType,PaddingType,PrimitiveType,XlaBuilder,dtype_to_etype,ops,register_custom_call_target,shape_from_pyval,Shape,XlaComputation.From
jax.lib.xla_extension:ArrayImpl,XlaRuntimeError.From
jax:jax.treedef_is_leaf,jax.tree_flatten,jax.tree_map,jax.tree_leaves,jax.tree_structure,jax.tree_transpose, andjax.tree_unflatten. Replacements can be found injax.treeorjax.tree_util.From
jax.core:AxisSize,ClosedJaxpr,EvalTrace,InDBIdx,InputType,Jaxpr,JaxprEqn,Literal,MapPrimitive,OpaqueTraceState,OutDBIdx,Primitive,Token,TRACER_LEAK_DEBUGGER_WARNING,Var,concrete_aval,dedup_referents,escaped_tracer_error,extend_axis_env_nd,full_lower,get_referent,jaxpr_as_fun,join_effects,lattice_join,leaked_tracer_error,maybe_find_leaked_tracers,raise_to_shaped,raise_to_shaped_mappings,reset_trace_state,str_eqn_compact,substitute_vars_in_output_ty,typecompat, andused_axis_names_jaxpr. Mosthave no public replacement, though a few are available atjax.extend.core.The
vectorizedargument topure_callback()andffi_call(). Use thevmap_methodparameter instead.
jax 0.5.3 (Mar 19, 2025)#
New Features
Added a
allow_negative_indicesoption tojax.lax.dynamic_slice(),jax.lax.dynamic_update_slice()and related functions. The default istrue, matching the current behavior. If set to false, JAX does not need toemit code clamping negative indices, which improves code size.Added a
replaceoption tojax.random.categorical()to enable samplingwithout replacement.
jax 0.5.2 (Mar 4, 2025)#
Patch release of 0.5.1
Bug fixes
Fixes TPU metric logging and
tpu-info, which was broken in 0.5.1
jax 0.5.1 (Feb 24, 2025)#
Breaking changes
The jit tracing cache now keys on input NamedShardings. Previously, thetracing cache did not include sharding information at all(although subsequent jit caches did like lowering and compilation caches),so two equivalent shardings of different types would not retrace,but now they do. For example:
@jax.jitdeff(x):returnx# inp1.sharding is of type SingleDeviceShardinginp1=jnp.arange(8)f(inp1)mesh=jax.make_mesh((1,),('x',))# inp2.sharding is of type NamedShardinginp2=jax.device_put(jnp.arange(8),NamedSharding(mesh,P('x')))f(inp2)# tracing cache miss
In the above example, calling
f(inp1)and thenf(inp2)will lead to atracing cache miss because the shardings have changed on the abstract valueswhile tracing.
New Features
Added an experimental
jax.experimental.custom_dce.custom_dce()decorator to support customizing the behavior of opaque functions underJAX-level dead code elimination (DCE). See#25956 for moredetails.Added low-level reduction APIs in
jax.lax:jax.lax.reduce_sum(),jax.lax.reduce_prod(),jax.lax.reduce_max(),jax.lax.reduce_min(),jax.lax.reduce_and(),jax.lax.reduce_or(), andjax.lax.reduce_xor().jax.lax.linalg.qr(), andjax.scipy.linalg.qr(), now supportcolumn-pivoting on CPU and GPU. See#20282 andAdded
jax.random.multinomial().#25955 for more details.
Changes
JAX_CPU_COLLECTIVES_IMPLEMENTATIONandJAX_NUM_CPU_DEVICESnow work asenv vars. Before they could only be specified via jax.config or flags.JAX_CPU_COLLECTIVES_IMPLEMENTATIONnow defaults to'gloo', meaningmulti-process CPU communication works out-of-the-box.The
jax[tpu]TPU extra no longer depends on thelibtpu-nightlypackage.This package may safely be removed if it is present on your machine; JAX nowuseslibtpuinstead.
Deprecations
The internal function
linear_util.wrap_initand the constructorcore.Jaxprnow must take a non-emptycore.DebugInfokwarg. Fora limited time, aDeprecationWarningis printed ifjax.extend.linear_util.wrap_initis used without debugging info.A downstream effect of this several other internal functions need debuginfo. This change does not affect public APIs.See https://github.com/jax-ml/jax/issues/26480 for more detail.In
jax.numpy.ndim(),jax.numpy.shape(), andjax.numpy.size(),non-arraylike inputs (such as lists, tuples, etc.) are now deprecated.
Bug fixes
TPU runtime startup and shutdown time should be significantly improved onTPU v5e and newer (from around 17s to around 8s). If not already set, you mayneed to enable transparent hugepages in your VM image(
sudosh-c'echoalways>/sys/kernel/mm/transparent_hugepage/enabled').We hope to improve this further in future releases.Persistent compilation cache no longer writes access time file ifJAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRUeviction policy isn’t enabled. This should improve performance when usingthe cache with large-scale network storage.
jax 0.5.0 (Jan 17, 2025)#
As of this release, JAX now useseffort-based versioning.Since this release makes a breaking change to PRNG key semantics thatmay require users to update their code, we are bumping the “meso” version of JAXto signify this.
Breaking changes
Enable
jax_threefry_partitionableby default (seethe update note).This release drops support for Mac x86 wheels. Mac ARM of course remainssupported. For a recent discussion, seehttps://github.com/jax-ml/jax/discussions/22936.
Two key factors motivated this decision:
The Mac x86 build (only) has a number of test failures and crashes. Wewould prefer to ship no release than a broken release.
Mac x86 hardware is end-of-life and cannot be easily obtained fordevelopers at this point. So it is difficult for us to fix this kind ofproblem even if we wanted to.
We are open to re-adding support for Mac x86 if the community is willingto help support that platform: in particular, we would need the JAX testsuite to pass cleanly on Mac x86 before we could ship releases again.
Changes:
The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimumsupported version until June 2025.
The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimumsupported version until June 2025.
jax.numpy.einsum()now defaults tooptimize='auto'rather thanoptimize='optimal'. This avoids exponentially-scaling trace-time inthe case of many arguments (#25214).jax.numpy.linalg.solve()no longer supports batched 1D argumentson the right hand side. To recover the previous behavior in these cases,usesolve(a,b[...,None]).squeeze(-1).
New Features
jax.numpy.fft.fftn(),jax.numpy.fft.rfftn(),jax.numpy.fft.ifftn(), andjax.numpy.fft.irfftn()now supporttransforms in more than 3 dimensions, which was previously the limit. See#25606 for more details.Support added for user defined state in the FFI via the new
jax.ffi.register_ffi_type_id()function.The AOT lowering
.as_text()method now supports thedebug_infooptionto include debugging information, e.g., source location, in the output.
Deprecations
From
jax.interpreters.xla,abstractifyandpytype_aval_mappingsare now deprecated, having been replaced by symbols of the same nameinjax.core.jax.scipy.special.lpmn()andjax.scipy.special.lpmn_values()are deprecated, following their deprecation in SciPy v1.15.0. There areno plans to replace these deprecated functions with new APIs.The
jax.extend.ffisubmodule was moved tojax.ffi, and theprevious import path is deprecated.
Deletions
jax_enable_memoriesflag has been deleted and the behavior of that flagis on by default.From
jax.lib.xla_client, the previously-deprecatedDeviceandXlaRuntimeErrorsymbols have been removed; instead usejax.Deviceandjax.errors.JaxRuntimeErrorrespectively.The
jax.experimental.array_apimodule has been removed after beingdeprecated in JAX v0.4.32. Since that release,jax.numpysupportsthe array API directly.
jax 0.4.38 (Dec 17, 2024)#
Breaking Changes
XlaExecutable.cost_analysisnow returns adict[str,float](instead of asingle-elementlist[dict[str,float]]).
Changes:
jax.tree.flatten_with_pathandjax.tree.map_with_pathare addedas shortcuts of the correspondingtree_utilfunctions.
Deprecations
a number of APIs in the internal
jax.corenamespace have been deprecated.Most were no-ops, were little-used, or can be replaced by APIs of the samename injax.extend.core; see the documentation forjax.extendfor information on the compatibility guarantees of these semi-public extensions.Several previously-deprecated APIs have been removed, including:
from
jax.core:check_eqn,check_type,check_valid_jaxtype, andnon_negative_dim.from
jax.lib.xla_bridge:xla_clientanddefault_backend.from
jax.lib.xla_client:_xlaandbfloat16.from
jax.numpy:round_.
New Features
jax.export.export()can be used for device-polymorphic export withshardings constructed withjax.sharding.AbstractMesh().See thejax.export documentation.Added
jax.lax.split(). This is a primitive version ofjax.numpy.split(), added because it yields a more compacttranspose during automatic differentiation.
jax 0.4.37 (Dec 9, 2024)#
This is a patch release of jax 0.4.36. Only “jax” was released at this version.
Bug fixes
Fixed a bug where
jitwould error if an argument was namedf(#25329).Fix a bug that will throw
indexoutofrangeerror injax.lax.while_loop()if the user register pytree node class withdifferent aux data for the flatten and flatten_with_path.Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.
jax 0.4.36 (Dec 5, 2024)#
Breaking Changes
This release lands “stackless”, an internal change to JAX’s tracingmachinery. We made trace dispatch purely a function of context rather than afunction of both context and data. This let us delete a lot of machinery formanaging data-dependent tracing: levels, sublevels,
post_process_call,new_base_main,custom_bind, and so on. The change should only affectusers that use JAX internals.If you do use JAX internals then you may need toupdate your code (seehttps://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986ffor clues about how to do this). There might also be version skewissues with JAX libraries that do this. If you find this change breaks yournon-JAX-internals-using code then try the
config.jax_data_dependent_tracing_fallbackflag as a workaround, and ifyou need help updating your code then please file a bug.jax.experimental.jax2tf.convert()withnative_serialization=Falseor withenable_xla=Falsehave been deprecated since July 2024, withJAX version 0.4.31. Now we removed support for these use cases.jax2tfwith native serialization will still be supported.In
jax.interpreters.xla, thexb,xc, andxesymbols have been removedafter being deprecated in JAX v0.4.31. Instead usexb=jax.lib.xla_bridge,xc=jax.lib.xla_client, andxe=jax.lib.xla_extension.The deprecated module
jax.experimental.exporthas been removed. It was replacedbyjax.exportin JAX v0.4.30. See themigration guidefor information on migrating to the new API.The
initialargument tojax.nn.softmax()andjax.nn.log_softmax()has been removed, after being deprecated in v0.4.27.Calling
np.asarrayon typed PRNG keys (i.e. keys produced byjax.random.key())now raises an error. Previously, this returned a scalar object array.The following deprecated methods and functions in
jax.exporthavebeen removed:jax.export.DisabledSafetyCheck.shape_assertions: it had no effectalready.jax.export.Exported.lowering_platforms: useplatforms.jax.export.Exported.mlir_module_serialization_version:usecalling_convention_version.jax.export.Exported.uses_shape_polymorphism:useuses_global_constants.the
lowering_platformskwarg forjax.export.export(): useplatformsinstead.
The kwargs
symbolic_scopeandsymbolic_constraintsfromjax.export.symbolic_args_specs()have been removed. They weredeprecated in June 2024. Usescopeandconstraintsinstead.Hashing of tracers, which has been deprecated since version 0.4.30, nowresults in a
TypeError.Refactor: JAX build CLI (build/build.py) now uses a subcommand structure andreplaces previous build.py usage. Run
pythonbuild/build.py--helpformore details. Brief overview of the new subcommand options:build: Builds JAX wheel packages. For e.g.,pythonbuild/build.pybuild--wheels=jaxlib,jax-cuda-pjrtrequirements_update: Updates requirements_lock.txt files.
jax.scipy.linalg.toeplitz()now does implicit batching on multi-dimensionalinputs. To recover the previous behavior, you can calljax.numpy.ravel()on the function inputs.jax.scipy.special.gamma()andjax.scipy.special.gammasgn()nowreturn NaN for negative integer inputs, to match the behavior of SciPy fromhttps://github.com/scipy/scipy/pull/21827.jax.clear_backendswas removed after being deprecated in v0.4.26.We removed the custom call “__gpu$xla.gpu.triton” from the list of customcall that we guarantee export stability. This is because this custom callrelies on Triton IR, which is not guaranteed to be stable. If you needto export code that uses this custom call, you can use the
disabled_checksparameter. See more details in thedocumentation.
New Features
jax.jit()got a newcompiler_options:dict[str,Any]argument, forpassing compilation options to XLA. For the moment it’s undocumented andmay be in flux.jax.tree_util.register_dataclass()now allows metadata fields to bedeclared inline viadataclasses.field(). See the function documentationfor examples.jax.lax.linalg.eig()and the relatedjax.numpyfunctions(jax.numpy.linalg.eig()andjax.numpy.linalg.eigvals()) are nowsupported on GPU. See#24663 for more details.Added two new configuration flags,
jax_exec_time_optimization_effortandjax_memory_fitting_effort, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
Bug fixes
Fixed a bug where the GPU implementations of LU and QR decomposition wouldresult in an indexing overflow for batch sizes close to int32 max. See#24843 for more details.
Deprecations
jax.lib.xla_extension.ArrayImplandjax.lib.xla_client.ArrayImplare deprecated;usejax.Arrayinstead.jax.lib.xla_extension.XlaRuntimeErroris deprecated; usejax.errors.JaxRuntimeErrorinstead.
jax 0.4.35 (Oct 22, 2024)#
Breaking Changes
jax.numpy.isscalar()now returns True for any array-like object withzero dimensions. Previously it only returned True for zero-dimensionalarray-like objects with a weak dtype.jax.experimental.host_callbackhas been deprecated since March 2024, withJAX version 0.4.26. Now we removed it.See#20385 for a discussion of alternatives.
Changes:
jax.lax.FftTypewas introduced as a public name for the enum of FFToperations. The semi-public APIjax.lib.xla_client.FftTypehas beendeprecated.TPU: JAX now installs TPU support from the
libtpupackage rather thanlibtpu-nightly. For the next few releases JAX will pin an empty version oflibtpu-nightlyas well aslibtputo ease the transition; that dependencywill be removed in Q1 2025.
Deprecations:
The semi-public API
jax.lib.xla_client.PaddingTypehas been deprecated.No JAX APIs consume this type, so there is no replacement.The default behavior of
jax.pure_callback()andjax.extend.ffi.ffi_call()undervmaphas been deprecated and so hasthevectorizedparameter to those functions. Thevmap_methodparametershould be used instead for better defined behavior. See the discussion in#23881 for more details.The semi-public API
jax.lib.xla_client.register_custom_call_targethasbeen deprecated. Use the JAX FFI instead.The semi-public APIs
jax.lib.xla_client.dtype_to_etype,jax.lib.xla_client.ops,jax.lib.xla_client.shape_from_pyval,jax.lib.xla_client.PrimitiveType,jax.lib.xla_client.Shape,jax.lib.xla_client.XlaBuilder, andjax.lib.xla_client.XlaComputationhave been deprecated. Use StableHLOinstead.
jax 0.4.34 (October 4, 2024)#
New Functionality
This release includes wheels for Python 3.13. Free-threading mode is not yetsupported.
jax.errors.JaxRuntimeErrorhas been added as a public alias for theformerly privateXlaRuntimeErrortype.
Breaking changes
jax_pmap_no_rank_reductionflag is set toTrueby default.array[0] on a pmap result now introduces a reshape (use array[0:1]instead).
The per-shard shape (accessible via jax_array.addressable_shards orjax_array.addressable_data(0)) now has a leading (1, …). Update codethat directly accesses shards accordingly. The rank of the per-shard-shapenow matches that of the global shape which is the same behavior as jit.This avoids costly reshapes when passing results from pmap into jit.
jax.experimental.host_callbackhas been deprecated since March 2024, withJAX version 0.4.26. Now we set the default value of the--jax_host_callback_legacyconfiguration value toTrue, which means thatif your code usesjax.experimental.host_callbackAPIs, those API callswill be implemented in terms of the newjax.experimental.io_callbackAPI.If this breaks your code, for a very limited time, you can set the--jax_host_callback_legacytoTrue. Soon we will remove thatconfiguration option, so you should instead transition to using thenew JAX callback APIs. See#20385 for a discussion.
Deprecations
In
jax.numpy.trim_zeros(), non-arraylike arguments or arraylikearguments withndim!=1are now deprecated, and in the future will resultin an error.Internal pretty-printing tools
jax.core.pp_*have been removed, afterbeing deprecated in JAX v0.4.30.jax.lib.xla_client.Deviceis deprecated; usejax.Deviceinstead.jax.lib.xla_client.XlaRuntimeErrorhas been deprecated. Usejax.errors.JaxRuntimeErrorinstead.
Deletion:
jax.xla_computationis deleted. It’s been 3 months since it’s deprecationin 0.4.30 JAX release.Please use the AOT APIs to get the same functionality asjax.xla_computation.jax.xla_computation(fn)(*args,**kwargs)can be replaced withjax.jit(fn).lower(*args,**kwargs).compiler_ir('hlo').You can also use
.out_infoproperty ofjax.stages.Loweredto get theoutput information (like tree structure, shape and dtype).For cross-backend lowering, you can replace
jax.xla_computation(fn,backend='tpu')(*args,**kwargs)withjax.jit(fn).trace(*args,**kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo').
jax.ShapeDtypeStructno longer accepts thenamed_shapeargument.The argument was only used byxmapwhich was removed in 0.4.31.jax.tree.map(f,None,non-None), which previously emitted aDeprecationWarning, now raises an error in a future version of jax.Noneis only a tree-prefix of itself. To preserve the current behavior, you canaskjax.tree.mapto treatNoneas a leaf value by writing:jax.tree.map(lambdax,y:NoneifxisNoneelsef(x,y),a,b,is_leaf=lambdax:xisNone).jax.sharding.XLACompatibleShardinghas been removed. Please usejax.sharding.Sharding.
Bug fixes
Fixed a bug where
jax.numpy.cumsum()would produce incorrect outputsif a non-boolean input was provided anddtype=boolwas specified.Edit implementation of
jax.numpy.ldexp()to get correct gradient.
jax 0.4.33 (September 16, 2024)#
This is a patch release on top of jax 0.4.32, that fixes two bugs found in thatrelease.
A TPU-only data corruption bug was found in the version of libtpu pinned byJAX 0.4.32, which manifested only if multiple TPU slices were present in thesame job, for example, if training on multiple v5e slices.This release fixes that issue by pinning a fixed version oflibtpu.
This release fixes an inaccurate result for F64 tanh on CPU (#23590).
jax 0.4.32 (September 11, 2024)#
Note: This release was yanked from PyPi because of a data corruption bug on TPU.See the 0.4.33 release notes for more details.
New Functionality
Added
jax.extend.ffi.ffi_call()andjax.extend.ffi.ffi_lowering()to support the use of the newForeign function interface (FFI) to interface with customC++ and CUDA code from JAX.
Changes
jax_enable_memoriesflag is set toTrueby default.jax.numpynow supports v2023.12 of the Python Array API Standard.SeePython Array API standard for more information.Computations on the CPU backend may now be dispatched asynchronously inmore cases. Previously non-parallel computations were always dispatchedsynchronously. You can recover the old behavior by setting
jax.config.update('jax_cpu_enable_async_dispatch',False).Added new
jax.process_indices()function to replace thejax.process_indexs()function that was deprecated in JAX v0.2.13.To align with the behavior of
numpy.fabs,jax.numpy.fabshas beenmodified to no longer supportcomplexdtypes.jax.tree_util.register_dataclassnow checks thatdata_fieldsandmeta_fieldsincludes all dataclass fields withinit=Trueand only them, ifnodetypeis a dataclass.Several
jax.numpyfunctions now have fullufuncinterfaces, includingadd,multiply,bitwise_and,bitwise_or,bitwise_xor,logical_and,logical_and, andlogical_and.In future releases we plan to expand these to other ufuncs.Added
jax.lax.optimization_barrier(), which allows users to preventcompiler optimizations such as common-subexpression elimination and tocontrol scheduling.
Breaking changes
The MHLO MLIR dialect (
jax.extend.mlir.mhlo) has been removed. Use thestablehlodialect instead.
Deprecations
Complex inputs to
jax.numpy.clip()andjax.numpy.hypot()areno longer allowed, after being deprecated since JAX v0.4.27.Deprecated the following APIs:
jax.lib.xla_bridge.xla_client: usejax.lib.xla_clientdirectly.jax.lib.xla_bridge.get_backend: usejax.extend.backend.get_backend().jax.lib.xla_bridge.default_backend: usejax.extend.backend.default_backend().
The
jax.experimental.array_apimodule is deprecated, and importing it is nolonger required to use the Array API.jax.numpysupports the array APIdirectly; seePython Array API standard for more information.The internal utilities
jax.core.check_eqn,jax.core.check_type, andjax.core.check_valid_jaxtypeare now deprecated, and will be removed inthe future.jax.numpy.round_has been deprecated, following removal of the correspondingAPI in NumPy 2.0. Usejax.numpy.round()instead.Passing a DLPack capsule to
jax.dlpack.from_dlpack()is deprecated.The argument tojax.dlpack.from_dlpack()should be an array fromanother framework that implements the__dlpack__protocol.
jaxlib 0.4.32 (September 11, 2024)#
Note: This release was yanked from PyPi because of a data corruption bug on TPU.See the 0.4.33 release notes for more details.
Breaking changes
This release of jaxlib switched to a new version of the CPU backend, whichshould compile faster and leverage parallelism better. If you experienceany problems due to this change, you can temporarily enable the old CPUbackend by setting the environment variable
XLA_FLAGS=--xla_cpu_use_thunk_runtime=false. If you need to do this,please file a JAX bug with instructions to reproduce.Hermetic CUDA support is added.Hermetic CUDA uses a specific downloadable version of CUDA instead of theuser’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCLdistributions, and then use CUDA libraries and tools as dependencies invarious Bazel targets. This enables more reproducible builds for JAX and itssupported CUDA versions.
Changes
SparseCore profiling is added.
JAX now supports profilingSparseCore on TPUv5p chips. These traces will be viewable in Tensorboard Profiler’sTraceViewer.
jax 0.4.31 (July 29, 2024)#
Deletion
xmap has been deleted. Please use
shard_map()as the replacement.
Changes
The minimum CuDNN version is v9.1. This was true in previous releases also,but we now declare this version constraint formally.
The minimum Python version is now 3.10. 3.10 will remain the minimumsupported version until July 2025.
The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimumsupported version until December 2024.
The minimum SciPy version is now 1.10. SciPy 1.10 will remain the minimumsupported version until January 2025.
jax.numpy.ceil(),jax.numpy.floor()andjax.numpy.trunc()now return the outputof the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.libdevice.10.bcis no longer bundled with CUDA wheels. It must beinstalled either as a part of local CUDA installation, or via NVIDIA’s CUDApip wheels.jax.experimental.pallas.BlockSpecnow expectsblock_shapetobe passedbeforeindex_map. The old argument order is deprecated andwill be removed in a future release.Updated the repr of gpu devices to be more consistentwith TPUs/CPUs. For example,
cuda(id=0)will now beCudaDevice(id=0).Added the
deviceproperty andto_devicemethod tojax.Array, aspart of JAX’sArray API support.
Deprecations
Removed a number of previously-deprecated internal APIs related topolymorphic shapes. From
jax.core: removedcanonicalize_shape,dimension_as_value,definitely_equal, andsymbolic_equal_dim.HLO lowering rules should no longer wrap singleton ir.Values in tuples.Instead, return singleton ir.Values unwrapped. Support for wrapped valueswill be removed in a future version of JAX.
jax.experimental.jax2tf.convert()withnative_serialization=Falseorenable_xla=Falseis now deprecated and this support will be removed ina future version.Native serialization has been the default since JAX 0.4.16 (September 2023).The previously-deprecated function
jax.random.shufflehas been removed;instead usejax.random.permutationwithindependent=True.
jaxlib 0.4.31 (July 29, 2024)#
Bug fixes
Fixed a bug that meant that negative static_argnums to a jit were mishandledby the jit dispatch fast path.
Fixed a bug that meant triangular solves of batches of singular matricesproduce nonsensical finite values, instead of inf or nan (#3589, #15429).
jax 0.4.30 (June 18, 2024)#
Changes
JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version wasbumped to 0.4.0 but this has been rolled back in this release to give usersof both TensorFlow and JAX more time to migrate to a newer TensorFlowrelease.
jax.experimental.mesh_utilscan now create an efficient mesh for TPU v5e.jax now depends on jaxlib directly. This change was enabled by the CUDAplugin switch: there are no longer multiple jaxlib variants. You can installa CPU-only jax with
pipinstalljax, no extras required.Added an API for exporting and serializing JAX functions. This usedto exist in
jax.experimental.export(which is being deprecated),and will now live injax.export.See thedocumentation.
Deprecations
Internal pretty-printing tools
jax.core.pp_*are deprecated, and will be removedin a future release.Hashing of tracers is deprecated, and will lead to a
TypeErrorin a future JAXrelease. This previously was the case, but there was an inadvertent regression inthe last several JAX releases.jax.experimental.exportis deprecated. Usejax.exportinstead.See themigration guide.Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays
xandy,x.astype(y)will raise a warning. To silence it usex.astype(y.dtype).jax.xla_computationis deprecated and will be removed in a future release.Please use the AOT APIs to get the same functionality asjax.xla_computation.jax.xla_computation(fn)(*args,**kwargs)can be replaced withjax.jit(fn).lower(*args,**kwargs).compiler_ir('hlo').You can also use
.out_infoproperty ofjax.stages.Loweredto get theoutput information (like tree structure, shape and dtype).For cross-backend lowering, you can replace
jax.xla_computation(fn,backend='tpu')(*args,**kwargs)withjax.jit(fn).trace(*args,**kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo').
jaxlib 0.4.30 (June 18, 2024)#
Support for monolithic CUDA jaxlibs has been dropped. You must use theplugin-based installation (
pipinstalljax[cuda12]orpipinstalljax[cuda12_local]).
jax 0.4.29 (June 10, 2024)#
Changes
We anticipate that this will be the last release of JAX and jaxlibsupporting a monolithic CUDA jaxlib. Future releases will use the CUDAplugin jaxlib (e.g.
pipinstalljax[cuda12]).JAX now requires ml_dtypes version 0.4.0 or newer.
Removed backwards-compatibility support for old usage of the
jax.experimental.exportAPI. It is not possible anymore to usefromjax.experimental.exportimportexport, and instead you should usefromjax.experimentalimportexport.The removed functionality has been deprecated since 0.4.24.Added
is_leafargument tojax.tree.all()&jax.tree_util.tree_all().
Deprecations
jax.sharding.XLACompatibleShardingis deprecated. Please usejax.sharding.Sharding.jax.experimental.Exported.in_shardingshas been renamed asjax.experimental.Exported.in_shardings_hlo. Same forout_shardings.The old names will be removed after 3 months.Removed a number of previously-deprecated APIs:
The
tolargument ofjax.numpy.linalg.matrix_rank()is beingdeprecated and will soon be removed. Usertolinstead.The
rcondargument ofjax.numpy.linalg.pinv()is beingdeprecated and will soon be removed. Usertolinstead.The deprecated
jax.configsubmodule has been removed. To configure JAXuseimportjaxand then reference the config object viajax.config.jax.randomAPIs no longer accept batched keys, where previouslysome did unintentionally. Going forward, we recommend explicit use ofjax.vmap()in such cases.In
jax.scipy.special.beta(), thexandyparameters have beenrenamed toaandbfor consistency with otherbetaAPIs.
New Functionality
Added
jax.experimental.Exported.in_shardings_jax()to constructshardings that can be used with the JAX APIs from the HloShardingsthat are stored in theExportedobjects.
jaxlib 0.4.29 (June 10, 2024)#
Bug fixes
Fixed a bug where XLA sharded some concatenation operations incorrectly,which manifested as an incorrect output for cumulative reductions (#21403).
Fixed a bug where XLA:CPU miscompiled certain matmul fusions(https://github.com/openxla/xla/pull/13301).
Fixes a compiler crash on GPU (https://github.com/jax-ml/jax/issues/21396).
Deprecations
jax.tree.map(f,None,non-None)now emits aDeprecationWarning, and willraise an error in a future version of jax.Noneis only a tree-prefix ofitself. To preserve the current behavior, you can askjax.tree.maptotreatNoneas a leaf value by writing:jax.tree.map(lambdax,y:NoneifxisNoneelsef(x,y),a,b,is_leaf=lambdax:xisNone).
jax 0.4.28 (May 9, 2024)#
Bug fixes
Reverted a change to
make_jaxprthat was breaking Equinox (#21116).
Deprecations & removals
The
kindargument tojax.numpy.sort()andjax.numpy.argsort()is now removed. Usestable=Trueorstable=Falseinstead.Removed
get_compute_capabilityfrom thejax.experimental.pallas.gpumodule. Use thecompute_capabilityattribute of a GPU device, returnedbyjax.devices()orjax.local_devices(), instead.The
newshapeargument tojax.numpy.reshape()is being deprecatedand will soon be removed. Useshapeinstead.
Changes
The minimum jaxlib version of this release is 0.4.27.
jaxlib 0.4.28 (May 9, 2024)#
Bug fixes
Fixes a memory corruption bug in the type name of Array and JIT Pythonobjects in Python 3.10 or earlier.
Fixed a warning
'+ptx84'isnotarecognizedfeatureforthistargetunder CUDA 12.4.Fixed a slow compilation problem on CPU.
Changes
The Windows build is now built with Clang instead of MSVC.
jax 0.4.27 (May 7, 2024)#
New Functionality
Added
jax.numpy.unstack()andjax.numpy.cumulative_sum(),following their addition in the array API 2023 standard, soon to beadopted by NumPy.Added a new config option
jax_cpu_collectives_implementationto select theimplementation of cross-process collective operations used by the CPU backend.Choices available are'none'(default),'gloo'and'mpi'(requires jaxlib 0.4.26).If set to'none', cross-process collective operations are disabled.
Changes
jax.pure_callback(),jax.experimental.io_callback()andjax.debug.callback()now usejax.Arrayinsteadofnp.ndarray. You can recover the old behavior by transformingthe arguments viajax.tree.map(np.asarray,args)before passing themto the callback.complex_arr.astype(bool)now follows the same semantics as NumPy, returningFalse wherecomplex_arris equal to0+0j, and True otherwise.core.Tokennow is a non-trivial class which wraps ajax.Array. It couldbe created and threaded in and out of computations to build up dependency.The singleton objectcore.tokenhas been removed, users now should createand use freshcore.Tokenobjects instead.On GPU, the Threefry PRNG implementation no longer lowers to a kernel callby default. This choice can improve runtime memory usage at a compile-timecost. Prior behavior, which produces a kernel call, can be recovered with
jax.config.update('jax_threefry_gpu_kernel_lowering',True). If the newdefault causes issues, please file a bug. Otherwise, we intend to removethis flag in a future release.
Deprecations & Removals
Pallas now exclusively uses XLA for compiling kernels on GPU. The oldlowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLAenvironment variable no longer has any effect.jax.numpy.clip()has a new argument signature:a,a_min, anda_maxare deprecated in favor ofx(positional only),min, andmax(#20550).The
device()method of JAX arrays has been removed, after being deprecatedsince JAX v0.4.21. Usearr.devices()instead.The
initialargument tojax.nn.softmax()andjax.nn.log_softmax()is deprecated; empty inputs to softmax are now supported without setting this.In
jax.jit(), passing invalidstatic_argnumsorstatic_argnamesnow leads to an error rather than a warning.The minimum jaxlib version is now 0.4.23.
The
jax.numpy.hypot()function now issues a deprecation warning whenpassing complex-valued inputs to it. This will raise an error when thedeprecation is completed.Scalar arguments to
jax.numpy.nonzero(),jax.numpy.where(), andrelated functions now raise an error, following a similar change in NumPy.The config option
jax_cpu_enable_gloo_collectivesis deprecated.Usejax.config.update('jax_cpu_collectives_implementation','gloo')instead.The
jax.Array.device_bufferandjax.Array.device_buffersmethods havebeen removed after being deprecated in JAX v0.4.22. Instead usejax.Array.addressable_shardsandjax.Array.addressable_data().The
condition,x, andyparameters ofjax.numpy.whereare nowpositional-only, following deprecation of the keywords in JAX v0.4.21.Non-array arguments to functions in
jax.lax.linalgnow must bespecified by keyword. Previously, this raised a DeprecationWarning.Array-like arguments are now required in several
jax.numpy()APIs,includingapply_along_axis(),apply_over_axes(),inner(),outer(),cross(),kron(), andlexsort().
Bug fixes
jax.numpy.astype()will now always return a copy whencopy=True.Previously, no copy would be made when the output array would have the samedtype as the input array. This may result in some increased memory usage.The default value is set tocopy=Falseto preserve backwards compatibility.
jaxlib 0.4.27 (May 7, 2024)#
jax 0.4.26 (April 3, 2024)#
New Functionality
Added
jax.numpy.trapezoid(), following the addition of this function inNumPy 2.0.
Changes
Complex-valued
jax.numpy.geomspace()now chooses the logarithmic spiralbranch consistent with that of NumPy 2.0.The behavior of
lax.rng_bit_generator, and in turn the'rbg'and'unsafe_rbg'PRNG implementations, underjax.vmaphaschanged so thatmapping over keys results in random generation only from the firstkey in the batch.Docs now use
jax.random.keyfor construction of PRNG key arraysrather thanjax.random.PRNGKey.
Deprecations & Removals
jax.tree_map()is deprecated; usejax.tree.mapinstead, or for backwardcompatibility with older JAX versions, usejax.tree_util.tree_map().jax.clear_backends()is deprecated as it does not necessarily do whatits name suggests and can lead to unexpected consequences, e.g., it will notdestroy existing backends and release corresponding owned resources. Usejax.clear_caches()if you only want to clean up compilation caches.For backward compatibility or you really need to switch/reinitialize thedefault backend, usejax.extend.backend.clear_backends().The
jax.experimental.mapsmodule andjax.experimental.maps.xmaparedeprecated. Usejax.experimental.shard_maporjax.vmapwith thespmd_axis_nameargument for expressing SPMD device-parallel computations.The
jax.experimental.host_callbackmodule is deprecated.Use instead thenew JAX external callbacks.AddedJAX_HOST_CALLBACK_LEGACYflag to assist in the transition to thenew callbacks. See#20385 for a discussion.Passing arguments to
jax.numpy.array_equal()andjax.numpy.array_equiv()that cannot be converted to a JAX array now results in an exception.The deprecated flag
jax_parallel_functions_output_gdahas been removed.This flag was long deprecated and did nothing; its use was a no-op.The previously-deprecated imports
jax.interpreters.ad.configandjax.interpreters.ad.source_info_utilhave now been removed. Usejax.configandjax.extend.source_info_utilinstead.JAX export does not support older serialization versions anymore. Version 9has been supported since October 27th, 2023 and has become the defaultsince February 1, 2024.Seea description of the versions.This change could break clients that set a specificJAX serialization version lower than 9.
jaxlib 0.4.26 (April 3, 2024)#
Changes
JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has beendropped.
JAX now supports NumPy 2.0.
jax 0.4.25 (Feb 26, 2024)#
New Features
AddedCUDA ArrayInterfaceimport support (requires jaxlib 0.4.24).
JAX arrays now support NumPy-style scalar boolean indexing, e.g.
x[True]orx[False].Added
jax.treemodule, with a more convenient interface for referencing functionsinjax.tree_util.jax.tree.transpose()(i.e.jax.tree_util.tree_transpose()) now acceptsinner_treedef=None, in which case the inner treedef will be automatically inferred.
Changes
Pallas now uses XLA instead of the Triton Python APIs to compile Tritonkernels. You can revert to the old behavior by setting the
JAX_TRITON_COMPILE_VIA_XLAenvironment variable to"0".Several deprecated APIs in
jax.interpreters.xlathat were removed in v0.4.24have been re-added in v0.4.25, includingbackend_specific_translations,translations,register_translation,xla_destructure,TranslationRule,TranslationContext, andXLAOp. These are still considered deprecated, andwill be removed again in the future when better replacements are available.Refer to#19816 for discussion.
Deprecations & Removals
jax.numpy.linalg.solve()now shows a deprecation warning for batched 1Dsolves withb.ndim>1. In the future these will be treated as batched 2Dsolves.Conversion of a non-scalar array to a Python scalar now raises an error, regardlessof the size of the array. Previously a deprecation warning was raised in the case ofnon-scalar arrays of size 1. This follows a similar deprecation in NumPy.
The previously deprecated configuration APIs have been removedfollowing a standard 3 months deprecation cycle (seeAPI compatibility).These include
the
jax.config.configobject andthe
define_*_stateandDEFINE_*methods ofjax.config.
Importing the
jax.configsubmodule viaimportjax.configis deprecated.To configure JAX useimportjaxand then reference the config objectviajax.config.The minimum jaxlib version is now 0.4.20.
jaxlib 0.4.25 (Feb 26, 2024)#
jax 0.4.24 (Feb 6, 2024)#
Changes
JAX lowering to StableHLO does not depend on physical devices anymore.If your primitive wraps custom_partitioning or JAX callbacks in the loweringrule i.e. function passed to
ruleparameter ofmlir.register_loweringthen add yourprimitive tojax._src.dispatch.prim_requires_devices_during_loweringset.This is needed because custom_partitioning and JAX callbacks need physicaldevices to createShardings during lowering.This is a temporary state until we can createShardings without physicaldevices.jax.numpy.argsort()andjax.numpy.sort()now support thestableanddescendingarguments.Several changes to the handling of shape polymorphism (used in
jax.experimental.jax2tfandjax.experimental.export):cleaner pretty-printing of symbolic expressions (#19227)
added the ability to specify symbolic constraints on the dimension variables.This makes shape polymorphism more expressive, and gives a way to workaroundlimitations in the reasoning about inequalities.See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
with the addition of symbolic constraints (#19235) we nowconsider dimension variables from different scopes to be different, evenif they have the same name. Symbolic expressions from different scopescannot interact, e.g., in arithmetic operations.Scopes are introduced by
jax.experimental.jax2tf.convert(),jax.experimental.export.symbolic_shape(),jax.experimental.export.symbolic_args_specs().The scope of a symbolic expressionecan be read withe.scopeand passedinto the above functions to direct them to construct symbolic expressions ina given scope.See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.simplified and faster equality comparisons, where we consider two symbolic dimensionsto be equal if the normalized form of their difference reduces to 0(#19231; note that this may result in user-visible behaviorchanges)
improved the error messages for inconclusive inequality comparisons(#19235).
the
core.non_negative_dimAPI (introduced recently)was deprecated andcore.max_dimandcore.min_dimwere introduced(#18953) to expressmaxandminfor symbolic dimensions.You can usecore.max_dim(d,0)instead ofcore.non_negative_dim(d).the
shape_poly.is_poly_dimis deprecated in favor ofexport.is_symbolic_dim(#19282).the
export.args_specsis deprecated in favor ofexport.symbolic_args_specs({jax-issue}#19283`).the
shape_poly.PolyShapeandjax2tf.PolyShapeare deprecated, usestrings for polymorphic shapes specifications (#19284).JAX default native serialization version is now 9. This is relevantfor
jax.experimental.jax2tfandjax.experimental.export.Seedescription of version numbers.
Refactored the API for
jax.experimental.export. Instead offromjax.experimental.exportimportexportyou should use nowfromjax.experimentalimportexport. The old way of importing willcontinue to work for a deprecation period of 3 months.Added
jax.scipy.stats.sem().jax.numpy.unique()withreturn_inverse=Truereturns inverse indicesreshaped to the dimension of the input, following a similar change tonumpy.unique()in NumPy 2.0.jax.numpy.sign()now returnsx/abs(x)for nonzero complex inputs. This isconsistent with the behavior ofnumpy.sign()in NumPy version 2.0.jax.scipy.special.logsumexp()withreturn_sign=Truenow uses the NumPy 2.0convention for the complex sign,x/abs(x). This is consistent with the behaviorofscipy.special.logsumexp()in SciPy v1.13.JAX now supports the bool DLPack type for both import and export.Previously bool values could not be imported and were exported as integers.
Deprecations & Removals
A number of previously deprecated functions have been removed, following astandard 3+ month deprecation cycle (seeAPI compatibility).This includes:
From
jax.core:TracerArrayConversionError,TracerIntegerConversionError,UnexpectedTracerError,as_hashable_function,collections,dtypes,lu,map,namedtuple,partial,pp,ref,safe_zip,safe_map,source_info_util,total_ordering,traceback_util,tuple_delete,tuple_insert, andzip.From
jax.lax:dtypes,itertools,naryop,naryop_dtype_rule,standard_abstract_eval,standard_naryop,standard_primitive,standard_unop,unop, andunop_dtype_rule.The
jax.linear_utilsubmodule and all its contents.The
jax.prngsubmodule and all its contents.From
jax.random:PRNGKeyArray,KeyArray,default_prng_impl,threefry_2x32,threefry2x32_key,threefry2x32_p,rbg_key, andunsafe_rbg_key.From
jax.tree_util:register_keypaths,AttributeKeyPathEntry, andGetItemKeyPathEntry.from
jax.interpreters.xla:backend_specific_translations,translations,register_translation,xla_destructure,TranslationRule,TranslationContext,axis_groups,ShapedArray,ConcreteArray,AxisEnv,backend_compile,andXLAOp.from
jax.numpy:NINF,NZERO,PZERO,row_stack,issubsctype,trapz, andin1d.from
jax.scipy.linalg:trilandtriu.
The previously-deprecated method
PRNGKeyArray.unsafe_raw_arrayhas beenremoved. Usejax.random.key_data()instead.bool(empty_array)now raises an error rather than returningFalse. Thispreviously raised a deprecation warning, and follows a similar change in NumPy.Support for the mhlo MLIR dialect has been deprecated. JAX no longer usesthe mhlo dialect, in favor of stablehlo. APIs that refer to “mhlo” will beremoved in the future. Use the “stablehlo” dialect instead.
jax.random: passing batched keys directly to random number generation functions,such asbits(),gamma(), and others, is deprecatedand will emit aFutureWarning. Usejax.vmapfor explicit batching.jax.lax.tie_in()is deprecated: it has been a no-op since JAX v0.2.0.
jaxlib 0.4.24 (Feb 6, 2024)#
Changes
JAX now supports CUDA 12.3 and CUDA 11.8. Support for CUDA 12.2 has beendropped.
cost_analysisnow works with cross-compiledCompiledobjects (i.e. whenusing.lower().compile()with a topology object, e.g., to compile forCloud TPU from a non-TPU computer).AddedCUDA ArrayInterfaceimport support (requires jax 0.4.25).
jax 0.4.23 (Dec 13, 2023)#
jaxlib 0.4.23 (Dec 13, 2023)#
Fixed a bug that caused verbose logging from the GPU compiler duringcompilation.
jax 0.4.22 (Dec 13, 2023)#
Deprecations
The
device_bufferanddevice_buffersproperties of JAX arrays are deprecated.Explicit buffers have been replaced by the more flexible array sharding interface,but the previous outputs can be recovered this way:arr.device_bufferbecomesarr.addressable_data(0)arr.device_buffersbecomes[x.dataforxinarr.addressable_shards]
jaxlib 0.4.22 (Dec 13, 2023)#
jax 0.4.21 (Dec 4 2023)#
New Features
Added
jax.nn.squareplus.
Changes
The minimum jaxlib version is now 0.4.19.
Released wheels are built now with clang instead of gcc.
Enforce that the device backend has not been initialized prior to calling
jax.distributed.initialize().Automate arguments to
jax.distributed.initialize()in cloud TPU environments.
Deprecations
The previously-deprecated
sym_posargument has been removed fromjax.scipy.linalg.solve(). Useassume_a='pos'instead.Passing
Nonetojax.array()orjax.asarray(), either directly orwithin a list or tuple, is deprecated and now raises aFutureWarning.It currently is converted to NaN, and in the future will raise aTypeError.Passing the
condition,x, andyparameters tojax.numpy.wherebykeyword arguments has been deprecated, to matchnumpy.where.Passing arguments to
jax.numpy.array_equal()andjax.numpy.array_equiv()that cannot be converted to a JAX array is deprecated and now raises aDeprecationWaning. Currently the functions return False, in the future thiswill raise an exception.The
device()method of JAX arrays is deprecated. Depending on the context, it maybe replaced with one of the following:jax.Array.devices()returns the set of all devices used by the array.jax.Array.shardinggives the sharding configuration used by the array.
jaxlib 0.4.21 (Dec 4 2023)#
Changes
In preparation for adding distributed CPU support, JAX now treats CPUdevices identically to GPU and TPU devices, that is:
jax.devices()includes all devices present in a distributed job, eventhose not local to the current process.jax.local_devices()still onlyincludes devices local to the current process, so if the change tojax.devices()breaks you, you most likely want to usejax.local_devices()instead.CPU devices now receive a globally unique ID number within a distributedjob; previously CPU devices would receive a process-local ID number.
The
process_indexof each CPU device will now match any GPU or TPUdevices within the same process; previously theprocess_indexof a CPUdevice was always 0.
On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
Bug fixes
Fixed error/hang when an array with non-finite values is passed to anon-symmetric eigendecomposition (#18226). Arrays with non-finite values nowproduce arrays full of NaNs as outputs.
jax 0.4.20 (Nov 2, 2023)#
jaxlib 0.4.20 (Nov 2, 2023)#
Bug fixes
Fixed some type confusion between E4M3 and E5M2 float8 types.
jax 0.4.19 (Oct 19, 2023)#
New Features
Added
jax.typing.DTypeLike, which can be used to annotate objects thatare convertible to JAX dtypes.Added
jax.numpy.fill_diagonal.
Changes
JAX now requires SciPy 1.9 or newer.
Bug fixes
Only process 0 in a multicontroller distributed JAX program will writepersistent compilation cache entries. This fixes write contention if thecache is placed on a network file system such as GCS.
The version check for cusolver and cufft no longer considers the patchversions when determining if the installed version of these libraries is atleast as new as the versions against which JAX was built.
jaxlib 0.4.19 (Oct 19, 2023)#
Changes
jaxlib will now always prefer pip-installed NVIDIA CUDA libraries(nvidia-… packages) over any other CUDA installation if they areinstalled, including installations named in
LD_LIBRARY_PATH. If thiscauses problems and the intent is to use a system-installed CUDA, the fix isto remove the pip installed CUDA library packages.
jax 0.4.18 (Oct 6, 2023)#
jaxlib 0.4.18 (Oct 6, 2023)#
Changes
CUDA jaxlibs now depend on the user to install a compatible NCCL version.If using the recommended
cuda12_pipinstallation, NCCL should be installedautomatically. Currently, NCCL 2.16 or newer is required.We now provide Linux aarch64 wheels, both with and without NVIDIA GPUsupport.
jax.Array.item()now supports optional index arguments.
Deprecations
A number of internal utilities and inadvertent exports in
jax.laxhavebeen deprecated, and will be removed in a future release.jax.lax.dtypes: usejax.dtypesinstead.jax.lax.itertools: useitertoolsinstead.naryop,naryop_dtype_rule,standard_abstract_eval,standard_naryop,standard_primitive,standard_unop,unop, andunop_dtype_ruleareinternal utilities, now deprecated without replacement.
Bug fixes
Fixed Cloud TPU regression where compilation would OOM due to smem.
jax 0.4.17 (Oct 3, 2023)#
New features
Added new
jax.numpy.bitwise_count()function, matching the API of the similarfunction recently added to NumPy.
Deprecations
Removed the deprecated module
jax.abstract_arraysand all its contents.Named key constructors in
jax.randomare deprecated. Pass theimplargumenttojax.random.PRNGKey()orjax.random.key()instead:random.threefry2x32_key(seed)becomesrandom.PRNGKey(seed,impl='threefry2x32')random.rbg_key(seed)becomesrandom.PRNGKey(seed,impl='rbg')random.unsafe_rbg_key(seed)becomesrandom.PRNGKey(seed,impl='unsafe_rbg')
Changes:
CUDA: JAX now verifies that the CUDA libraries it finds are at least as newas the CUDA libraries that JAX was built against. If older libraries arefound, JAX raises an exception since that is preferable to mysteriousfailures and crashes.
Removed the “No GPU/TPU” found warning. Instead warn if, on Linux, anNVIDIA GPU or a Google TPU are found but not used and
--jax_platformswasnot specified.jax.scipy.stats.mode()now returns a 0 count if the mode is takenacross a size-0 axis, matching the behavior ofscipy.stats.modein SciPy1.11.Most
jax.numpyfunctions and attributes now have fully-defined type stubs.Previously many of these were treated asAnyby static type checkers likemypyandpytype.
jaxlib 0.4.17 (Oct 3, 2023)#
Changes:
Python 3.12 wheels were added in this release.
The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer.
Bug fixes:
Fixed log spam from ABSL when the JAX CPU backend was initialized.
jax 0.4.16 (Sept 18, 2023)#
Changes
Added
jax.numpy.ufunc, as well asjax.numpy.frompyfunc(), which can convertany scalar-valued function into anumpy.ufunc()-like object, with methods such asouter(),reduce(),accumulate(),at(), andreduceat()(#17054).When not running under IPython: when an exception is raised, JAX now filters out theentirety of its internal frames from tracebacks. (Without the “unfiltered stack trace”that previously appeared.) This should produce much friendlier-looking tracebacks. Seehere for an example.This behavior can be changed by setting
JAX_TRACEBACK_FILTERING=remove_frames(for twoseparate unfiltered/filtered tracebacks, which was the old behavior) orJAX_TRACEBACK_FILTERING=off(for one unfiltered traceback).jax2tf default serialization version is now 7, which introduces new shapesafety assertions.
Devices passed to
jax.sharding.Meshshould be hashable. This specificallyapplies to mock devices or user created devices.jax.devices()arealready hashable.
Breaking changes:
jax2tf now uses native serialization by default. Seethejax2tf documentationfor details and for mechanisms to override the default.
The option
--jax_coordination_servicehas been removed. It is now alwaysTrue.jax.jaxpr_utilhas been removed from the public JAX namespace.JAX_USE_PJRT_C_API_ON_TPUno longer has an effect (i.e. it always defaults to true).The backwards compatibility flag
--jax_host_callback_ad_transformsintroduced in December 2021, has been removed.
Deprecations:
Several
jax.numpyAPIs have been deprecated followingNumPy NEP-52:jax.numpy.NINFhas been deprecated. Use-jax.numpy.infinstead.jax.numpy.PZEROhas been deprecated. Use0.0instead.jax.numpy.NZEROhas been deprecated. Use-0.0instead.jax.numpy.issubsctype(x,t)has been deprecated. Usejax.numpy.issubdtype(x.dtype,t).jax.numpy.row_stackhas been deprecated. Usejax.numpy.vstackinstead.jax.numpy.in1dhas been deprecated. Usejax.numpy.isininstead.jax.numpy.trapzhas been deprecated. Usejax.scipy.integrate.trapezoidinstead.
jax.scipy.linalg.trilandjax.scipy.linalg.triuhave been deprecated,following SciPy. Usejax.numpy.trilandjax.numpy.triuinstead.jax.lax.prodhas been removed after being deprecated in JAX v0.4.11.Use the built-inmath.prodinstead.A number of exports from
jax.interpreters.xlarelated to definingHLO lowering rules for custom JAX primitives have been deprecated. Customprimitives should be defined using the StableHLO lowering utilities injax.interpreters.mlirinstead.The following previously-deprecated functions have been removed after athree-month deprecation period:
jax.abstract_arrays.ShapedArray: usejax.core.ShapedArray.jax.abstract_arrays.raise_to_shaped: usejax.core.raise_to_shaped.jax.numpy.alltrue: usejax.numpy.all.jax.numpy.sometrue: usejax.numpy.any.jax.numpy.product: usejax.numpy.prod.jax.numpy.cumproduct: usejax.numpy.cumprod.
Deprecations/removals:
The internal submodule
jax.prngis now deprecated. Its contents are available atjax.extend.random.The internal submodule path
jax.linear_utilhas been deprecated. Usejax.extend.linear_utilinstead (Part ofjax.extend: a module for extensions)jax.random.PRNGKeyArrayandjax.random.KeyArrayare deprecated. Usejax.Arrayfor type annotations, andjax.dtypes.issubdtype(arr.dtype,jax.dtypes.prng_key)forruntime detection of typed prng keys.The method
PRNGKeyArray.unsafe_raw_arrayis deprecated. Usejax.random.key_data()instead.jax.experimental.pjit.with_sharding_constraintis deprecated. Usejax.lax.with_sharding_constraintinstead.The internal utilities
jax.core.is_opaque_dtypeandjax.core.has_opaque_dtypehave been removed. Opaque dtypes have been renamed to Extended dtypes; usejnp.issubdtype(dtype,jax.dtypes.extended)instead (available since jax v0.4.14).The utility
jax.interpreters.xla.register_collective_primitivehas beenremoved. This utility did nothing useful in recent JAX releases and callsto it can be safely removed.The internal submodule path
jax.linear_utilhas been deprecated. Usejax.extend.linear_utilinstead (Part ofjax.extend: a module for extensions)
jaxlib 0.4.16 (Sept 18, 2023)#
Changes:
Sparse CSR matrix multiplications via the experimental jax sparse APIsno longer uses a deterministic algorithm on NVIDIA GPUs. This change wasmade to improve compatibility with CUDA 12.2.1.
Bug fixes:
Fixed a crash on Windows due to a fatal LLVM error related to out-of-ordersections and IMAGE_REL_AMD64_ADDR32NB relocations(https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).
jax 0.4.14 (July 27, 2023)#
Changes
jax.jittakesdonate_argnamesas an argument. It’s semantics are similartostatic_argnames.If neither donate_argnums nor donate_argnames is provided, noarguments are donated. If donate_argnums is not provided butdonate_argnames is, or vice versa, JAX usesinspect.signature(fun)to find any positional arguments thatcorrespond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actualparameters listed in either donate_argnums or donate_argnames willbe donated.jax.random.gamma()has been re-factored to a more efficient algorithmwith more robust endpoint behavior (#16779). This means that thesequence of values returned for a givenkeywill change between JAX v0.4.13and v0.4.14 forgammaand related samplers (includingjax.random.ball(),jax.random.beta(),jax.random.chisquare(),jax.random.dirichlet(),jax.random.generalized_normal(),jax.random.loggamma(),jax.random.t()).
Deletions
in_axis_resourcesandout_axis_resourceshave been deleted from pjit sinceit has been more than 3 months since their deprecation. Please usein_shardingsandout_shardingsas the replacement.This is a safe and trivial name replacement. It does not change any of thecurrent pjit semantics and doesn’t break any code.You can still pass inPartitionSpecsto in_shardings and out_shardings.
Deprecations
Python 3.8 support has been dropped as perhttps://docs.jax.dev/en/latest/deprecation.html
JAX now requires NumPy 1.22 or newer as perhttps://docs.jax.dev/en/latest/deprecation.html
Passing optional arguments to
jax.numpy.ndarray.atby position isno longer supported, after being deprecated in JAX version 0.4.7.For example, instead ofx.at[i].get(True), usex.at[i].get(indices_are_sorted=True)The following
jax.Arraymethods have been removed, after being deprecatedin JAX v0.4.5:jax.Array.broadcast: usejax.lax.broadcast()instead.jax.Array.broadcast_in_dim: usejax.lax.broadcast_in_dim()instead.jax.Array.split: usejax.numpy.split()instead.
The following APIs have been removed after previous deprecation:
jax.ad: usejax.interpreters.ad.jax.curry: usecurry=lambdaf:partial(partial,f).jax.partial_eval: usejax.interpreters.partial_eval.jax.pxla: usejax.interpreters.pxla.jax.xla: usejax.interpreters.xla.jax.ShapedArray: usejax.core.ShapedArray.jax.interpreters.pxla.device_put: usejax.device_put().jax.interpreters.pxla.make_sharded_device_array: usejax.make_array_from_single_device_arrays().jax.interpreters.pxla.ShardedDeviceArray: usejax.Array.jax.numpy.DeviceArray: usejax.Array.jax.stages.Compiled.compiler_ir: usejax.stages.Compiled.as_text().
Breaking changes
JAX now requires ml_dtypes version 0.2.0 or newer.
To fix a corner case, calls to
jax.lax.cond()with fivearguments will always resolve to the “common operands”condbehavior (as documented) if the second and third arguments arecallable, even if other operands are callable as well. See#16413.The deprecated config options
jax_arrayandjax_jit_pjit_api_merge,which did nothing, have been removed. These options have been true bydefault for many releases.
New features
JAX now supports a configuration flag –jax_serialization_versionand a JAX_SERIALIZATION_VERSION environment variable to control theserialization version (#16746).
jax2tf in presence of shape polymorphism now generates code that checkscertain shape constraints, if the serialization version is at least 7.See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
jaxlib 0.4.14 (July 27, 2023)#
Deprecations
Python 3.8 support has been dropped as perhttps://docs.jax.dev/en/latest/deprecation.html
jax 0.4.13 (June 22, 2023)#
Changes
jax.jitnow allowsNoneto be passed toin_shardingsandout_shardings. The semantics are as follows:For in_shardings, JAX will mark is as replicated but this behaviorcan change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner todetermine the output shardings.
jax.experimental.pjit.pjitalso allowsNoneto be passed toin_shardingsandout_shardings. The semantics are as follows:If the mesh context manager isnot provided, JAX has the freedom tochoose whatever sharding it wants.
For in_shardings, JAX will mark is as replicated but this behaviorcan change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner todetermine the output shardings.
If the mesh context manager is provided, None will imply that the valuewill be replicated on all devices of the mesh.
Executable.cost_analysis() works on Cloud TPU
Added a warning if a non-allowlisted
jaxlibplugin is in use.Added
jax.tree_util.tree_leaves_with_path.Noneis not a valid input tojax.experimental.multihost_utils.host_local_array_to_global_arrayorjax.experimental.multihost_utils.global_array_to_host_local_array.Please usejax.sharding.PartitionSpec()if you wanted to replicate yourinput.
Bug fixes
Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheelis named
cudnn89instead ofcudnn88.
Deprecations
The
native_serialization_strict_checksparameter tojax.experimental.jax2tf.convert()is deprecated in favor of thenewnative_serializaation_disabled_checks(#16347).
jaxlib 0.4.13 (June 22, 2023)#
Changes
Added Windows CPU-only wheels to the
jaxlibPypi release.
Bug fixes
__cuda_array_interface__was broken in previous jaxlib versions and is nowfixed (#16440).Concurrent CUDA kernel tracing is now enabled by default on NVIDIA GPUs.
jax 0.4.12 (June 8, 2023)#
Changes
Deprecations
jax.abstract_arraysand its contents are now deprecated. See relatedfunctionality injax.core.jax.numpy.alltrue: usejax.numpy.all. This follows the deprecationofnumpy.alltruein NumPy version 1.25.0.jax.numpy.sometrue: usejax.numpy.any. This follows the deprecationofnumpy.sometruein NumPy version 1.25.0.jax.numpy.product: usejax.numpy.prod. This follows the deprecationofnumpy.productin NumPy version 1.25.0.jax.numpy.cumproduct: usejax.numpy.cumprod. This follows the deprecationofnumpy.cumproductin NumPy version 1.25.0.jax.sharding.OpShardingShardinghas been removed since it has been 3months since it was deprecated.
jaxlib 0.4.12 (June 8, 2023)#
Changes
Includes PTX/SASS for Hopper (SM version 9.0+) GPUs. Previousversions of jaxlib should work on Hopper but would have a longJIT-compilation delay the first time a JAX operation was executed.
Bug fixes
Fixes incorrect source line information in JAX-generated Python tracebacksunder Python 3.11.
Fixes crash when printing local variables of frames in JAX-generated Pythontracebacks (#16027).
jax 0.4.11 (May 31, 2023)#
Deprecations
The following APIs have been removed after a 3 month deprecation period, inaccordance with theAPI compatibility policy:
jax.experimental.PartitionSpec: usejax.sharding.PartitionSpec.jax.experimental.maps.Mesh: usejax.sharding.Meshjax.experimental.pjit.NamedSharding: usejax.sharding.NamedSharding.jax.experimental.pjit.PartitionSpec: usejax.sharding.PartitionSpec.jax.experimental.pjit.FROM_GDA. Instead pass shardedjax.Arrayobjectsas input and remove the optionalin_shardingsargument topjit.jax.interpreters.pxla.PartitionSpec: usejax.sharding.PartitionSpec.jax.interpreters.pxla.Mesh: usejax.sharding.Meshjax.interpreters.xla.Buffer: usejax.Array.jax.interpreters.xla.Device: usejax.Device.jax.interpreters.xla.DeviceArray: usejax.Array.jax.interpreters.xla.device_put: usejax.device_put.jax.interpreters.xla.xla_call_p: usejax.experimental.pjit.pjit_p.axis_resourcesargument ofwith_sharding_constraintis removed. Pleaseuseshardingsinstead.
jaxlib 0.4.11 (May 31, 2023)#
Changes
Added
memory_stats()method toDevices. If supported, this returns adict of string stat names with int values, e.g."bytes_in_use", or None ifthe platform doesn’t support memory statistics. The exact stats returned mayvary across platforms. Currently only implemented on Cloud TPU.Re-added support for the Python buffer protocol (
memoryview) on CPUdevices.
jax 0.4.10 (May 11, 2023)#
jaxlib 0.4.10 (May 11, 2023)#
Changes
Fixed
'apple-m1'isnotarecognizedprocessorforthistarget(ignoringprocessor)issue that prevented previous release from running on Mac M1.
jax 0.4.9 (May 9, 2023)#
Changes
The flags experimental_cpp_jit, experimental_cpp_pjit andexperimental_cpp_pmap have been removed.They are now always on.
Accuracy of singular value decomposition (SVD) on TPU has been improved(requires jaxlib 0.4.9).
Deprecations
jax.experimental.gda_serializationis deprecated and has been renamed tojax.experimental.array_serialization.Please change your imports to usejax.experimental.array_serialization.The
in_axis_resourcesandout_axis_resourcesarguments of pjit have beendeprecated. Please usein_shardingsandout_shardingsrespectively.The function
jax.numpy.msorthas been removed. It has been deprecated sinceJAX v0.4.1. Usejnp.sort(a,axis=0)instead.in_partsandout_partsarguments have been removed fromjax.xla_computationsince they were only used with sharded_jit and sharded_jit is long gone.instantiate_const_outputsargument has been removed fromjax.xla_computationsince it has been unused for a very long time.
jaxlib 0.4.9 (May 9, 2023)#
jax 0.4.8 (March 29, 2023)#
Breaking changes
A major component of the Cloud TPU runtime has been upgraded. This enablesthe following new features on Cloud TPU:
jax.debug.print(),jax.debug.callback(), andjax.debug.breakpoint()now work on Cloud TPUAutomatic TPU memory defragmentation
jax.experimental.host_callback()is no longer supported on Cloud TPUwith the new runtime component. Please file an issue on theJAX issuetracker if the newjax.debugAPIsare insufficient for your use case.The old runtime component will be available for at least the next threemonths by setting the environment variable
JAX_USE_PJRT_C_API_ON_TPU=false. If you find you need to disable the newruntime for any reason, please let us know on theJAX issuetracker.
Changes
The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
Deprecations
CUDA 11.4 support has been dropped. JAX GPU wheels only supportCUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is builtfrom source.
global_arg_shapesargument of pmap only worked with sharded_jit and hasbeen removed from pmap. Please migrate to pjit and remove global_arg_shapesfrom pmap.
jax 0.4.7 (March 27, 2023)#
Changes
As per https://docs.jax.dev/en/latest/jax_array_migration.html#jax-array-migration
jax.config.jax_arraycannot be disabled anymore.jax.config.jax_jit_pjit_api_mergecannot be disabled anymore.jax.experimental.jax2tf.convert()now supports thenative_serializationparameter to use JAX’s native lowering to StableHLO to obtain aStableHLO module for the entire JAX function instead of lowering each JAXprimitive to a TensorFlow op. This simplifies the internals and increasesthe confidence that what you serialize matches the JAX native semantics.Seedocumentation.As part of this change the config flag--jax2tf_default_experimental_native_loweringhas been renamed to--jax2tf_native_serialization.JAX now depends on
ml_dtypes, which contains definitions of NumPy typeslike bfloat16. These definitions were previously internal to JAX, but havebeen split into a separate package to facilitate sharing them with otherprojects.JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.
Deprecations
The type
jax.numpy.DeviceArrayis deprecated. Usejax.Arrayinstead,for which it is an alias.The type
jax.interpreters.pxla.ShardedDeviceArrayis deprecated. Usejax.Arrayinstead.Passing additional arguments to
jax.numpy.ndarray.atby position is deprecated.For example, instead ofx.at[i].get(True), usex.at[i].get(indices_are_sorted=True)jax.interpreters.xla.device_putis deprecated. Please usejax.device_put.jax.interpreters.pxla.device_putis deprecated. Please usejax.device_put.jax.experimental.pjit.FROM_GDAis deprecated. Please pass in shardedjax.Arrays as input and remove thein_shardingsargument to pjit sinceit is optional.
jaxlib 0.4.7 (March 27, 2023)#
Changes:
jaxlib now depends on
ml_dtypes, which contains definitions of NumPy typeslike bfloat16. These definitions were previously internal to JAX, but havebeen split into a separate package to facilitate sharing them with otherprojects.
jax 0.4.6 (Mar 9, 2023)#
Changes
jax.tree_utilnow contain a set of APIs that allow user to define keys for theircustom pytree node. This includes:tree_flatten_with_paththat flattens a tree and return not only each leaf butalso their key paths.tree_map_with_paththat can map a function that takes the key path as an argument.register_pytree_with_keysto register how the key path and leaves should lookslike in a custom pytree node.keystrthat pretty-prints a key path.
jax2tf.call_tf()has a new parameteroutput_shape_dtype(defaultNone)that can be used to declare the output shape and type of the result. This enablesjax2tf.call_tf()to work in the presence of shape polymorphism. (#14734).
Deprecations
The old key-path APIs in
jax.tree_utilare deprecated and will be removed 3 monthsfrom Mar 10 2023:register_keypaths: usejax.tree_util.register_pytree_with_keys()instead.AttributeKeyPathEntry: useGetAttrKeyinstead.GetitemKeyPathEntry: useSequenceKeyorDictKeyinstead.
jaxlib 0.4.6 (Mar 9, 2023)#
jax 0.4.5 (Mar 2, 2023)#
Deprecations
jax.sharding.OpShardingShardinghas been renamed tojax.sharding.GSPMDSharding.jax.sharding.OpShardingShardingwill be removed in 3 months from Feb 17, 2023.The following
jax.Arraymethods are deprecated and will be removed 3 months fromFeb 23 2023:jax.Array.broadcast: usejax.lax.broadcast()instead.jax.Array.broadcast_in_dim: usejax.lax.broadcast_in_dim()instead.jax.Array.split: usejax.numpy.split()instead.
jax 0.4.4 (Feb 16, 2023)#
Changes
The implementation of
jitandpjithas been merged. Merging jit and pjitchanges the internals of JAX without affecting the public API of JAX.Before,jitwas a final style primitive. Final style means that the creationof jaxpr was delayed as much as possible and transformations were stackedon top of each other. With thejit-pjitimplementation merge,jitbecomes an initial style primitive which means that we trace to jaxpras early as possible. For more information seethis section in autodidax.Moving to initial style should simplify JAX’s internals and makedevelopment of features like dynamic shapes, etc easier.You can disable it only via the environment variable i.e.os.environ['JAX_JIT_PJIT_API_MERGE']='0'.The merge must be disabled via an environment variable since it affects JAXat import time so it needs to be disabled before jax is imported.axis_resourcesargument ofwith_sharding_constraintis deprecated.Please useshardingsinstead. There is no change needed if you were usingaxis_resourcesas an arg. If you were using it as a kwarg, then pleaseuseshardingsinstead.axis_resourceswill be removed after 3 monthsfrom Feb 13, 2023.added the
jax.typingmodule, with tools for type annotations of JAXfunctions.The following names have been deprecated:
jax.xla.Deviceandjax.interpreters.xla.Device: usejax.Device.jax.experimental.maps.Mesh. Usejax.sharding.Meshinstead.jax.experimental.pjit.NamedSharding: usejax.sharding.NamedSharding.jax.experimental.pjit.PartitionSpec: usejax.sharding.PartitionSpec.jax.interpreters.pxla.Mesh: usejax.sharding.Mesh.jax.interpreters.pxla.PartitionSpec: usejax.sharding.PartitionSpec.
Breaking Changes
the
initialargument to reduction functions likejax.numpy.sum()is now required to be a scalar, consistent with the corresponding NumPy API.The previous behavior of broadcasting the output against non-scalarinitialvalues was an unintentional implementation detail (#14446).
jaxlib 0.4.4 (Feb 16, 2023)#
Breaking changes
Support for NVIDIA Kepler series GPUs has been removed from the default
jaxlibbuilds. If Kepler support is needed, it is still possible tobuildjaxlibfrom source with Kepler support (via the--cuda_compute_capabilities=sm_35option tobuild.py), however notethat CUDA 12 has completely dropped support for Kepler GPUs.
jax 0.4.3 (Feb 8, 2023)#
Breaking changes
Deleted
jax.scipy.linalg.polar_unitary(), which was a deprecated JAXextension to the scipy API. Usejax.scipy.linalg.polar()instead.
Changes
jaxlib 0.4.3 (Feb 8, 2023)#
jax.Arraynow has the non-blockingis_ready()method, which returnsTrueif the array is ready (see alsojax.block_until_ready()).
jax 0.4.2 (Jan 24, 2023)#
Breaking changes
Deleted
jax.experimental.callbackOperations with dimensions in presence of jax2tf shape polymorphism havebeen generalized to work in more scenarios, by converting the symbolicdimension to JAX arrays. Operations involving symbolic dimensions and
np.ndarraynow can raise errors when the result is used as a shape value(#14106).jaxpr objects now raise an error on attribute setting in order to avoidproblematic mutations (#14102)
Changes
jax2tf.call_tf()has a new parameterhas_side_effects(defaultTrue)that can be used to declare whether an instance can be removed or replicatedby JAX optimizations such as dead-code elimination (#13980).Added more support for floordiv and mod for jax2tf shape polymorphism. Previously,certain division operations resulted in errors in presence of symbolic dimensions(#14108).
jaxlib 0.4.2 (Jan 24, 2023)#
Changes
Set JAX_USE_PJRT_C_API_ON_TPU=1 to enable new Cloud TPU runtime, featuringautomatic device memory defragmentation.
jax 0.4.1 (Dec 13, 2022)#
Changes
Support for Python 3.7 has been dropped, in accordance with JAX’sPython and NumPy version support policy.
We introduce
jax.Arraywhich is a unified array type that subsumesDeviceArray,ShardedDeviceArray, andGlobalDeviceArraytypes in JAX.Thejax.Arraytype helps make parallelism a core feature of JAX,simplifies and unifies JAX internals, and allows us to unifyjitandpjit.jax.Arrayhas been enabled by default in JAX 0.4 and makes somebreaking change to thepjitAPI. Thejax.Array migrationguide canhelp you migrate your codebase tojax.Array. You can also look at theDistributed arrays and automatic parallelizationtutorial to understand the new concepts.PartitionSpecandMeshare now out of experimental. The new API endpointsarejax.sharding.PartitionSpecandjax.sharding.Mesh.jax.experimental.maps.Meshandjax.experimental.PartitionSpecaredeprecated and will be removed in 3 months.with_sharding_constraints new public endpoint isjax.lax.with_sharding_constraint.If using ABSL flags together with
jax.config, the ABSL flag values are nolonger read or written after the JAX configuration options are initiallypopulated from the ABSL flags. This change improves performance of readingjax.configoptions, which are used pervasively in JAX.The jax2tf.call_tf function now uses for TF lowering the first TFdevice of the same platform as used by the embedding JAX computation.Before, it was using the 0th device for the JAX-default backend.
A number of
jax.numpyfunctions now have their arguments marked aspositional-only, matching NumPy.jnp.msortis now deprecated, following the deprecation ofnp.msortin numpy 1.24.It will be removed in a future release, in accordance with theAPI compatibilitypolicy. It can be replaced withjnp.sort(a,axis=0).
jaxlib 0.4.1 (Dec 13, 2022)#
Changes
Support for Python 3.7 has been dropped, in accordance with JAX’sPython and NumPy version support policy.
The behavior of
XLA_PYTHON_CLIENT_MEM_FRACTION=.XXhas been changed to allocate XX% ofthe total GPU memory instead of the previous behavior of using currently available GPU memoryto calculate preallocation. Please refer toGPU memory allocation formore details.The deprecated method
.block_host_until_ready()has been removed. Use.block_until_ready()instead.
jax 0.4.0 (Dec 12, 2022)#
The release was yanked.
jaxlib 0.4.0 (Dec 12, 2022)#
The release was yanked.
jax 0.3.25 (Nov 15, 2022)#
Changes
jax.numpy.linalg.pinv()now supports thehermitianoption.jax.scipy.linalg.hessenberg()is now supported on CPU only. Requiresjaxlib > 0.3.24.New functions
jax.lax.linalg.hessenberg(),jax.lax.linalg.tridiagonal(), andjax.lax.linalg.householder_product()were added. Householder reductionis currently CPU-only and tridiagonal reductions are supported on CPU andGPU only.The gradients of
svdandjax.numpy.linalg.pinvare now computed moreeconomically for non-square matrices.
Breaking Changes
Deleted the
jax_experimental_name_stackconfig option.Convert a string
axis_namesarguments to thejax.experimental.maps.Meshconstructor into a singleton tupleinstead of unpacking the string into a sequence of character axis names.
jaxlib 0.3.25 (Nov 15, 2022)#
Changes
Added support for tridiagonal reductions on CPU and GPU.
Added support for upper Hessenberg reductions on CPU.
Bugs
Fixed a bug that meant that frames in tracebacks captured by JAX wereincorrectly mapped to source lines under Python 3.10+
jax 0.3.24 (Nov 4, 2022)#
Changes
JAX should be faster to import. We now import scipy lazily, which accountedfor a significant fraction of JAX’s import time.
Setting the env var
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$Ncan beused to limit the number of cache entries written to the persistent cache.By default, computations that take 1 second or more to compile will becached.Added
jax.scipy.stats.mode().
The default device order used by
pmapon TPU if no order is specified nowmatchesjax.devices()for single-process jobs. Previously thetwo orderings differed, which could lead to unnecessary copies orout-of-memory errors. Requiring the orderings to agree simplifies matters.
Breaking Changes
jax.numpy.gradient()now behaves like most other functions injax.numpy,and forbids passing lists or tuples in place of arrays (#12958)Functions in
jax.numpy.linalgandjax.numpy.fftnow uniformlyrequire inputs to be array-like: i.e. lists and tuples cannot be used in placeof arrays. Part of#7737.
Deprecations
jax.sharding.MeshPspecShardinghas been renamed tojax.sharding.NamedSharding.jax.sharding.MeshPspecShardingname will be removed in 3 months.
jaxlib 0.3.24 (Nov 4, 2022)#
Changes
Buffer donation now works on CPU. This may break code that marked buffersfor donation on CPU but relied on donation not being implemented.
jax 0.3.23 (Oct 12, 2022)#
Changes
Update Colab TPU driver version for new jaxlib release.
jax 0.3.22 (Oct 11, 2022)#
Changes
Add
JAX_PLATFORMS=tpu,cpuas default setting in TPU initialization,so JAX will raise an error if TPU cannot be initialized instead of fallingback to CPU. SetJAX_PLATFORMS=''to override this behavior and automaticallychoose an available backend (the original default), or setJAX_PLATFORMS=cputo always use CPU regardless of if the TPU is available.
Deprecations
Several test utilities deprecated in JAX v0.3.8 are now removed from
jax.test_util.
jaxlib 0.3.22 (Oct 11, 2022)#
jax 0.3.21 (Sep 30, 2022)#
Changes
The persistent compilation cache will now warn instead of raising anexception on error (#12582), so program execution can continueif something goes wrong with the cache. Set
JAX_RAISE_PERSISTENT_CACHE_ERRORS=trueto revert this behavior.
jax 0.3.20 (Sep 28, 2022)#
jaxlib 0.3.20 (Sep 28, 2022)#
Bug fixes
Fixes support for limiting the visible CUDA devices via
jax_cuda_visible_devicesin distributed jobs. This functionality is needed forthe JAX/SLURM integration on GPU (#12533).
jax 0.3.19 (Sep 27, 2022)#
Fixes required jaxlib version.
jax 0.3.18 (Sep 26, 2022)#
Changes
Ahead-of-time lowering and compilation functionality (tracked in#7733) is stable and public. Seetheoverview and the API docsfor
jax.stages.Introduced
jax.Array, intended to be used for bothisinstancechecksand type annotations for array types in JAX. Notice that this included some subtlechanges to howisinstanceworks forjax.numpy.ndarrayfor jax-internalobjects, asjax.numpy.ndarrayis now a simple alias ofjax.Array.
Breaking changes
jax._srcis no longer imported into the publicjaxnamespace.This may break users that were using JAX internals.jax.soft_pmaphas been deleted. Please usepjitorxmapinstead.jax.soft_pmapis undocumented. If it were documented, a deprecation periodwould have been provided.
jax 0.3.17 (Aug 31, 2022)#
Bugs
Fix corner case issue in gradient of
lax.powwith an exponent of zero(#12041)
Breaking changes
jax.checkpoint(), also known asjax.remat(), no longer supportstheconcreteoption, following the previous version’s deprecation; seeJEP 11830.
Changes
Added
jax.pure_callback()that enables calling back to pure Python functions from compiled functions (e.g. functions decorated withjax.jitorjax.pmap).
Deprecations:
The deprecated
DeviceArray.tile()method has been removed. Usejax.numpy.tile()(#11944).DeviceArray.to_py()has been deprecated. Usenp.asarray(x)instead.
jax 0.3.16#
Breaking changes
Support for NumPy 1.19 has been dropped, per thedeprecation policy.Please upgrade to NumPy 1.20 or newer.
Changes
Added
jax.debugthat includes utilities for runtime value debugging such atjax.debug.print()andjax.debug.breakpoint().Added new documentation forruntime value debugging
Deprecations
jax.mask()jax.shapecheck()APIs have been removed.See#11557.jax.experimental.loopshas been removed. See#10278for an alternative API.jax.tree_util.tree_multimap()has been removed. It has been deprecated sinceJAX release 0.3.5, andjax.tree_util.tree_map()is a direct replacement.Removed
jax.experimental.stax; it has long been a deprecated alias ofjax.example_libraries.stax.Removed
jax.experimental.optimizers; it has long been a deprecated alias ofjax.example_libraries.optimizers.jax.checkpoint(), also known asjax.remat(), has a newimplementation switched on by default, meaning the old implementation isdeprecated; seeJEP 11830.
jax 0.3.15 (July 22, 2022)#
Changes
JaxTestCaseandJaxTestLoaderhave been removed fromjax.test_util. Theseclasses have been deprecated since v0.3.1 (#11248).Added
jax.scipy.gaussian_kde(#11237).Binary operations between JAX arrays and built-in collections (
dict,list,set,tuple)now raise aTypeErrorin all cases. Previously some cases (particularly equality and inequality)would return boolean scalars inconsistent with similar operations in NumPy (#11234).Several
jax.tree_utilroutines accessed as top-level JAX package imports are nowdeprecated, and will be removed in a future JAX release in accordance with theAPI compatibility policy:jax.treedef_is_leaf()is deprecated in favor ofjax.tree_util.treedef_is_leaf()jax.tree_flatten()is deprecated in favor ofjax.tree_util.tree_flatten()jax.tree_leaves()is deprecated in favor ofjax.tree_util.tree_leaves()jax.tree_structure()is deprecated in favor ofjax.tree_util.tree_structure()jax.tree_transpose()is deprecated in favor ofjax.tree_util.tree_transpose()jax.tree_unflatten()is deprecated in favor ofjax.tree_util.tree_unflatten()
The
sym_posargument ofjax.scipy.linalg.solve()is deprecated in favor ofassume_a='pos',following a similar deprecation inscipy.linalg.solve().
jaxlib 0.3.15 (July 22, 2022)#
jax 0.3.14 (June 27, 2022)#
Breaking changes
jax.experimental.compilation_cache.initialize_cache()does not supportmax_cache_size_ bytesanymore and will not get that as an input.JAX_PLATFORMSnow raises an exception when platform initialization fails.
Changes
Fixed compatibility problems with NumPy 1.23.
jax.numpy.linalg.slogdet()now accepts an optionalmethodargumentthat allows selection between an LU-decomposition based implementation andan implementation based on QR decomposition.jax.numpy.linalg.qr()now supportsmode="raw".pickle,copy.copy, andcopy.deepcopynow have more complete support whenused on jax arrays (#10659). In particular:pickleanddeepcopypreviously returnednp.ndarrayobjects when usedon aDeviceArray; nowDeviceArrayobjects are returned. Fordeepcopy,the copied array is on the same device as the original. Forpicklethedeserialized array will be on the default device.Within function transformations (i.e. traced code),
deepcopyandcopypreviously were no-ops. Now they use the same mechanism asDeviceArray.copy().Calling
pickleon a traced array now results in an explicitConcretizationTypeError.
The implementation of singular value decomposition (SVD) andsymmetric/Hermitian eigendecomposition should be significantly faster onTPU, especially for matrices above 1000x1000 or so. Both now use a spectraldivide-and-conquer algorithm for eigendecomposition (QDWH-eig).
jax.numpy.ldexp()no longer silently promotes all inputs to float64,instead it promotes to float32 for integer inputs of size int32 or smaller(#10921).Add a
create_perfetto_linkoption tojax.profiler.start_trace()andjax.profiler.start_trace(). When used, the profiler will generate alink to the Perfetto UI to view the trace.Changed the semantics of
jax.profiler.start_server(...)()to store thekeepalive globally, rather than requiring the user to keep a reference toit.Added
jax.random.ball().Added
jax.default_device().Added a
python-mjax.collect_profilescript to manually capture programtraces as an alternative to the TensorBoard UI.Added a
jax.named_scopecontext manager that adds profiler metadata toPython programs (similar tojax.named_call).In scatter-update operations (i.e.
jax.numpy.ndarray.at), unsafe implicitdtype casts are deprecated, and now result in aFutureWarning.In a future release, this will become an error. An example of an unsafe implicitcast isjnp.zeros(4,dtype=int).at[0].set(1.5), in which1.5previously wassilently truncated to1.jax.experimental.compilation_cache.initialize_cache()now supports gcsbucket path as input.jax.numpy.roots()is now better behaved whenstrip_zeros=Falsewhencoefficients have leading zeros (#11215).
jaxlib 0.3.14 (June 27, 2022)#
x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14was released in 2018, so this should not be a very onerous requirement.
The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.
The Python flatbuffers package is no longer a dependency of jaxlib.
jax 0.3.13 (May 16, 2022)#
jax 0.3.12 (May 15, 2022)#
Changes
Fixes#10717.
jax 0.3.11 (May 15, 2022)#
Changes
jax.lax.eigh()now accepts an optionalsort_eigenvaluesargumentthat allows users to opt out of eigenvalue sorting on TPU.
Deprecations
Non-array arguments to functions in
jax.lax.linalgare now markedkeyword-only. As a backward-compatibility step passing keyword-onlyarguments positionally yields a warning, but in a future JAX release passingkeyword-only arguments positionally will fail.However, most users should prefer to usejax.numpy.linalginstead.jax.scipy.linalg.polar_unitary(), which was a JAX extension to thescipy API, is deprecated. Usejax.scipy.linalg.polar()instead.
jax 0.3.10 (May 3, 2022)#
jaxlib 0.3.10 (May 3, 2022)#
Changes
TF commitfixes an issue in the MHLO canonicalizer that caused constant folding totake a long time or crash for certain programs.
jax 0.3.9 (May 2, 2022)#
Changes
Added support for fully asynchronous checkpointing for GlobalDeviceArray.
jax 0.3.8 (April 29 2022)#
Changes
jax.numpy.linalg.svd()on TPUs uses a qdwh-svd solver.jax.numpy.linalg.cond()on TPUs now accepts complex input.jax.numpy.linalg.pinv()on TPUs now accepts complex input.jax.numpy.linalg.matrix_rank()on TPUs now accepts complex input.jax.scipy.cluster.vq.vq()has been added.jax.experimental.maps.meshhas been deleted.Please usejax.experimental.maps.Mesh. Please see https://docs.jax.dev/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Meshfor more information.jax.scipy.linalg.qr()now returns a length-1 tuple rather than the raw array whenmode='r', in order to match the behavior ofscipy.linalg.qr(#10452)jax.numpy.take_along_axis()now takes an optionalmodeparameterthat specifies the behavior of out-of-bounds indexing. By default,invalid values (e.g., NaN) will be returned for out-of-bounds indices. Inprevious versions of JAX, invalid indices were clamped into range. Theprevious behavior can be restored by passingmode="clip".jax.numpy.take()now defaults tomode="fill", which returnsinvalid values (e.g., NaN) for out-of-bounds indices.Scatter operations, such as
x.at[...].set(...), now have"drop"semantics.This has no effect on the scatter operation itself, but it means that whendifferentiated the gradient of a scatter will yield zero cotangents forout-of-bounds indices. Previously out-of-bounds indices were clamped intorange for the gradient, which was not mathematically correct.jax.numpy.take_along_axis()now raises aTypeErrorif its indicesare not of an integer type, matching the behavior ofnumpy.take_along_axis(). Previously non-integer indices were silentlycast to integers.jax.numpy.ravel_multi_index()now raises aTypeErrorif itsdimsargumentis not of an integer type, matching the behavior ofnumpy.ravel_multi_index(). Previously non-integerdimswas silentlycast to integers.jax.numpy.split()now raises aTypeErrorif itsaxisargumentis not of an integer type, matching the behavior ofnumpy.split(). Previously non-integeraxiswas silentlycast to integers.jax.numpy.indices()now raises aTypeErrorif its dimensionsare not of an integer type, matching the behavior ofnumpy.indices(). Previously non-integer dimensions were silentlycast to integers.jax.numpy.diag()now raises aTypeErrorif itskargumentis not of an integer type, matching the behavior ofnumpy.diag(). Previously non-integerkwas silentlycast to integers.Added
jax.random.orthogonal().
Deprecations
Many functions and objects available in
jax.test_utilare now deprecated and will raise awarning on import. This includescases_from_list,check_close,check_eq,device_under_test,format_shape_dtype_string,rand_uniform,skip_on_devices,with_config,xla_bridge, and_default_tolerance(#10389). These, along with previously-deprecatedJaxTestCase,JaxTestLoader, andBufferDonationTestCase, will be removed in a future JAX release.Most of these utilities can be replaced by calls to standard python & numpy testing utilities foundin e.g.unittest,absl.testing,numpy.testing, etc. JAX-specific functionalitysuch as device checking can be replaced through the use of public APIs such asjax.devices().Many of the deprecated utilities will still exist injax._src.test_util, but these are notpublic APIs and as such may be changed or removed without notice in future releases.
jax 0.3.7 (April 15, 2022)#
Changes:
Fixed a performance problem if the indices passed to
jax.numpy.take_along_axis()were broadcasted (#10281).jax.scipy.special.expit()andjax.scipy.special.logit()nowrequire their arguments to be scalars or JAX arrays. They also now promoteinteger arguments to floating point.The
DeviceArray.tile()method is deprecated, because numpy arrays do not have atile()method. As a replacement for this, usejax.numpy.tile()(#10266).
jaxlib 0.3.7 (April 15, 2022)#
Changes:
Linux wheels are now built conforming to the
manylinux2014standard, insteadofmanylinux2010.
jax 0.3.6 (April 12, 2022)#
jax 0.3.5 (April 7, 2022)#
Changes:
added
jax.random.loggamma()& improved behavior ofjax.random.beta()andjax.random.dirichlet()for small parameter values (#9906).the private
lax_numpysubmodule is no longer exposed in thejax.numpynamespace (#10029).added array creation routines
jax.numpy.frombuffer(),jax.numpy.fromfunction(),andjax.numpy.fromstring()(#10049).DeviceArray.copy()now returns aDeviceArrayrather than anp.ndarray(#10069)jax.experimental.sharded_jithas been deprecated and will be removed soon.
Deprecations:
jax.nn.normalize()is being deprecated. Usejax.nn.standardize()instead (#9899).jax.tree_util.tree_multimap()is deprecated. Usejax.tree_util.tree_map()instead (#5746).jax.experimental.sharded_jitis deprecated. Usepjitinstead.
jaxlib 0.3.5 (April 7, 2022)#
jax 0.3.4 (March 18, 2022)#
jax 0.3.3 (March 17, 2022)#
jax 0.3.2 (March 16, 2022)#
Changes:
The functions
jax.ops.index_update,jax.ops.index_add, which weredeprecated in 0.2.22, have been removed. Please usethe.atproperty on JAX arraysinstead, e.g.,x.at[idx].set(y).Moved
jax.experimental.ann.approx_*_kintojax.lax. These functions areoptimized alternatives tojax.lax.top_k.jax.numpy.broadcast_arrays()andjax.numpy.broadcast_to()now require scalaror array-like inputs, and will fail if they are passed lists (part of#7737).The standard jax[tpu] install can now be used with Cloud TPU v4 VMs.
pjitnow works on CPU (in addition to previous TPU and GPU support).
jaxlib 0.3.2 (March 16, 2022)#
Changes
XlaComputation.as_hlo_text()now supports printing large constants bypassing boolean flagprint_large_constants=True.
Deprecations:
The
.block_host_until_ready()method on JAX arrays has been deprecated.Use.block_until_ready()instead.
jax 0.3.1 (Feb 18, 2022)#
Changes:
jax.test_util.JaxTestCaseandjax.test_util.JaxTestLoaderare now deprecated.The suggested replacement is to useparametrized.TestCasedirectly. For tests thatrely on custom asserts such asJaxTestCase.assertAllClose(), the suggested replacementis to use standard numpy testing utilities such asnumpy.testing.assert_allclose(),which work directly with JAX arrays (#9620).jax.test_util.JaxTestCasenow setsjax_numpy_rank_promotion='raise'by default(#9562). To recover the previous behavior, use the newjax.test_util.with_configdecorator:@jtu.with_config(jax_numpy_rank_promotion='allow')classMyTestCase(jtu.JaxTestCase):...
Added
jax.scipy.linalg.schur(),jax.scipy.linalg.sqrtm(),jax.scipy.signal.csd(),jax.scipy.signal.stft(),jax.scipy.signal.welch().
jax 0.3.0 (Feb 10, 2022)#
Changes
jax version has been bumped to 0.3.0. Please see thedesign docfor the explanation.
jaxlib 0.3.0 (Feb 10, 2022)#
Changes
Bazel 5.0.0 is now required to build jaxlib.
jaxlib version has been bumped to 0.3.0. Please see thedesign docfor the explanation.
jax 0.2.28 (Feb 1, 2022)#
jax.jit(f).lower(...).compiler_ir()now defaults to the MHLO dialect if nodialect=is passed.The
jax.jit(f).lower(...).compiler_ir(dialect='mhlo')now returns an MLIRir.Moduleobject instead of its string representation.
jaxlib 0.1.76 (Jan 27, 2022)#
New features
Includes precompiled SASS for NVidia compute capability 8.0 GPUS(e.g. A100). Removes precompiled SASS for compute capability 6.1 so as notto increase the number of compute capabilities: GPUs with compute capability6.1 can use the 6.0 SASS.
With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IRby default.
Breaking changes
Support for NumPy 1.18 has been dropped, per thedeprecation policy.Please upgrade to a supported NumPy version.
Bug fixes
Fixed a bug where apparently identical pytreedef objects constructed by different routesdo not compare as equal (#9066).
The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).
jax 0.2.27 (Jan 18 2022)#
Breaking changes:
Support for NumPy 1.18 has been dropped, per thedeprecation policy.Please upgrade to a supported NumPy version.
The host_callback primitives have been simplified to drop thespecial autodiff handling for hcb.id_tap and id_print.From now on, only the primals are tapped. The old behavior can beobtained (for a limited time) by setting the
JAX_HOST_CALLBACK_AD_TRANSFORMSenvironment variable, or the--jax_host_callback_ad_transformsflag.Additionally, added documentation for how to implement the old behaviorusing JAX custom AD APIs (#8678).Sorting now matches the behavior of NumPy for
0.0andNaNregardless of thebit representation. In particular,0.0and-0.0are now treated as equivalent,where previously-0.0was treated as less than0.0. Additionally allNaNrepresentations are now treated as equivalent and sorted to the end of the array.Previously negativeNaNvalues were sorted to the front of the array, andNaNvalues with different internal bit representations were not treated as equivalent, andwere sorted according to those bit patterns (#9178).jax.numpy.unique()now treatsNaNvalues in the same way asnp.uniqueinNumPy versions 1.21 and newer: at most oneNaNvalue will appear in the uniquifiedoutput (#9184).
Bug fixes:
host_callback now supports ad_checkpoint.checkpoint (#8907).
New features:
add
jax.block_until_ready({jax-issue}`#8941)Added a new debugging flag/environment variable
JAX_DUMP_IR_TO=/path.If set, JAX dumps the MHLO/HLO IR it generates for each computation to afile under the given path.Added
jax.ensure_compile_time_evalto the public api (#7987).jax2tf now supports a flag jax2tf_associative_scan_reductions to changethe lowering for associative reductions, e.g., jnp.cumsum, to behavelike JAX on CPU and GPU (to use an associative scan). See the jax2tf READMEfor more details (#9189).
jaxlib 0.1.75 (Dec 8, 2021)#
New features:
Support for python 3.10.
jax 0.2.26 (Dec 8, 2021)#
Bug fixes:
Out-of-bounds indices to
jax.ops.segment_sumwill now be handled withFILL_OR_DROPsemantics, as documented. This primarily affects thereverse-mode derivative, where gradients corresponding to out-of-boundsindices will now be returned as 0. (#8634).jax2tf will force the converted code to use XLA for the code fragmentsunder jax.jit, e.g., most jax.numpy functions (#7839).
jaxlib 0.1.74 (Nov 17, 2021)#
Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced viathe host, which is usually slower.
Added experimental MLIR Python bindings for use by JAX.
jax 0.2.25 (Nov 10, 2021)#
New features:
(Experimental)
jax.distributed.initializeexposes multi-host GPU backend.jax.random.permutationsupports newindependentkeyword argument(#8430)
Breaking changes
Moved
jax.experimental.staxtojax.example_libraries.staxMoved
jax.experimental.optimizerstojax.example_libraries.optimizers
New features:
Added
jax.lax.linalg.qdwh.
jax 0.2.24 (Oct 19, 2021)#
jaxlib 0.1.73 (Oct 18, 2021)#
Multiple cuDNN versions are now supported for jaxlib GPU
cuda11wheels.cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNNinstallation is new enough, since it supports additional functionality.
cuDNN 8.0.5 or newer.
Breaking changes:
The install commands for GPU jaxlib are as follows:
pipinstall--upgradepip# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.pipinstall--upgrade"jax[cuda]"-fhttps://storage.googleapis.com/jax-releases/jax_releases.html# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.pipinstalljax[cuda11_cudnn82]-fhttps://storage.googleapis.com/jax-releases/jax_releases.html# Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.pipinstalljax[cuda11_cudnn805]-fhttps://storage.googleapis.com/jax-releases/jax_releases.html
jax 0.2.22 (Oct 12, 2021)#
Breaking Changes
Static arguments to
jax.pmapmust now be hashable.Unhashable static arguments have long been disallowed on
jax.jit, but theywere still permitted onjax.pmap;jax.pmapcompared unhashable staticarguments using object identity.This behavior is a footgun, since comparing arguments usingobject identity leads to recompilation each time the object identitychanges. Instead, we now ban unhashable arguments: if a user of
jax.pmapwants to compare static arguments by object identity, they can define__hash__and__eq__methods on their objects that do that, or wrap theirobjects in an object that has those operations with object identitysemantics. Another option is to usefunctools.partialto encapsulate theunhashable static arguments into the function object.jax.util.partialwas an accidental export that has now been removed. Usefunctools.partialfrom the Python standard library instead.
Deprecations
The functions
jax.ops.index_update,jax.ops.index_addetc. aredeprecated and will be removed in a future JAX release. Please usethe.atproperty on JAX arraysinstead, e.g.,x.at[idx].set(y). For now, these functions produce aDeprecationWarning.
New features:
An optimized C++ code-path improving the dispatch time for
pmapis now thedefault when using jaxlib 0.1.72 or newer. The feature can be disabled usingthe--experimental_cpp_pmapflag (orJAX_CPP_PMAPenvironment variable).jax.numpy.uniquenow supports an optionalfill_valueargument (#8121)
jaxlib 0.1.72 (Oct 12, 2021)#
Breaking changes:
Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supportsCUDA 11.1+.
Bug fixes:
Fixes https://github.com/jax-ml/jax/issues/7461, which caused wrongoutputs on all platforms due to incorrect buffer aliasing inside the XLAcompiler.
jax 0.2.21 (Sept 23, 2021)#
Breaking Changes
jax.apihas been removed. Functions that were available asjax.api.*were aliases for functions injax.*; please use the functions injax.*instead.jax.partial, andjax.lax.partialwere accidental exports that have nowbeen removed. Usefunctools.partialfrom the Python standard libraryinstead.Boolean scalar indices now raise a
TypeError; previously this silentlyreturned wrong results (#7925).Many more
jax.numpyfunctions now require array-like inputs, and will errorif passed a list (#7747#7802#7907).See#7737 for a discussion of the rationale behind this change.When inside a transformation such as
jax.jit,jax.numpy.arrayalwaysstages the array it produces into the traced computation. Previouslyjax.numpy.arraywould sometimes produce a on-device array, even underajax.jitdecorator. This change may break code that used JAX arrays toperform shape or index computations that must be known statically; theworkaround is to perform such computations using classic NumPy arraysinstead.jnp.ndarrayis now a true base-class for JAX arrays. In particular, thismeans that for a standard numpy arrayx,isinstance(x,jnp.ndarray)willnow returnFalse(#7927).
New features:
Added
jax.numpy.insert()implementation (#7936).
jax 0.2.20 (Sept 2, 2021)#
Breaking Changes
jaxlib 0.1.71 (Sep 1, 2021)#
Breaking changes:
Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supportsCUDA 10.2 and CUDA 11.1+.
jax 0.2.19 (Aug 12, 2021)#
Breaking changes:
Support for NumPy 1.17 has been dropped, per thedeprecation policy.Please upgrade to a supported NumPy version.
The
jitdecorator has been added around the implementation of a number ofoperators on JAX arrays. This speeds up dispatch times for commonoperators such as+.This change should largely be transparent to most users. However, there isone known behavioral change, which is that large integer constants may nowproduce an error when passed directly to a JAX operator(e.g.,
x+2**40). The workaround is to cast the constant to anexplicit type (e.g.,np.float64(2**40)).
New features:
Improved the support for shape polymorphism in jax2tf for operations thatneed to use a dimension size in array computation, e.g.,
jnp.mean.(#7317)
Bug fixes:
Some leaked trace errors from the previous release (#7613)
jaxlib 0.1.70 (Aug 9, 2021)#
Breaking changes:
Support for Python 3.6 has been dropped, per thedeprecation policy.Please upgrade to a supported Python version.
Support for NumPy 1.17 has been dropped, per thedeprecation policy.Please upgrade to a supported NumPy version.
The host_callback mechanism now uses one thread per local device formaking the calls to the Python callbacks. Previously there was a singlethread for all devices. This means that the callbacks may now be calledinterleaved. The callbacks corresponding to one device will still becalled in sequence.
jax 0.2.18 (July 21 2021)#
Breaking changes:
Support for Python 3.6 has been dropped, per thedeprecation policy.Please upgrade to a supported Python version.
The minimum jaxlib version is now 0.1.69.
The
backendargument tojax.dlpack.from_dlpack()has beenremoved.
New features:
Added a polar decomposition (
jax.scipy.linalg.polar()).
Bug fixes:
Tightened the checks for lax.argmin and lax.argmax to ensure they arenot used with an invalid
axisvalue, or with an empty reduction dimension.(#7196)
jaxlib 0.1.69 (July 9 2021)#
Fix bugs in TFRT CPU backend that results in incorrect results.
jax 0.2.17 (July 9 2021)#
Bug fixes:
Default to the older “stream_executor” CPU runtime for jaxlib <= 0.1.68to work around #7229, which caused wrong outputs on CPU due to a concurrencyproblem.
New features:
New SciPy function
jax.scipy.special.sph_harm().Reverse-mode autodiff functions (
jax.grad(),jax.value_and_grad(),jax.vjp(), andjax.linear_transpose()) support a parameter that indicates which namedaxes should be summed over in the backward pass if they were broadcastedover in the forward pass. This enables use of these APIs in anon-per-example way inside maps (initially onlyjax.experimental.maps.xmap()) (#6950).
jax 0.2.16 (June 23 2021)#
jax 0.2.15 (June 23 2021)#
New features:
#7042 Turned on TFRT CPU backendwith significant dispatch performance improvements on CPU.
The
jax2tf.convert()supports inequalities and min/max for booleans(#6956).New SciPy function
jax.scipy.special.lpmn_values().
Breaking changes:
Support for NumPy 1.16 has been dropped, per thedeprecation policy.
Bug fixes:
Fixed bug that prevented round-tripping from JAX to TF and back:
jax2tf.call_tf(jax2tf.convert)(#6947).
jaxlib 0.1.68 (June 23 2021)#
Bug fixes:
Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer toCPU.
jax 0.2.14 (June 10 2021)#
New features:
The
jax2tf.convert()now has support forpjitandsharded_jit.A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filterstracebacks.
A new traceback filtering mode using
__tracebackhide__is now enabled bydefault in sufficiently recent versions of IPython.The
jax2tf.convert()supports shape polymorphism even when theunknown dimensions are used in arithmetic operations, e.g.,jnp.reshape(-1)(#6827).The
jax2tf.convert()generates custom attributes with location informationin TF ops. The code that XLA generates after jax2tfhas the same location information as JAX/XLA.New SciPy function
jax.scipy.special.lpmn().
Bug fixes:
The
jax2tf.convert()now ensures that it uses the same typing rulesfor Python scalars and for choosing 32-bit vs. 64-bit computationsas JAX (#6883).The
jax2tf.convert()now scopes theenable_xlaconversion parameterproperly to apply only during the just-in-time conversion(#6720).The
jax2tf.convert()now convertslax.dot_generalusing theXlaDotTensorFlow op, for better fidelity w.r.t. JAX numerical precision(#6717).The
jax2tf.convert()now has support for inequality comparisons andmin/max for complex numbers (#6892).
jaxlib 0.1.67 (May 17 2021)#
jaxlib 0.1.66 (May 11 2021)#
New features:
CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.
NVidia now promises compatibility between CUDA minor releases starting withCUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel thatis compatible with CUDA 11.2 and 11.3.
There is no longer a separate jaxlib release for CUDA 11.2 (or higher); usethe CUDA 11.1 wheel for those versions (cuda111).
Jaxlib now bundles
libdevice.10.bcin CUDA wheels. There should be no needto point JAX to a CUDA installation to find this file.Added automatic support for static keyword arguments to the
jit()implementation.Added support for pretransformation exception traces.
Initial support for pruning unused arguments from
jit()-transformedcomputations.Pruning is still a work in progress.Improved the string representation of
PyTreeDefobjects.Added support for XLA’s variadic ReduceWindow.
Bug fixes:
Fixed a bug in the remote cloud TPU support when large numbers of argumentsare passed to a computation.
Fix a bug that meant that JAX garbage collection was not triggered by
jit()transformed functions.
jax 0.2.13 (May 3 2021)#
New features:
When combined with jaxlib 0.1.66,
jax.jit()now supports statickeyword arguments. A newstatic_argnamesoption has been added to specifykeyword arguments as static.jax.nonzero()has a new optionalsizeargument that allows it tobe used withinjit(#6501)jax.numpy.unique()now supports theaxisargument (#6532).jax.experimental.host_callback.call()now supportspjit.pjit(#6569).Added
jax.scipy.linalg.eigh_tridiagonal()that computes theeigenvalues of a tridiagonal matrix. Only eigenvalues are supported atpresent.The order of the filtered and unfiltered stack traces in exceptions has beenchanged. The traceback attached to an exception thrown from JAX-transformedcode is now filtered, with an
UnfilteredStackTraceexceptioncontaining the original trace as the__cause__of the filtered exception.Filtered stack traces now also work with Python 3.6.If an exception is thrown by code that has been transformed by reverse-modeautomatic differentiation, JAX now attempts to attach as a
__cause__ofthe exception aJaxStackTraceBeforeTransformationobject that contains thestack trace that created the original operation in the forward pass.Requires jaxlib 0.1.66.
Breaking changes:
The following function names have changed. There are still aliases, so thisshould not break existing code, but the aliases will eventually be removedso please change your code.
host_id–>process_index()host_count–>process_count()host_ids–>range(jax.process_count())
Similarly, the argument to
local_devices()has been renamed fromhost_idtoprocess_index.Arguments to
jax.jit()other than the function are now marked askeyword-only. This change is to prevent accidental breakage when argumentsare added tojit.
Bug fixes:
jaxlib 0.1.65 (April 7 2021)#
jax 0.2.12 (April 1 2021)#
New features
New profiling APIs:
jax.profiler.start_trace(),jax.profiler.stop_trace(), andjax.profiler.trace()jax.lax.reduce()is now differentiable.
Breaking changes:
The minimum jaxlib version is now 0.1.64.
Some profiler APIs names have been changed. There are still aliases, so thisshould not break existing code, but the aliases will eventually be removedso please change your code.
TraceContext–>TraceAnnotation()StepTraceContext–>StepTraceAnnotation()trace_function–>annotate_function()
Omnistaging can no longer be disabled. Seeomnistagingfor more information.
Python integers larger than the maximum
int64value will now lead to an overflowin all cases, rather than being silently converted touint64in some cases (#6047).Outside X64 mode, Python integers outside the range representable by
int32will now lead to anOverflowErrorrather than having their value silently truncated.
Bug fixes:
host_callbacknow supports empty arrays in arguments and results (#6262).jax.random.randint()clips rather than wraps of out-of-bounds limits, and can now generateintegers in the full range of the specified dtype (#5868)
jax 0.2.11 (March 23 2021)#
New features:
Bug fixes:
#6136 generalized
jax.flatten_util.ravel_pytreeto handle integer dtypes.#6129 fixed a bug with handlingsome constants like
enum.IntEnums#6145 fixed batching issues withincomplete beta functions
#6014 fixed H2D transfers duringtracing
#6165 avoids OverflowErrors whenconverting some large Python integers to floats
Breaking changes:
The minimum jaxlib version is now 0.1.62.
jaxlib 0.1.64 (March 18 2021)#
jaxlib 0.1.63 (March 17 2021)#
jax 0.2.10 (March 5 2021)#
New features:
jax.scipy.stats.chi2()is now available as a distribution with logpdf and pdf methods.jax.scipy.stats.betabinom()is now available as a distribution with logpmf and pmf methods.Added
jax.experimental.jax2tf.call_tf()to call TensorFlow functionsfrom JAX (#5627)andREADME).Extended the batching rule for
lax.padto support batching of the padding values.
Bug fixes:
jax.numpy.take()properly handles negative indices (#5768)
Breaking changes:
JAX’s promotion rules were adjusted to make promotion more consistent andinvariant to JIT. In particular, binary operations can now result in weakly-typedvalues when appropriate. The main user-visible effect of the change is thatsome operations result in outputs of different precision than before; forexample the expression
jnp.bfloat16(1)+0.1*jnp.arange(10)previously returned afloat64array, and now returns abfloat16array.JAX’s type promotion behavior is described atType promotion semantics.jax.numpy.linspace()now computes the floor of integer values, i.e.,rounding towards -inf rather than 0. This change was made to match NumPy1.20.0.jax.numpy.i0()no longer accepts complex numbers. Previously thefunction computed the absolute value of complex arguments. This change wasmade to match the semantics of NumPy 1.20.0.Several
jax.numpyfunctions no longer accept tuples or lists in placeof array arguments:jax.numpy.pad(), :funcjax.numpy.ravel,jax.numpy.repeat(),jax.numpy.reshape().In general,jax.numpyfunctions should be used with scalars or array arguments.
jaxlib 0.1.62 (March 9 2021)#
New features:
jaxlib wheels are now built to require AVX instructions on x86-64 machinesby default. If you want to use JAX on a machine that doesn’t support AVX,you can build a jaxlib from source using the
--target_cpu_featuresflagtobuild.py.--target_cpu_featuresalso replaces--enable_march_native.
jaxlib 0.1.61 (February 12 2021)#
jaxlib 0.1.60 (February 3 2021)#
Bug fixes:
Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. Thememory leak was present in jaxlib releases 0.1.58 and 0.1.59.
bool,int8, anduint8are now considered safe to cast tobfloat16NumPy extension type.
jax 0.2.9 (January 26 2021)#
New features:
Extend the
jax.experimental.loopsmodule with support for pytrees. Improvederror checking and error messages.Add
jax.experimental.enable_x64()andjax.experimental.disable_x64().These are context managers which allow X64 mode to be temporarily enabled/disabledwithin a session.
Breaking changes:
jax.ops.segment_sum()now drops segment IDs that are out of range ratherthan wrapping them into the segment ID space. This was done for performancereasons.
jaxlib 0.1.59 (January 15 2021)#
jax 0.2.8 (January 12 2021)#
New features:
Add
jax.closure_convert()for use with higher-order customderivative functions. (#5244)Add
jax.experimental.host_callback.call()to call a custom Pythonfunction on the host and return a result to the device computation.(#5243)
Bug fixes:
jax.numpy.arccoshnow returns the same branch asnumpy.arccoshforcomplex inputs (#5156)host_callback.id_tapnow works forjax.pmapalso. There is anoptional parameter forid_tapandid_printto request that thedevice from which the value is tapped be passed as a keyword argumentto the tap function (#5182).
Breaking changes:
jax.numpy.padnow takes keyword arguments. Positional argumentconstant_valueshas been removed. In addition, passing unsupported keyword arguments raises an error.Changes for
jax.experimental.host_callback.id_tap()(#5243):Removed support for
kwargsforjax.experimental.host_callback.id_tap().(This support has been deprecated for a few months.)Changed the printing of tuples for
jax.experimental.host_callback.id_print()to use ‘(’ instead of ‘[‘.Changed the
jax.experimental.host_callback.id_print()in presence of JVPto print a pair of primal and tangent. Previously, there were two separateprint operations for the primals and the tangent.host_callback.outfeed_receiverhas been removed (it is not necessary,and was deprecated a few months ago).
New features:
New flag for debugging
inf, analogous to that forNaN(#5224).
jax 0.2.7 (Dec 4 2020)#
New features:
Add
jax.device_put_replicatedAdd multi-host support to
jax.experimental.sharded_jitAdd support for differentiating eigenvalues computed by
jax.numpy.linalg.eigAdd support for building on Windows platforms
Add support for general in_axes and out_axes in
jax.pmapAdd complex support for
jax.numpy.linalg.slogdet
Bug fixes:
Fix higher-than-second order derivatives of
jax.numpy.sincat zeroFix some hard-to-hit bugs around symbolic zeros in transpose rules
Breaking changes:
jax.experimental.optixhas been deleted, in favor of the standaloneoptaxPython package.indexing of JAX arrays with non-tuple sequences now raises a
TypeError. This type of indexinghas been deprecated in Numpy since v1.16, and in JAX since v0.2.4.See#4564.
jax 0.2.6 (Nov 18 2020)#
New Features:
Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter.SeeREADME.md.
Breaking change cleanup
Raise an error on non-hashable static arguments for jax.jit andxla_computation. Seecb48f42.
Improve consistency of type promotion behavior (#4744):
Adding a complex Python scalar to a JAX floating point number respects the precision ofthe JAX float. For example,
jnp.float32(1)+1jnow returnscomplex64, where previouslyit returnedcomplex128.Results of type promotion with 3 or more terms involving uint64, a signed int, and a third typeare now independent of the order of arguments. For example:
jnp.result_type(jnp.uint64,jnp.int64,jnp.float16)andjnp.result_type(jnp.float16,jnp.uint64,jnp.int64)both returnfloat16, where previouslythe first returnedfloat64and the second returnedfloat16.
The contents of the (undocumented)
jax.lax_linalglinear algebra moduleare now exposed publicly asjax.lax.linalg.jax.random.PRNGKeynow produces the same results in and out of JIT compilation(#4877).This required changing the result for a given seed in a few particular cases:With
jax_enable_x64=False, negative seeds passed as Python integers now return a different resultoutside JIT mode. For example,jax.random.PRNGKey(-1)previously returned[4294967295,4294967295], and now returns[0,4294967295]. This matches the behavior in JIT.Seeds outside the range representable by
int64outside JIT now result in anOverflowErrorrather than aTypeError. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with
jax_enable_x64=Falseoutside JIT, you can use:key=random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
DeviceArray now raises
RuntimeErrorinstead ofValueErrorwhen tryingto access its value while it has been deleted.
jaxlib 0.1.58 (January 12ish 2021)#
Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
np.cint) instead of standard types (e.g.,np.int32). (#4903)Fixed a crash when constant-folding certain int16 operations. (#4971)
Added an
is_leafpredicate topytree.flatten().
jaxlib 0.1.57 (November 12 2020)#
Fixed manylinux2010 compliance issues in GPU wheels.
Switched the CPU FFT implementation from Eigen to PocketFFT.
Fixed a bug where the hash of bfloat16 values was not correctly initializedand could change (#4651).
Add support for retaining ownership when passing arrays to DLPack (#4636).
Fixed a bug for batched triangular solves with sizes greater than 128 but nota multiple of 128.
Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
Fixed a bug in profiler where tools are missing (#4427).
Dropped support for CUDA 10.0.
jax 0.2.5 (October 27 2020)#
Improvements:
Ensure that
check_jaxprdoes not perform FLOPS. See#4650.Expanded the set of JAX primitives converted by jax2tf.Seeprimitives_with_limited_support.md.
jax 0.2.4 (October 19 2020)#
jaxlib 0.1.56 (October 14, 2020)#
jax 0.2.3 (October 14 2020)#
The reason for another release so soon is we need to temporarily roll back anew jit fastpath while we look into a performance degradation
jax 0.2.2 (October 13 2020)#
jax 0.2.1 (October 6 2020)#
Improvements:
As a benefit of omnistaging, the host_callback functions are executed (in programorder) even if the result of the
jax.experimental.host_callback.id_print()/jax.experimental.host_callback.id_tap()is not used in the computation.
jax (0.2.0) (September 23 2020)#
Improvements:
Omnistaging on by default. See#3370 andomnistaging
jax (0.1.77) (September 15 2020)#
Breaking changes:
New simplified interface for
jax.experimental.host_callback.id_tap()(#4101)
jaxlib 0.1.55 (September 8, 2020)#
Update XLA:
Fix bug in DLPackManagedTensorToBuffer (#4196)
jax 0.1.76 (September 8, 2020)#
jax 0.1.75 (July 30, 2020)#
Bug Fixes:
make jnp.abs() work for unsigned inputs (#3914)
Improvements:
“Omnistaging” behavior added behind a flag, disabled by default (#3370)
jax 0.1.74 (July 29, 2020)#
New Features:
BFGS (#3101)
TPU support for half-precision arithmetic (#3878)
Bug Fixes:
Prevent some accidental dtype warnings (#3874)
Fix a multi-threading bug in custom derivatives (#3845, #3869)
Improvements:
Faster searchsorted implementation (#3873)
Better test coverage for jax.numpy sorting algorithms (#3836)
jaxlib 0.1.52 (July 22, 2020)#
Update XLA.
jax 0.1.73 (July 22, 2020)#
The minimum jaxlib version is now 0.1.51.
New Features:
jax.image.resize. (#3703)
hfft and ihfft (#3664)
jax.numpy.intersect1d (#3726)
jax.numpy.lexsort (#3812)
lax.scanand thescanprimitive support anunrollparameter for loop unrolling when lowering to XLA(#3738).
Bug Fixes:
Fix reduction repeated axis error (#3618)
Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
make psum transpose handle zero cotangents (#3653)
Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
Support differentiation through jax.lax.all_to_all (#3733)
address nan issue in jax.scipy.special.zeta (#3777)
Improvements:
Many improvements to jax2tf
Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
Enable XLA SPMD partitioning by default. (#3151)
Add support for 0d transpose convolution (#3643)
Make LU gradient work for low-rank matrices (#3610)
support multiple_results and custom JVPs in jet (#3657)
Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
Implement complex convolutions on CPU and GPU. (#3735)
Make jnp.take work for empty slices of empty arrays. (#3751)
Relax dimension ordering rules for dot_general. (#3778)
Enable buffer donation for GPU. (#3800)
Add support for base dilation and window dilation to reduce window op… (#3803)
jaxlib 0.1.51 (July 2, 2020)#
Update XLA.
Add new runtime support for host_callback.
jax 0.1.72 (June 28, 2020)#
Bug fixes:
Fix an odeint bug introduced in the previous release, see#3587.
jax 0.1.71 (June 25, 2020)#
The minimum jaxlib version is now 0.1.48.
Bug fixes:
Allow
jax.experimental.ode.odeintdynamics functions to close overvalues with respect to which we’re differentiating#3562.
jaxlib 0.1.50 (June 25, 2020)#
Add support for CUDA 11.0.
Drop support for CUDA 9.2 (we only maintain support for the last four CUDAversions.)
Update XLA.
jaxlib 0.1.49 (June 19, 2020)#
Bug fixes:
Fix build issue that could result in slow compiles(tensorflow/tensorflow)
jaxlib 0.1.48 (June 12, 2020)#
New features:
Adds support for fast traceback collection.
Adds preliminary support for on-device heap profiling.
Implements
np.nextafterforbfloat16types.Complex128 support for FFTs on CPU and GPU.
Bug fixes:
Improved float64
tanhaccuracy on GPU.float64 scatters on GPU are much faster.
Complex matrix multiplication on CPU should be much faster.
Stable sorts on CPU should actually be stable now.
Concurrency bug fix in CPU backend.
jax 0.1.70 (June 8, 2020)#
New features:
lax.switchintroduces indexed conditionals with multiplebranches, together with a generalization of thecondprimitive#3318.
jax 0.1.69 (June 3, 2020)#
jax 0.1.68 (May 21, 2020)#
jax 0.1.67 (May 12, 2020)#
New features:
Support for reduction over subsets of a pmapped axis using
axis_index_groups#2382.Experimental support for printing and calling host-side Python function fromcompiled code. Seeid_print and id_tap(#3006).
Notable changes:
The visibility of names exported from
jax.numpyhas beentightened. This may break code that was making use of names that werepreviously exported accidentally.
jaxlib 0.1.47 (May 8, 2020)#
Fixes crash for outfeed.
jax 0.1.66 (May 5, 2020)#
New features:
Support for
in_axes=Noneonpmap()#2896.
jaxlib 0.1.46 (May 5, 2020)#
Fixes crash for linear algebra functions on Mac OS X (#432).
Fixes an illegal instruction crash caused by using AVX512 instructions whenan operating system or hypervisor disabled them (#2906).
jax 0.1.65 (April 30, 2020)#
New features:
Differentiation of determinants of singular matrices#2809.
Bug fixes:
jaxlib 0.1.45 (April 21, 2020)#
Fixes segfault:#2755
Plumb is_stable option on Sort HLO through to Python.
jax 0.1.64 (April 21, 2020)#
New features:
Add syntactic sugar for functional indexed updates#2684.
Add more primitive rules for
jax.experimental.jet().
Bug fixes:
Better errors:
Improves error message for reverse-mode differentiation of
lax.while_loop()#2129.
jaxlib 0.1.44 (April 16, 2020)#
Fixes a bug where if multiple GPUs of different models were present, JAXwould only compile programs suitable for the first GPU.
Bugfix for
batch_group_countconvolutions.Added precompiled SASS for more GPU versions to avoid startup PTX compilationhang.
jax 0.1.63 (April 12, 2020)#
Added
jax.custom_jvpandjax.custom_vjpfrom#2026, see thetutorial notebook. Deprecatedjax.custom_transformsand removed it from the docs (though it still works).Add
scipy.sparse.linalg.cg#2566.Changed how Tracers are printed to show more useful information for debugging#2591.
Made
jax.numpy.isclosehandlenanandinfcorrectly#2501.Added several new rules for
jax.experimental.jet#2537.Fixed
jax.experimental.stax.BatchNormwhenscale/centerisn’t provided.Fix some missing cases of broadcasting in
jax.numpy.einsum#2512.Implement
jax.numpy.cumsumandjax.numpy.cumprodin terms of a parallel prefix scan#2596 and makereduce_proddifferentiable to arbitrary order#2597.Add
batch_group_counttoconv_general_dilated#2635.Add docstring for
test_util.check_grads#2656.Add
callback_transform#2665.Implement
rollaxis,convolve/correlate1d & 2d,copysign,trunc,roots, andquantile/percentileinterpolation options.
jaxlib 0.1.43 (March 31, 2020)#
Fixed a performance regression for Resnet-50 on GPU.
jax 0.1.62 (March 21, 2020)#
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
Removed the internal function
lax._safe_mul, which implemented theconvention0.*nan==0.. This change means some programs whendifferentiated will produce nans when they previously produced correctvalues, though it ensures nans rather than silently incorrect results areproduced for other programs. See #2447 and #1052 for details.Added an
all_gatherparallel convenience function.More type annotations in core code.
jaxlib 0.1.42 (March 19, 2020)#
jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. Thisrelease fixes it again.
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
jax 0.1.61 (March 17, 2020)#
Fixes Python 3.5 support. This will be the last JAX or jaxlib release thatsupports Python 3.5.
jax 0.1.60 (March 17, 2020)#
New features:
jax.pmap()hasstatic_broadcast_argnumsargument which allowsthe user to specify arguments that should be treated as compile-timeconstants and should be broadcasted to all devices. It works analogously tostatic_argnumsinjax.jit().Improved error messages for when tracers are mistakenly saved in global state.
Added
jax.nn.one_hot()utility function.Added
jax.experimental.jetfor exponentially fasterhigher-order automatic differentiation.Added more correctness checking to arguments of
jax.lax.broadcast_in_dim().
The minimum jaxlib version is now 0.1.41.
jaxlib 0.1.40 (March 4, 2020)#
Adds experimental support in Jaxlib for TensorFlow profiler, which allowstracing of CPU and GPU computations from TensorBoard.
Includes prototype support for multihost GPU computations that communicate viaNCCL.
Improves performance of NCCL collectives on GPU.
Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA andRandomGamma implementations.
Supports device assignments known at XLA compilation time.
jax 0.1.59 (February 11, 2020)#
Breaking changes
The minimum jaxlib version is now 0.1.38.
Simplified
Jaxprby removing theJaxpr.freevarsandJaxpr.bound_subjaxprs. The call primitives (xla_call,xla_pmap,sharded_call, andremat_call) get a new parametercall_jaxprwith afully-closed (noconstvars) jaxpr. Also, added a new fieldcall_primitiveto primitives.
New features:
Reverse-mode automatic differentiation (e.g.
grad) oflax.cond, making itnow differentiable in both modes (#2091)JAX now supports DLPack, which allows sharing CPU and GPU arrays in azero-copy way with other libraries, such as PyTorch.
JAX GPU DeviceArrays now support
__cuda_array_interface__, which is anotherzero-copy protocol for sharing GPU arrays with other libraries such as CuPyand Numba.JAX CPU device buffers now implement the Python buffer protocol, which allowszero-copy buffer sharing between JAX and NumPy.
Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
jaxlib 0.1.39 (February 11, 2020)#
Updates XLA.
jaxlib 0.1.38 (January 29, 2020)#
CUDA 9.0 is no longer supported.
CUDA 10.2 wheels are now built by default.
jax 0.1.58 (January 28, 2020)#
Breaking changes
JAX has dropped Python 2 support, because Python 2 reached its end of life onJanuary 1, 2020. Please update to Python 3.5 or newer.
New features
Forward-mode automatic differentiation (
jvp) of while loop(#1980)
New NumPy and SciPy functions:
Batched Cholesky decomposition on GPU now uses a more efficient batchedkernel.
Notable bug fixes#
With the Python 3 upgrade, JAX no longer depends on
fastcache, which shouldhelp with installation.
