Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.lax.fori_loop

Contents

jax.lax.fori_loop#

jax.lax.fori_loop(lower,upper,body_fun,init_val,*,unroll=None)[source]#

Loop fromlower toupper by reduction tojax.lax.while_loop().

TheHaskell-like type signature in brief is

fori_loop::Int->Int->((Int,a)->a)->a->a

The semantics offori_loop are given by this Python implementation:

deffori_loop(lower,upper,body_fun,init_val):val=init_valforiinrange(lower,upper):val=body_fun(i,val)returnval

As the Python version suggests, settingupper<=lower will produce noiterations. Negative or custom increments are not supported.

Unlike that Python version,fori_loop is implemented in terms of either acall tojax.lax.while_loop() or a call tojax.lax.scan(). If thetrip count is static (meaning known at tracing time, perhaps becauselowerandupper are Python integer literals) then thefori_loop isimplemented in terms ofscan() and reverse-mode autodiff is supported;otherwise, awhile_loop is used and reverse-mode autodiff is notsupported. See those functions’ docstrings for more information.

Also unlike the Python analogue, the loop-carried valueval must hold afixed shape and dtype across all iterations (and not just be consistent up toNumPy rank/shape broadcasting and dtype promotion rules, for example). Inother words, the typea in the type signature above represents an arraywith a fixed shape and dtype (or a nested tuple/list/dict container datastructure with a fixed structure and arrays with fixed shape and dtype at theleaves).

Note

fori_loop() compilesbody_fun, so while it can be combined withjit(), it’s usually unnecessary.

Parameters:
  • lower – an integer representing the loop index lower bound (inclusive)

  • upper – an integer representing the loop index upper bound (exclusive)

  • body_fun – function of type(int,a)->a.

  • init_val – initial loop carry value of typea.

  • unroll (int |bool |None) – An optional integer or boolean that determines how much to unrollthe loop. If an integer is provided, it determines how many unrolledloop iterations to run within a single rolled iteration of the loop. If aboolean is provided, it will determine if the loop is completely unrolled(i.e.unroll=True) or left completely unrolled (i.e.unroll=False).This argument is only applicable if the loop bounds are statically known.

Returns:

Loop value from the final iteration, of typea.

Contents

[8]ページ先頭

©2009-2025 Movatter.jp