Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
/jaxPublic

Releases: jax-ml/jax

JAX v0.6.2

17 Jun 23:06
Compare
Choose a tag to compare
Loading
  • New features:

    • Addedjax.tree.broadcast which implements a pytree prefix broadcasting helper.
  • Changes

    • The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12.
Assets2
Loading
Kelvinlby, imgengineer, andyshuo, and lin72h reacted with hooray emojiYigitElma, pierrot-lc, andyshuo, and lin72h reacted with heart emoji
6 people reacted

JAX v0.6.1

21 May 18:30
Compare
Choose a tag to compare
Loading
  • New features:

    • Addedjax.lax.axis_size which returns the size of the mapped axis
      given its name.
  • Changes

    • Additional checking for the versions of CUDA package dependencies was
      reenabled, having been accidentally disabled in a previous release.
    • JAX nightly packages are now published to artifact registry. To install
      these packages, see theJAX installation guide.
    • jax.sharding.PartitionSpec no longer inherits from a tuple.
    • jax.ShapeDtypeStruct is immutable now. Please use.update method to
      update yourShapeDtypeStruct instead of doing in-place updates.
  • Deprecations

    • jax.custom_derivatives.custom_jvp_call_jaxpr_p is deprecated, and will be
      removed in JAX v0.7.0.
Loading
etiennelndr, mariogeiger, cemag1, therealjtgill, Mickychen00, s-ohtake, mehdikhodadadeh, and tengyifei reacted with thumbs up emojiNeilGirdhar, kasper0406, therealjtgill, priyakasimbeg, s-ohtake, and Syderitic reacted with heart emoji
12 people reacted

JAX v0.6.0

17 Apr 00:04
Compare
Choose a tag to compare
Loading
  • Breaking changes

    • jax.numpy.array no longer acceptsNone. This behavior was
      deprecated since November 2023 and is now removed.
    • Removed theconfig.jax_data_dependent_tracing_fallback config option,
      which was added temporarily in v0.4.36 to allow users to opt out of the
      new "stackless" tracing machinery.
    • Removed theconfig.jax_eager_pmap config option.
    • Disallow the calling oflower andtrace AOT APIs on the result
      ofjax.jit if there have been subsequent wrappers applied.
      Previously this worked, but silently ignored the wrappers.
      The workaround is to applyjax.jit last among the wrappers,
      and similarly forjax.pmap.
      See#27873.
    • Thecuda12_pip extra forjax has been removed; usepip install jax[cuda12]
      instead.
  • Changes

    • The minimum CuDNN version is v9.8.
    • JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain
      supported.
    • JAX package extras are now updated to use dash instead of underscore to
      align with PEP 685. For instance, if you were previously usingpip install jax[cuda12_local]
      to install JAX, runpip install jax[cuda12-local] instead.
    • jax.jit now requiresfun to be passed by position, and additional
      arguments to be passed by keyword. Doing otherwise will result in a
      DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
  • Deprecations

    • jax.tree_util.build_tree is deprecated. Usejax.tree.unflatten
      instead.
    • Implemented host callback handlers for CPU and GPU devices using XLA's FFI
      and removed existing CPU/GPU handlers using XLA's custom call.
    • All APIs injax.lib.xla_extension are now deprecated.
    • jax.interpreters.mlir.hlo andjax.interpreters.mlir.func_dialect,
      which were accidental exports, have been removed. If needed, they are
      available fromjax.extend.mlir.
    • jax.interpreters.mlir.custom_call is deprecated. The APIs provided by
      jax.ffi should be used instead.
    • The deprecated use ofjax.ffi.ffi_call with inline arguments is no
      longer supported.jax.ffi.ffi_call now unconditionally returns a
      callable.
    • The following exports injax.lib.xla_client are deprecated:
      get_topology_for_devices,heap_profile,mlir_api_version,Client,
      CompileOptions,DeviceAssignment,Frame,HloSharding,OpSharding,
      Traceback.
    • The following internal APIs injax.util are deprecated:
      HashableFunction,as_hashable_function,cache,safe_map,safe_zip,
      split_dict,split_list,split_list_checked,split_merge,subvals,
      toposort,unzip2,wrap_name, andwraps.
    • jax.dlpack.to_dlpack has been deprecated. You can usually pass a JAX
      Array directly to thefrom_dlpack function of another framework. If you
      need the functionality ofto_dlpack, use the__dlpack__ attribute of an
      array.
    • jax.lax.infeed,jax.lax.infeed_p,jax.lax.outfeed, and
      jax.lax.outfeed_p are deprecated and will be removed in JAX v0.7.0.
    • Several previously-deprecated APIs have been removed, including:
      • Fromjax.lib.xla_client:ArrayImpl,FftType,PaddingType,
        PrimitiveType,XlaBuilder,dtype_to_etype,
        ops,register_custom_call_target,shape_from_pyval,Shape,
        XlaComputation.
      • Fromjax.lib.xla_extension:ArrayImpl,XlaRuntimeError.
      • Fromjax:jax.treedef_is_leaf,jax.tree_flatten,jax.tree_map,
        jax.tree_leaves,jax.tree_structure,jax.tree_transpose, and
        jax.tree_unflatten. Replacements can be found injax.tree or
        jax.tree_util.
      • Fromjax.core:AxisSize,ClosedJaxpr,EvalTrace,InDBIdx,InputType,
        Jaxpr,JaxprEqn,Literal,MapPrimitive,OpaqueTraceState,OutDBIdx,
        Primitive,Token,TRACER_LEAK_DEBUGGER_WARNING,Var,concrete_aval,
        dedup_referents,escaped_tracer_error,extend_axis_env_nd,full_lower,get_referent,jaxpr_as_fun,join_effects,lattice_join,
        leaked_tracer_error,maybe_find_leaked_tracers,raise_to_shaped,
        raise_to_shaped_mappings,reset_trace_state,str_eqn_compact,
        substitute_vars_in_output_ty,typecompat, andused_axis_names_jaxpr. Most
        have no public replacement, though a few are available atjax.extend.core.
      • Thevectorized argument tojax.pure_callback and
        jax.ffi.ffi_call. Use thevmap_method parameter instead.
