Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.lax.pmax

Contents

jax.lax.pmax#

jax.lax.pmax(x,axis_name,*,axis_index_groups=None)[source]#

Compute an all-reduce max onx over the pmapped axisaxis_name.

Ifx is a pytree then the result is equivalent to mapping this function toeach leaf in the tree.

Parameters:
  • x – array(s) with a mapped axis namedaxis_name.

  • axis_name – hashable Python object used to name a pmapped axis (see thejax.pmap() documentation for more details).

  • axis_index_groups – optional list of lists containing axis indices (e.g. foran axis of size 4, [[0, 1], [2, 3]] would perform pmaxes over the firsttwo and last two replicas). Groups must cover all axis indices exactlyonce, and on TPUs all groups must be the same size.

Returns:

Array(s) with the same shape asx representing the result of anall-reduce max along the axisaxis_name.

Contents

[8]ページ先頭

©2009-2025 Movatter.jp