jax.numpy.repeat
Contents
jax.numpy.repeat#
- jax.numpy.repeat(a,repeats,axis=None,*,total_repeat_length=None,out_sharding=None)[source]#
Construct an array from repeated elements.
JAX implementation of
numpy.repeat().- Parameters:
a (ArrayLike) – N-dimensional array
repeats (ArrayLike) – 1D integer array specifying the number of repeats. Must match thelength of the repeated axis.
axis (int |None) – integer specifying the axis of
aalong which to construct therepeated array. If None (default) thenais first flattened.total_repeat_length (int |None) – this must be specified statically for
jnp.repeatto be compatible withjit()and other JAX transformations.Ifsum(repeats)is larger than the specifiedtotal_repeat_length,the remaining values will be discarded. Ifsum(repeats)is smallerthantotal_repeat_length, the final value will be repeated.out_sharding (NamedSharding |P |None)
- Returns:
an array constructed from repeated values of
a.- Return type:
See also
jax.numpy.tile(): repeat a full array rather than individual values.
Examples
Repeat each value twice along the last axis:
>>>a=jnp.array([[1,2],...[3,4]])>>>jnp.repeat(a,2,axis=-1)Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
If
axisis not specified, the input array will be flattened:>>>jnp.repeat(a,2)Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)
Pass an array to
repeatsto repeat each value a different number of times:>>>repeats=jnp.array([2,3])>>>jnp.repeat(a,repeats,axis=1)Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
In order to use
repeatwithinjitand other JAX transformations, thesize of the output must be specified statically usingtotal_repeat_length:>>>jit_repeat=jax.jit(jnp.repeat,static_argnames=['axis','total_repeat_length'])>>>jit_repeat(a,repeats,axis=1,total_repeat_length=5)Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
Iftotal_repeat_length is smaller than
sum(repeats), the result will be truncated:>>>jit_repeat(a,repeats,axis=1,total_repeat_length=4)Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
If it is larger, then the additional entries will be filled with the final value:
>>>jit_repeat(a,repeats,axis=1,total_repeat_length=7)Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32)
