jax.numpy.where
Contents
jax.numpy.where#
- jax.numpy.where(condition,x=None,y=None,/,*,size=None,fill_value=None)[source]#
Select elements from two arrays based on a condition.
JAX implementation of
numpy.where().Note
when only
conditionis provided,jnp.where(condition)is equivalenttojnp.nonzero(condition). For that case, refer to the documentation ofjax.numpy.nonzero(). The docstring below focuses on the case wherexandyare specified.The three-term version of
jnp.wherelowers tojax.lax.select().- Parameters:
condition – boolean array. Must be broadcast-compatible with
xandywhenthey are specified.x – arraylike. Should be broadcast-compatible with
conditionandy, andtypecast-compatible withy.y – arraylike. Should be broadcast-compatible with
conditionandx, andtypecast-compatible withx.size – integer, only referenced when
xandyareNone. For details,seejax.numpy.nonzero().fill_value – only referenced when
xandyareNone. For details,seejax.numpy.nonzero().
- Returns:
An array of dtype
jnp.result_type(x,y)with values drawn fromxwhereconditionis True, and fromywhere condition isFalse. IfxandyareNone, thefunction behaves differently; seejax.numpy.nonzero()for a description of the returntype.
Notes
Special care is needed when the
xoryinput tojax.numpy.where()couldhave a value of NaN. Specifically, when a gradient is taken withjax.grad()(reverse-mode differentiation), a NaN in eitherxorywill propagate into thegradient, regardless of the value ofcondition. More information on this behaviorand workarounds is available in theJAX FAQ.Examples
When
xandyare not provided,wherebehaves equivalently tojax.numpy.nonzero():>>>x=jnp.arange(10)>>>jnp.where(x>4)(Array([5, 6, 7, 8, 9], dtype=int32),)>>>jnp.nonzero(x>4)(Array([5, 6, 7, 8, 9], dtype=int32),)
When
xandyare provided,whereselects between them based onthe specified condition:>>>jnp.where(x>4,x,0)Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)
