Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Pallas Changelog#

This is the list of changes specific tojax.experimental.pallas.For the overall JAX change log seehere.

Unreleased#

  • New features:

  • Changes

    • Removed thebackend argument ofjax.experimental.pallas.pallas_call() in favor ofcompiler_params.For example, to force the use of the Triton backend you have to now writecompiler_params=pltriton.CompilerParams(), wherepltriton refers tojax.experimental.pallas.triton.

    • Renamedjax.experimental.pallas.tpu.KernelType toCoreType. Theold name is deprecated and will be removed in a future release.

Released with JAX 0.9.0#

  • New features:

    • Added areduction_scratch_bytes field tojax.experimental.pallas.mosaic_gpu.CompilerParams. This gives usercontrol over how much shared memory Pallas is allowed to reserve forcross-warp reductions on GPU. Increasing this value typically allows forfaster reductions.

  • Changes

  • Removals

    • Removed the previously deprecatedpl.atomic_*,pl.load,pl.store,pl.swap andpl.max_contiguous.

Released with jax 0.8.1#

  • New features:

  • Deprecations

  • Removals

    • Removed the previously deprecatedjax.experimental.pallas.tpu.TPUCompilerParams,jax.experimental.pallas.tpu.TPUMemorySpace,jax.experimental.pallas.tpu.TritonCompilerParams.

Released with jax 0.7.1#

  • New features:

    • pltpu.make_async_remote_copy andpltpu.semaphore_signal’sdevice_idargument now allows user to pass in a dictionary that only specifies thedevice index along the communication axis, instead of the full coordinates.It also supports TPU core id index.

    • jax.debug.print now works in Pallas kernels and is the recommended way toprint.

  • Deprecations

Released with jax 0.7.0#

Released with jax 0.6.1#

Released with jax 0.5.0#

Released with jax 0.4.37#

  • New functionality

    • Added support forDotAlgorithmPreset precision arguments fordotlowering on Triton backend.

Released with jax 0.4.36 (December 6, 2024)#

Released with jax 0.4.35 (October 22, 2024)#

  • Removals

    • Removed previously deprecated aliasesjax.experimental.pallas.tpu.CostEstimate andjax.experimental.tpu.run_scoped(). Both are now available injax.experimental.pallas.

  • New functionality

    • Added a cost estimate toolpl.estimate_cost() for automaticallyconstructing a kernel cost estimate from a JAX reference function.

Released with jax 0.4.34 (October 4, 2024)#

  • Changes

  • Deprecations

  • New functionality

    • jax.experimental.pallas.pallas_call() now acceptsscratch_shapes,a PyTree specifying backend-specific temporary objects needed by thekernel, for example, buffers, synchronization primitives etc.

    • checkify.check() can now be used to insert runtime asserts whenpallas_call is called with thepltpu.enable_runtime_assert(True) contextmanager.

Released with jax 0.4.33 (September 16, 2024)#

Released with jax 0.4.32 (September 11, 2024)#

  • Changes

    • The kernel function is not allowed to close over constants. Instead, all the needed arraysmust be passed as inputs, with proper block specs (#22746).

  • New functionality

    • Improved error messages for mistakes in the signature of the index map functions,to include the name and source location of the index map.

Released with jax 0.4.31 (July 29, 2024)#

  • Changes

    • jax.experimental.pallas.BlockSpec now expectsblock_shape tobe passedbeforeindex_map. The old argument order is deprecated andwill be removed in a future release.

    • jax.experimental.pallas.GridSpec does not have anymore thein_specs_tree,and theout_specs_tree fields, and thein_specs andout_specs tree nowstore the values as pytrees of BlockSpec. Previously,in_specs andout_specs were flattened (#22552).

    • The methodcompute_index ofjax.experimental.pallas.GridSpec hasbeen removed because it is private. Similarly, theget_grid_mapping andunzip_dynamic_bounds have been removed fromBlockSpec (#22593).

    • Fixed the interpret mode to work with BlockSpec that involve padding(#22275).Padding in interpret mode will be with NaN, to help debug out-of-boundserrors, but this behavior is not present when running in custom kernel mode,and should not be depended on.

    • Previously it was possible to import many APIs that are meant to beprivate, asjax.experimental.pallas.pallas. This is not possible anymore.

  • New Functionality

    • Added documentation for BlockSpec:Grids and BlockSpecs.

    • Improved error messages for thejax.experimental.pallas.pallas_call()API.

    • Added lowering rules for TPU forlax.shift_right_arithmetic (#22279)andlax.erf_inv (#22310).

    • Added initial support for shape polymorphism for the Pallas TPU custom kernels
      (#22084).

    • Added TPU support for checkify. (#22480)

    • Added clearer error messages when the block sizes do not match the TPUrequirements. Previously, the errors were coming from the Mosaic backendand did not have useful Python stack traces.

    • Added support for TPU lowering with 1D blocks, and relaxed the requirementsfor the block sizes with at least 2 dimensions: the last 2 dimensions mustbe divisible by 8 and 128 respectively, unless they span the entirecorresponding array dimension. Previously, block dimensions that spanned theentire array were allowed only if the block dimensions in the last twodimensions were smaller than 8 and 128 respectively.

Released with JAX 0.4.30 (June 18, 2024)#


[8]ページ先頭

©2009-2026 Movatter.jp