Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.where

Contents

jax.numpy.where#

jax.numpy.where(condition,x=None,y=None,/,*,size=None,fill_value=None)[source]#

Select elements from two arrays based on a condition.

JAX implementation ofnumpy.where().

Note

when onlycondition is provided,jnp.where(condition) is equivalenttojnp.nonzero(condition). For that case, refer to the documentation ofjax.numpy.nonzero(). The docstring below focuses on the case wherex andy are specified.

The three-term version ofjnp.where lowers tojax.lax.select().

Parameters:
  • condition – boolean array. Must be broadcast-compatible withx andy whenthey are specified.

  • x – arraylike. Should be broadcast-compatible withcondition andy, andtypecast-compatible withy.

  • y – arraylike. Should be broadcast-compatible withcondition andx, andtypecast-compatible withx.

  • size – integer, only referenced whenx andy areNone. For details,seejax.numpy.nonzero().

  • fill_value – only referenced whenx andy areNone. For details,seejax.numpy.nonzero().

Returns:

An array of dtypejnp.result_type(x,y) with values drawn fromx whereconditionis True, and fromy where condition isFalse. Ifx andy areNone, thefunction behaves differently; seejax.numpy.nonzero() for a description of the returntype.

Notes

Special care is needed when thex ory input tojax.numpy.where() couldhave a value of NaN. Specifically, when a gradient is taken withjax.grad()(reverse-mode differentiation), a NaN in eitherx ory will propagate into thegradient, regardless of the value ofcondition. More information on this behaviorand workarounds is available in theJAX FAQ.

Examples

Whenx andy are not provided,where behaves equivalently tojax.numpy.nonzero():

>>>x=jnp.arange(10)>>>jnp.where(x>4)(Array([5, 6, 7, 8, 9], dtype=int32),)>>>jnp.nonzero(x>4)(Array([5, 6, 7, 8, 9], dtype=int32),)

Whenx andy are provided,where selects between them based onthe specified condition:

>>>jnp.where(x>4,x,0)Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp