Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.nanmax

Contents

jax.numpy.nanmax#

jax.numpy.nanmax(a,axis=None,out=None,keepdims=False,initial=None,where=None)[source]#

Return the maximum of the array elements along a given axis, ignoring NaNs.

JAX implementation ofnumpy.nanmax().

Parameters:
  • a (ArrayLike) – Input array.

  • axis (Axis) – int or sequence of ints, default=None. Axis along which the maximum iscomputed. If None, the maximum is computed along the flattened array.

  • keepdims (bool) – bool, default=False. If True, reduced axes are left in the resultwith size 1.

  • initial (ArrayLike |None) – int or array, default=None. Initial value for the maximum.

  • where (ArrayLike |None) – array of boolean dtype, default=None. The elements to be used in themaximum. Array should be broadcast compatible to the input.initialmust be specified whenwhere is used.

  • out (None) – Unused by JAX.

Returns:

An array of maximum values along the given axis, ignoring NaNs. If all valuesare NaNs along the given axis, returnsnan.

Return type:

Array

See also

Examples

By default,jnp.nanmax computes the maximum of elements along the flattenedarray.

>>>nan=jnp.nan>>>x=jnp.array([[8,nan,4,6],...[nan,-2,nan,-4],...[-2,1,7,nan]])>>>jnp.nanmax(x)Array(8., dtype=float32)

Ifaxis=1, the maximum will be computed along axis 1.

>>>jnp.nanmax(x,axis=1)Array([ 8., -2.,  7.], dtype=float32)

Ifkeepdims=True,ndim of the output will be same of that of the input.

>>>jnp.nanmax(x,axis=1,keepdims=True)Array([[ 8.],       [-2.],       [ 7.]], dtype=float32)

To include only specific elements in computing the maximum, you can usewhere. It can either have same dimension as input

>>>where=jnp.array([[0,0,1,0],...[0,0,1,1],...[1,1,1,0]],dtype=bool)>>>jnp.nanmax(x,axis=1,keepdims=True,initial=0,where=where)Array([[4.],       [0.],       [7.]], dtype=float32)

or must be broadcast compatible with input.

>>>where=jnp.array([[True],...[False],...[False]])>>>jnp.nanmax(x,axis=0,keepdims=True,initial=0,where=where)Array([[8., 0., 4., 6.]], dtype=float32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp