jax.lax.pmin
Contents
jax.lax.pmin#
- jax.lax.pmin(x,axis_name,*,axis_index_groups=None)[source]#
Compute an all-reduce min on
xover the pmapped axisaxis_name.If
xis a pytree then the result is equivalent to mapping this function toeach leaf in the tree.- Parameters:
x – array(s) with a mapped axis named
axis_name.axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()documentation for more details).axis_index_groups – optional list of lists containing axis indices (e.g. foran axis of size 4, [[0, 1], [2, 3]] would perform pmins over the firsttwo and last two replicas). Groups must cover all axis indices exactlyonce, and on TPUs all groups must be the same size.
- Returns:
Array(s) with the same shape as
xrepresenting the result of anall-reduce min along the axisaxis_name.
Contents
