jax.lax.round
Contents
jax.lax.round#
- jax.lax.round(x,rounding_method=RoundingMethod.AWAY_FROM_ZERO)[source]#
Elementwise round.
Rounds values to the nearest integer. This function lowers directly to thestablehlo.round operation.
- Parameters:
x (ArrayLike) – an array or scalar value to round. Must have floating-point type.
rounding_method (RoundingMethod) – the method to use when rounding halfway values(e.g.,
0.5). Seejax.lax.RoundingMethodfor possible values.
- Returns:
An array of the same shape and dtype as
x, containing the elementwiserounding ofx.- Return type:
See also
jax.lax.floor(): round to the next integer toward negative infinityjax.lax.ceil(): round to the next integer toward positive infinity
Examples
>>>importjax.numpyasjnp>>>fromjaximportlax>>>x=jnp.array([-1.5,-1.0,-0.5,0.0,0.5,1.0,1.5])>>>jax.lax.round(x)# defaults method is AWAY_FROM_ZEROArray([-2., -1., -1., 0., 1., 1., 2.], dtype=float32)>>>jax.lax.round(x,rounding_method=jax.lax.RoundingMethod.TO_NEAREST_EVEN)Array([-2., -1., -0., 0., 0., 1., 2.], dtype=float32)
Contents
