jax.numpy.round
Contents
jax.numpy.round#
- jax.numpy.round(a,decimals=0,out=None)[source]#
Round input evenly to the given number of decimals.
JAX implementation of
numpy.round().- Parameters:
a (ArrayLike) – input array or scalar.
decimals (int) – int, default=0. Number of decimal points to which the input needsto be rounded. It must be specified statically. Not implemented for
decimals<0.out (None) – Unused by JAX.
- Returns:
An array containing the rounded values to the specified
decimalswithsame shape and dtype asa.- Return type:
Note
jnp.roundrounds to the nearest even integer for the values exactly halfwaybetween rounded decimal values.See also
jax.numpy.floor(): Rounds the input to the nearest integer downwards.jax.numpy.ceil(): Rounds the input to the nearest integer upwards.jax.numpy.fix()and :func:numpy.trunc`: Rounds the input to thenearest integer towards zero.
Examples
>>>x=jnp.array([1.532,3.267,6.149])>>>jnp.round(x)Array([2., 3., 6.], dtype=float32)>>>jnp.round(x,decimals=2)Array([1.53, 3.27, 6.15], dtype=float32)
For values exactly halfway between rounded values:
>>>x1=jnp.array([10.5,21.5,12.5,31.5])>>>jnp.round(x1)Array([10., 22., 12., 32.], dtype=float32)
