jax.lax.dynamic_slice
Contents
jax.lax.dynamic_slice#
- jax.lax.dynamic_slice(operand,start_indices,slice_sizes,*,allow_negative_indices=True)[source]#
Wraps XLA’sDynamicSliceoperator.
- Parameters:
operand (Array |np.ndarray) – an array to slice.
start_indices (Array |np.ndarray |Sequence[ArrayLike]) – a list of scalar indices, one per dimension. These valuesmay be dynamic.
slice_sizes (Shape) – the size of the slice. Must be a sequence of non-negativeintegers with length equal tondim(operand). Inside a JIT compiledfunction, only static values are supported (all JAX arrays inside JITmust have statically known size).
allow_negative_indices (bool |Sequence[bool]) – a bool or sequence of bools, one per dimension; ifa bool is passed, it applies to all dimensions. For each dimension,if true, negative indices are permitted and are are interpreted relativeto the end of the array. If false, negative indices are treated as if theywere out of bounds and the result is implementation defined, typicallyclamped to the first index.
- Returns:
An array containing the slice.
- Return type:
Examples
Here is a simple two-dimensional dynamic slice:
>>>x=jnp.arange(12).reshape(3,4)>>>xArray([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>>dynamic_slice(x,(1,1),(2,3))Array([[ 5, 6, 7], [ 9, 10, 11]], dtype=int32)
Note the potentially surprising behavior for the case where the requested sliceoverruns the bounds of the array; in this case the start index is adjusted toreturn a slice of the requested size:
>>>dynamic_slice(x,(1,1),(2,4))Array([[ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
