jax.lax.select
Contents
jax.lax.select#
- jax.lax.select(pred,on_true,on_false)[source]#
Selects between two branches based on a boolean predicate.
Wraps XLA’sSelectoperator.
In general
select()leads to evaluation of both branches, althoughthe compiler may elide computations if possible. For a similar function thatusually evaluates only a single branch, seecond().- Parameters:
pred (ArrayLike) – boolean array
on_true (ArrayLike) – array containing entries to return where
predis True. Must havethe same shape aspred, and the same shape and dtype ason_false.on_false (ArrayLike) – array containing entries to return where
predis False. Must havethe same shape aspred, and the same shape and dtype ason_true.
- Returns:
array with same shape and dtype as
on_trueandon_false.- Return type:
result
Contents
