Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.repeat

Contents

jax.numpy.repeat#

jax.numpy.repeat(a,repeats,axis=None,*,total_repeat_length=None,out_sharding=None)[source]#

Construct an array from repeated elements.

JAX implementation ofnumpy.repeat().

Parameters:
  • a (ArrayLike) – N-dimensional array

  • repeats (ArrayLike) – 1D integer array specifying the number of repeats. Must match thelength of the repeated axis.

  • axis (int |None) – integer specifying the axis ofa along which to construct therepeated array. If None (default) thena is first flattened.

  • total_repeat_length (int |None) – this must be specified statically forjnp.repeatto be compatible withjit() and other JAX transformations.Ifsum(repeats) is larger than the specifiedtotal_repeat_length,the remaining values will be discarded. Ifsum(repeats) is smallerthantotal_repeat_length, the final value will be repeated.

  • out_sharding (NamedSharding |P |None)

Returns:

an array constructed from repeated values ofa.

Return type:

Array

See also

Examples

Repeat each value twice along the last axis:

>>>a=jnp.array([[1,2],...[3,4]])>>>jnp.repeat(a,2,axis=-1)Array([[1, 1, 2, 2],       [3, 3, 4, 4]], dtype=int32)

Ifaxis is not specified, the input array will be flattened:

>>>jnp.repeat(a,2)Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)

Pass an array torepeats to repeat each value a different number of times:

>>>repeats=jnp.array([2,3])>>>jnp.repeat(a,repeats,axis=1)Array([[1, 1, 2, 2, 2],       [3, 3, 4, 4, 4]], dtype=int32)

In order to userepeat withinjit and other JAX transformations, thesize of the output must be specified statically usingtotal_repeat_length:

>>>jit_repeat=jax.jit(jnp.repeat,static_argnames=['axis','total_repeat_length'])>>>jit_repeat(a,repeats,axis=1,total_repeat_length=5)Array([[1, 1, 2, 2, 2],       [3, 3, 4, 4, 4]], dtype=int32)

Iftotal_repeat_length is smaller thansum(repeats), the result will be truncated:

>>>jit_repeat(a,repeats,axis=1,total_repeat_length=4)Array([[1, 1, 2, 2],       [3, 3, 4, 4]], dtype=int32)

If it is larger, then the additional entries will be filled with the final value:

>>>jit_repeat(a,repeats,axis=1,total_repeat_length=7)Array([[1, 1, 2, 2, 2, 2, 2],       [3, 3, 4, 4, 4, 4, 4]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp