jax.numpy.maximum
Contents
jax.numpy.maximum#
- jax.numpy.maximum=<jnp.ufunc'maximum'>#
Return element-wise maximum of the input arrays.
JAX implementation of
numpy.maximum.- Parameters:
x – input array or scalar.
y – input array or scalar. Both
xandyshould either have same shapeor be broadcast compatible.args (ArrayLike)
out (None)
where (None)
- Returns:
An array containing the element-wise maximum of
xandy.- Return type:
Any
Note
- For each pair of elements,
jnp.maximumreturns: larger of the two if both elements are finite numbers.
nanif one element isnan.
See also
jax.numpy.minimum(): Returns element-wise minimum of the inputarrays.jax.numpy.fmax(): Returns element-wise maximum of the input arrays,ignoring NaNs.jax.numpy.amax(): Returns the maximum of array elements along a givenaxis.jax.numpy.nanmax(): Returns the maximum of the array elements alonga given axis, ignoring NaNs.
Examples
Inputs with
x.shape==y.shape:>>>x=jnp.array([1,-5,3,2])>>>y=jnp.array([-2,4,7,-6])>>>jnp.maximum(x,y)Array([1, 4, 7, 2], dtype=int32)
Inputs with broadcast compatibility:
>>>x1=jnp.array([[-2,5,7,4],...[1,-6,3,8]])>>>y1=jnp.array([-5,3,6,9])>>>jnp.maximum(x1,y1)Array([[-2, 5, 7, 9], [ 1, 3, 6, 9]], dtype=int32)
Inputs having
nan:>>>nan=jnp.nan>>>x2=jnp.array([nan,-3,9])>>>y2=jnp.array([[4,-2,nan],...[-3,-5,10]])>>>jnp.maximum(x2,y2)Array([[nan, -2., nan], [nan, -3., 10.]], dtype=float32)
