jax.lax.fori_loop
Contents
jax.lax.fori_loop#
- jax.lax.fori_loop(lower,upper,body_fun,init_val,*,unroll=None)[source]#
Loop from
lowertoupperby reduction tojax.lax.while_loop().TheHaskell-like type signature in brief is
fori_loop::Int->Int->((Int,a)->a)->a->a
The semantics of
fori_loopare given by this Python implementation:deffori_loop(lower,upper,body_fun,init_val):val=init_valforiinrange(lower,upper):val=body_fun(i,val)returnval
As the Python version suggests, setting
upper<=lowerwill produce noiterations. Negative or custom increments are not supported.Unlike that Python version,
fori_loopis implemented in terms of either acall tojax.lax.while_loop()or a call tojax.lax.scan(). If thetrip count is static (meaning known at tracing time, perhaps becauselowerandupperare Python integer literals) then thefori_loopisimplemented in terms ofscan()and reverse-mode autodiff is supported;otherwise, awhile_loopis used and reverse-mode autodiff is notsupported. See those functions’ docstrings for more information.Also unlike the Python analogue, the loop-carried value
valmust hold afixed shape and dtype across all iterations (and not just be consistent up toNumPy rank/shape broadcasting and dtype promotion rules, for example). Inother words, the typeain the type signature above represents an arraywith a fixed shape and dtype (or a nested tuple/list/dict container datastructure with a fixed structure and arrays with fixed shape and dtype at theleaves).Note
fori_loop()compilesbody_fun, so while it can be combined withjit(), it’s usually unnecessary.- Parameters:
lower – an integer representing the loop index lower bound (inclusive)
upper – an integer representing the loop index upper bound (exclusive)
body_fun – function of type
(int,a)->a.init_val – initial loop carry value of type
a.unroll (int |bool |None) – An optional integer or boolean that determines how much to unrollthe loop. If an integer is provided, it determines how many unrolledloop iterations to run within a single rolled iteration of the loop. If aboolean is provided, it will determine if the loop is completely unrolled(i.e.unroll=True) or left completely unrolled (i.e.unroll=False).This argument is only applicable if the loop bounds are statically known.
- Returns:
Loop value from the final iteration, of type
a.