Loading
ErcinDedeoglu, SRSohag33, janosh, egeres, NeilGirdhar, etiennelndr, therealjtgill, csmile-1006, stergiosba, latexalpha, and jondeaton reacted with thumbs up emojiErcinDedeoglu, jondeaton, pierrot-lc, therealjtgill, free2ride19, and rafaelha reacted with hooray emojibionicles, jondeaton, and mariogeiger reacted with heart emojiarrufat, ErcinDedeoglu, jondeaton, dorjeduck, Nathan-Furnal, kasper0406, Sibgatulin, pierrot-lc, therealjtgill, pierreguilmin, and 5 more reacted with rocket emoji
27 people reacted

JAX v0.5.3

19 Mar 18:20
Compare
Choose a tag to compare
Loading
  • New Features

    • Added aallow_negative_indices option tojax.lax.dynamic_slice,
      jax.lax.dynamic_update_slice and related functions. The default is
      true, matching the current behavior. If set to false, JAX does not need to
      emit code clamping negative indices, which improves code size.
    • Added areplace option tojax.random.categorical to enable sampling
      without replacement.
Loading
etiennelndr, idozii, SepandKashani, Nathan-Furnal, nimahsn, kasper0406, codekansas, AmiThayaga07, and reneleonhardt reacted with thumbs up emoji
9 people reacted

JAX v0.5.2

05 Mar 02:36
Compare
Choose a tag to compare
Loading

Patch release of 0.5.1

  • Bug fixes
    • Fixes TPU metric logging andtpu-info, which was broken in 0.5.1
Loading
etiennelndr, mitchell-solomon, ErcinDedeoglu, GodRishUniverse, and AmiThayaga07 reacted with thumbs up emojiErcinDedeoglu reacted with hooray emojidavindicode, sabrinastronomy, Nathan-Furnal, and ErcinDedeoglu reacted with rocket emoji
8 people reacted

JAX v0.5.1

