Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.experimental.pallas module

jax.experimental.pallas module#

Module for Pallas, a JAX extension for custom kernels.

See the Pallas documentation athttps://docs.jax.dev/en/latest/pallas/index.html.

Backends#

Classes#

BlockSpec([block_shape, index_map, ...])

Specifies how an array should be sliced for each invocation of a kernel.

GridSpec([grid, in_specs, out_specs, ...])

Encodes the grid parameters forjax.experimental.pallas.pallas_call().

Slice(start, size[, stride])

A slice with a start index and a size.

Functions#

core_map(mesh, *[, compiler_params, ...])

Runs a function on a mesh, mapping it over the devices in the mesh.

kernel([body, out_shape, scratch_shapes, ...])

Entry point for creating a Pallas kernel.

pallas_call(kernel, out_shape, *[, ...])

Entry point for creating a Pallas kernel.

program_id(axis)

Returns the kernel execution position along the given axis of the grid.

num_programs(axis)

Returns the size of the grid along the given axis.

cdiv()

Computes the ceiling division of a divided by b.

dslice(start[, size, stride])

Constructs aSlice from a start index and a size.

empty(shape, dtype, *[, out_sharding])

Create an empty array of possibly uninitialized values.

empty_like(x)

Create an empty PyTree of possibly uninitialized values.

broadcast_to(a, shape)

Broadcasts an array to a new shape.

debug_check(condition, message)

Check the condition ifenable_debug_checks() is set, otherwise do nothing.

debug_print(fmt, *args)

Prints values from inside a Pallas kernel.

dot(a, b[, trans_a, trans_b, allow_tf32, ...])

Computes the dot product of two arrays.

get_global(what)

Returns a global reference that persists across all kernel invocations.

loop(lower, upper, *[, step, unroll])

Returns a decorator that calls the decorated function in a loop.

multiple_of(x, values)

A compiler hint that asserts a value is a static multiple of another.

run_scoped(f, *types[, collective_axes])

Calls the function with allocated references and returns the result.

when(condition, /)

Calls the decorated function when the condition is met.

with_scoped(*types[, collective_axes])

Returns a function decorator that runs a function with provided allocations.

Synchronization#

semaphore_read(sem_or_view)

Reads the value of a semaphore.

semaphore_signal(sem_or_view[, inc, ...])

Increments the value of a semaphore.

semaphore_wait(sem_or_view[, value, decrement])

Blocks execution of the current thread until a semaphore reaches a value.


[8]ページ先頭

©2009-2026 Movatter.jp