jax.numpy.piecewise
Contents
jax.numpy.piecewise#
- jax.numpy.piecewise(x,condlist,funclist,*args,**kw)[source]#
Evaluate a function defined piecewise across the domain.
JAX implementation of
numpy.piecewise(), in terms ofjax.lax.switch().Note
Unlike
numpy.piecewise(),jax.numpy.piecewise()requires functionsinfunclistto be traceable by JAX, as it is implemented viajax.lax.switch().- Parameters:
x (ArrayLike) – array of input values.
condlist (Array |Sequence[ArrayLike]) – boolean array or sequence of boolean arrays corresponding to thefunctions in
funclist. If a sequence of arrays, the length of eacharray must match the length ofxfunclist (list[ArrayLike |Callable[...,Array]]) – list of arrays or functions; must either be the same length as
condlist, or have lengthlen(condlist)+1, in which case thelast entry is the default applied when none of the conditions are True.Alternatively, entries offunclistmay be numerical values, in whichcase they indicate a constant function.args – additional arguments are passed to each function in
funclist.kwargs – additional arguments are passed to each function in
funclist.
- Returns:
An array which is the result of evaluating the functions on
xatthe specified conditions.- Return type:
See also
jax.lax.switch(): choose betweenN functions based on an index.jax.lax.cond(): choose between two functions based on a boolean condition.jax.numpy.where(): choose between two results based on a boolean mask.jax.lax.select(): choose between two results based on a boolean mask.jax.lax.select_n(): choose betweenN results based on a boolean mask.
Examples
Here’s an example of a function which is zero for negative values, and linearfor positive values:
>>>x=jnp.array([-4,-3,-2,-1,0,1,2,3,4])
>>>condlist=[x<0,x>=0]>>>funclist=[lambdax:0*x,lambdax:x]>>>jnp.piecewise(x,condlist,funclist)Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)
funclistcan also contain a simple scalar value for constant functions:>>>condlist=[x<0,x>=0]>>>funclist=[0,lambdax:x]>>>jnp.piecewise(x,condlist,funclist)Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)
You can specify a default value by appending an extra condition to
funclist:>>>condlist=[x<-1,x>1]>>>funclist=[lambdax:1+x,lambdax:x-1,0]>>>jnp.piecewise(x,condlist,funclist)Array([-3, -2, -1, 0, 0, 0, 1, 2, 3], dtype=int32)
condlistmay also be a simple array of scalar conditions, in which casethe associated function applies to the whole range>>>condlist=jnp.array([False,True,False])>>>funclist=[lambdax:x*0,lambdax:x*10,lambdax:x*100]>>>jnp.piecewise(x,condlist,funclist)Array([-40, -30, -20, -10, 0, 10, 20, 30, 40], dtype=int32)