24 Feb 21:03
Compare
Choose a tag to compare
Loading
  • New Features

    • Added an experimentaljax.experimental.custom_dce.custom_dce
      decorator to support customizing the behavior of opaque functions under
      JAX-level dead code elimination (DCE). See#25956 for more
      details.
    • Added low-level reduction APIs in {mod}jax.lax:jax.lax.reduce_sum,
      jax.lax.reduce_prod,jax.lax.reduce_max,jax.lax.reduce_min,
      jax.lax.reduce_and,jax.lax.reduce_or, andjax.lax.reduce_xor.
    • jax.lax.linalg.qr, andjax.scipy.linalg.qr, now support
      column-pivoting on CPU and GPU. See#20282 and
      #25955 for more details.
  • Changes

    • JAX_CPU_COLLECTIVES_IMPLEMENTATION andJAX_NUM_CPU_DEVICES now work as
      env vars. Before they could only be specified via jax.config or flags.
    • JAX_CPU_COLLECTIVES_IMPLEMENTATION now defaults to'gloo', meaning
      multi-process CPU communication works out-of-the-box.
    • Thejax[tpu] TPU extra no longer depends on thelibtpu-nightly package.
      This package may safely be removed if it is present on your machine; JAX now
      useslibtpu instead.
  • Deprecations

    • The internal functionlinear_util.wrap_init and the constructor
      core.Jaxpr now must take a non-emptycore.DebugInfo kwarg. For
      a limited time, aDeprecationWarning is printed if
      jax.extend.linear_util.wrap_init is used without debugging info.
      A downstream effect of this several other internal functions need debug
      info. This change does not affect public APIs.
      See#26480 for more detail.
  • Bug fixes

    • TPU runtime startup and shutdown time should be significantly improved on
      TPU v5e and newer (from around 17s to around 8s). If not already set, you may
      need to enable transparent hugepages in your VM image
      (sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled').
      We hope to improve this further in future releases.
    • Persistent compilation cache no longer writes access time file if
      JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU
      eviction policy isn't enabled. This should improve performance when using
      the cache with large-scale network storage.
Loading
etiennelndr reacted with thumbs up emojidaskol, nico-bohlinger, Nathan-Furnal, RaulPL, dominikandreasseitz, YusukeSuzuki, pierrot-lc, Qazalbash, GleasonK, Artoriuz, and 3 more reacted with hooray emoji
14 people reacted

JAX v0.5.0

17 Jan 18:27
Compare
Choose a tag to compare
Loading

As of this release, JAX now useseffort-based versioning.
Since this release makes a breaking change to PRNG key semantics that
may require users to update their code, we are bumping the "meso" version of JAX
to signify this.

  • Breaking changes

    • Enablejax_threefry_partitionable by default (see
      the update note).

    • This release drops support for Mac x86 wheels. Mac ARM of course remains
      supported. For a recent discussion, see#22936.

      Two key factors motivated this decision:

      • The Mac x86 build (only) has a number of test failures and crashes. We
        would prefer to ship no release than a broken release.
      • Mac x86 hardware is end-of-life and cannot be easily obtained for
        developers at this point. So it is difficult for us to fix this kind of
        problem even if we wanted to.

      We are open to readding support for Mac x86 if the community is willing
      to help support that platform: in particular, we would need the JAX test
      suite to pass cleanly on Mac x86 before we could ship releases again.

  • Changes:

    • The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
      supported version until June 2025.
    • The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum
      supported version until June 2025.
    • jax.numpy.einsum now defaults tooptimize='auto' rather than
      optimize='optimal'. This avoids exponentially-scaling trace-time in
      the case of many arguments (#25214).
    • jax.numpy.linalg.solve no longer supports batched 1D arguments
      on the right hand side. To recover the previous behavior in these cases,
      usesolve(a, b[..., None]).squeeze(-1).
  • New Features

    • jax.numpy.fft.fftn,jax.numpy.fft.rfftn,
      jax.numpy.fft.ifftn, andjax.numpy.fft.irfftn now support
      transforms in more than 3 dimensions, which was previously the limit. See
      #25606 for more details.
    • Support added for user defined state in the FFI via the new
      jax.ffi.register_ffi_type_id function.
    • The AOT lowering.as_text() method now supports thedebug_info option
      to include debugging information, e.g., source location, in the output.
  • Deprecations

    • Fromjax.interpreters.xla,abstractify andpytype_aval_mappings
      are now deprecated, having been replaced by symbols of the same name
      injax.core.
    • jax.scipy.special.lpmn andjax.scipy.special.lpmn_values
      are deprecated, following their deprecation in SciPy v1.15.0. There are
      no plans to replace these deprecated functions with new APIs.
    • Thejax.extend.ffi submodule was moved tojax.ffi, and the
      previous import path is deprecated.
  • Deletions

    • jax_enable_memories flag has been deleted and the behavior of that flag
      is on by default.
    • Fromjax.lib.xla_client, the previously-deprecatedDevice and
      XlaRuntimeError symbols have been removed; instead usejax.Device
      andjax.errors.JaxRuntimeError respectively.
    • Thejax.experimental.array_api module has been removed after being
      deprecated in JAX v0.4.32. Since that release,jax.numpy supports
      the array API directly.
Loading
etiennelndr, SRSohag33, johnnynunez, sosiristseng, kashif, samos123, and stergiosba reacted with thumbs up emojidaskol, johannahaffner, ismael-mendoza, piyushchauhan2011, emdupre, johnnynunez, Nathan-Furnal, kasper0406, OrenBochman, Blair-Johnson, and 4 more reacted with hooray emojidaskol, ismael-mendoza, emdupre, OrenBochman, leonard-gleyzer, kelechi-c, piyushchauhan2011, LiPingYen, and bramgrooten reacted with rocket emoji
23 people reacted

JAX v0.4.38

17 Dec 23:00
Compare
Choose a tag to compare
Loading
  • Changes:

    • jax.tree.flatten_with_path andjax.tree.map_with_path are added
      as shortcuts of the correspondingtree_util functions.
  • Deprecations

    • a number of APIs in the internaljax.core namespace have been deprecated.
      Most were no-ops, were little-used, or can be replaced by APIs of the same
      name injax.extend.core; see the documentation for {mod}jax.extend
      for information on the compatibility guarantees of these semi-public extensions.
    • Several previously-deprecated APIs have been removed, including:
      • fromjax.core:check_eqn,check_type,check_valid_jaxtype, and
        non_negative_dim.
      • fromjax.lib.xla_bridge:xla_client anddefault_backend.
      • fromjax.lib.xla_client:_xla andbfloat16.
      • fromjax.numpy:round_.
  • New Features

    • jax.export.export can be used for device-polymorphic export with
      shardings constructed with {func}jax.sharding.AbstractMesh.
      See thejax.export documentation.
    • Addedjax.lax.split. This is a primitive version of
      jax.numpy.split, added because it yields a more compact
      transpose during automatic differentiation.
Loading
ErcinDedeoglu and guyujun reacted with thumbs up emojiErcinDedeoglu reacted with hooray emojiErcinDedeoglu reacted with rocket emoji
2 people reacted

JAX v0.4.37

10 Dec 01:17
Compare
Choose a tag to compare
Loading

This is a patch release of jax 0.4.36. Only "jax" was released at this version.

  • Bug fixes
    • Fixed a bug wherejit would error if an argument was namedf (#25329).
    • Fix a bug that will throwindex out of range error in
      jax.lax.while_loop if the user registers pytree node class with
      different aux data for the flatten and flatten_with_path.
    • Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.
Loading
ErcinDedeoglu, samos123, and rkruegs123 reacted with thumbs up emojiErcinDedeoglu and davindicode reacted with hooray emojiErcinDedeoglu reacted with rocket emoji
4 people reacted

JAX v0.4.36

05 Dec 23:33
Compare
Choose a tag to compare
Loading
  • Breaking Changes

    • This release lands "stackless", an internal change to JAX's tracing
      machinery. We made trace dispatch purely a function of context rather than a
      function of both context and data. This let us delete a lot of machinery for
      managing data-dependent tracing: levels, sublevels,post_process_call,
      new_base_main,custom_bind, and so on. The change should only affect
      users that use JAX internals.

      If you do use JAX internals then you may need to
      update your code (see
      c36e1f7
      for clues about how to do this). There might also be version skew
      issues with JAX libraries that do this. If you find this change breaks your
      non-JAX-internals-using code then try the
      config.jax_data_dependent_tracing_fallback flag as a workaround, and if
      you need help updating your code then please file a bug.

    • jax.experimental.jax2tf.convert withnative_serialization=False
      or withenable_xla=False have been deprecated since July 2024, with
      JAX version 0.4.31. Now we removed support for these use cases.jax2tf
      with native serialization will still be supported.

    • Injax.interpreters.xla, thexb,xc, andxe symbols have been removed
      after being deprecated in JAX v0.4.31. Instead usexb = jax.lib.xla_bridge,
      xc = jax.lib.xla_client, andxe = jax.lib.xla_extension.

    • The deprecated modulejax.experimental.export has been removed. It was replaced
      byjax.export in JAX v0.4.30. See themigration guide
      for information on migrating to the new API.

    • Theinitial argument tojax.nn.softmax andjax.nn.log_softmax
      has been removed, after being deprecated in v0.4.27.

    • Callingnp.asarray on typed PRNG keys (i.e. keys produced byjax.random.key)
      now raises an error. Previously, this returned a scalar object array.

    • The following deprecated methods and functions injax.export have
      been removed:

      • jax.export.DisabledSafetyCheck.shape_assertions: it had no effect
        already.
      • jax.export.Exported.lowering_platforms: useplatforms.
      • jax.export.Exported.mlir_module_serialization_version:
        usecalling_convention_version.
      • jax.export.Exported.uses_shape_polymorphism:
        useuses_global_constants.
      • thelowering_platforms kwarg forjax.export.export: use
        platforms instead.
    • The kwargssymbolic_scope andsymbolic_constraints from
      jax.export.symbolic_args_specs have been removed. They were
      deprecated in June 2024. Usescope andconstraints instead.

    • Hashing of tracers, which has been deprecated since version 0.4.30, now
      results in aTypeError.

    • Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
      replaces previous build.py usage. Runpython build/build.py --help for
      more details. Brief overview of the new subcommand options:

      • build: Builds JAX wheel packages. For e.g.,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
      • requirements_update: Updates requirements_lock.txt files.
    • jax.scipy.linalg.toeplitz now does implicit batching on multi-dimensional
      inputs. To recover the previous behavior, you can calljax.numpy.ravel
      on the function inputs.

    • jax.scipy.special.gamma andjax.scipy.special.gammasgn now
      return NaN for negative integer inputs, to match the behavior of SciPy from
      scipy/scipy#21827.

    • jax.clear_backends was removed after being deprecated in v0.4.26.

    • We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
      call that we guarantee export stability. This is because this custom call
      relies on Triton IR, which is not guaranteed to be stable. If you need
      to export code that uses this custom call, you can use thedisabled_checks
      parameter. See more details in thedocumentation.

  • New Features

    • jax.jit got a newcompiler_options: dict[str, Any] argument, for
      passing compilation options to XLA. For the moment it's undocumented and
      may be in flux.
    • jax.tree_util.register_dataclass now allows metadata fields to be
      declared inline viadataclasses.field. See the function documentation
      for examples.
    • Addedjax.numpy.put_along_axis.
    • jax.lax.linalg.eig and the relatedjax.numpy functions
      (jax.numpy.linalg.eig andjax.numpy.linalg.eigvals) are now
      supported on GPU. See#24663 for more details.
    • Added two new configuration flags,jax_exec_time_optimization_effort andjax_memory_fitting_effort, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
  • Bug fixes

    • Fixed a bug where the GPU implementations of LU and QR decomposition would
      result in an indexing overflow for batch sizes close to int32 max. See
      #24843 for more details.
  • Deprecations

    • jax.lib.xla_extension.ArrayImpl andjax.lib.xla_client.ArrayImpl are deprecated;
      usejax.Array instead.
    • jax.lib.xla_extension.XlaRuntimeError is deprecated; usejax.errors.JaxRuntimeError
      instead.
Loading
NeilGirdhar, kasper0406, RaulPL, homerjed, johannahaffner, Yinan-Xiao, cemlyn007, davindicode, leonard-gleyzer, and MarkusThill reacted with hooray emojihammaad2002, Yinan-Xiao, cemlyn007, and heydaari reacted with heart emoji
12 people reacted
Previous13451213
Previous

[8]ページ先頭

©2009-2025 Movatter.jp