Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
/jaxPublic

Commitd8bf7ec

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Make sure axes are in the same order in pvary across multiple hosts. This ends up affectingpsum on the backward pass if axes order differs across hosts.
PiperOrigin-RevId: 837212890
1 parentff91c41 commitd8bf7ec

File tree

4 files changed

+54
-47
lines changed

4 files changed

+54
-47
lines changed

‎jax/_src/core.py‎

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2430,48 +2430,23 @@ def primal_sharding_to_cotangent_sharding(sharding):
24302430
############################## pvary #################################
24312431

24322432
# Invariant -> Variant no-op cast
2433-
24342433
defpvary(x,axis_name):
24352434
axes= (axis_name,)ifnotisinstance(axis_name,tuple)elseaxis_name
24362435
ifnotaxis_name:
24372436
returnx
24382437
xs,treedef=tree_flatten(x)
2439-
ys=pvary_p.bind(*xs,axes=axes)
2438+
# TODO(yashkatariya): Maybe move `order_wrt_mesh` to pvary_transpose_rule?
2439+
# Across hosts we should have the same order of axes during lowering time and
2440+
# pvary_p transposes to psum_invariant_p.
2441+
cur_mesh=mesh_lib.get_abstract_mesh()
2442+
new_axes=axesifcur_mesh.emptyelseorder_wrt_mesh(cur_mesh,axes)
2443+
assertset(new_axes)==set(axes)
2444+
delaxes
2445+
ys=pvary_p.bind(*xs,axes=new_axes)
24402446
returntree_unflatten(treedef,ys)
24412447

24422448
pvary_p=Primitive('pvary')
24432449
pvary_p.multiple_results=True
2444-
pvary_p.def_impl(lambda*args,axes:args)
2445-
2446-
def_pvary_abstract_eval(*args,axes):
2447-
ifnotconfig._check_vma.value:
2448-
returnargs
2449-
check_unreduced_args(args,'pvary')
2450-
assertisinstance(axes,tuple)
2451-
arg_vma= [a.vmaforainargs]
2452-
forainarg_vma:
2453-
# If there is intersection between arg_vma and axes, error
2454-
ifset(axes)&a:
2455-
raiseValueError(
2456-
"pvary is a invariant->variant collective. This means that the axis"
2457-
" names mentioned in `axes` passed to `pvary` must not be present in"
2458-
f" `jax.typeof(inp).vma`. Got axes={axes} and"
2459-
f" jax.typeof(inp).vma={a}")
2460-
return [a.update(sharding=a.sharding.update(mesh=mesh_lib.get_abstract_mesh()),
2461-
vma=a.vma.union(frozenset(axes)))
2462-
forainargs]
2463-
pvary_p.def_abstract_eval(_pvary_abstract_eval)
2464-
2465-
defcheck_unreduced_args(args,name):
2466-
forainargs:
2467-
ifa.sharding.spec.unreduced:
2468-
raiseValueError(
2469-
f"{name} cannot accept args which are unreduced. Got"
2470-
f"{a.str_short(True)}")
2471-
ifa.sharding.spec.reduced:
2472-
raiseValueError(
2473-
f"{name} cannot accept args which are reduced. Got"
2474-
f"{a.str_short(True)}")
24752450

24762451
####################### reduced_vary_cast #############################
24772452

@@ -2489,6 +2464,17 @@ def reduced_vary_cast(x, axis_name):
24892464

24902465
#######################################################################
24912466

2467+
defcheck_unreduced_args(args,name):
2468+
forainargs:
2469+
ifa.sharding.spec.unreduced:
2470+
raiseValueError(
2471+
f"{name} cannot accept args which are unreduced. Got"
2472+
f"{a.str_short(True)}")
2473+
ifa.sharding.spec.reduced:
2474+
raiseValueError(
2475+
f"{name} cannot accept args which are reduced. Got"
2476+
f"{a.str_short(True)}")
2477+
24922478
defstandard_insert_pvary(*args):
24932479
ifnotconfig._check_vma.value:
24942480
returnargs

‎jax/_src/interpreters/batching.py‎

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,15 +1232,6 @@ def add_batched(axis_data, batched_args, batch_dims):
12321232
fancy_primitive_batchers[add_jaxvals_p]=add_batched
12331233
skippable_batchers[add_jaxvals_p]=lambda_: ()
12341234

1235-
########################### core. ##################################
1236-
1237-
def_pvary_batcher(vals_in,dims_in,*,axes):
1238-
ifany(type(axis)isintforaxisinaxes):
1239-
raiseNotImplementedError
1240-
vals_out=core.pvary_p.bind(*vals_in,axes=axes)
1241-
returnvals_out,dims_in
1242-
primitive_batchers[core.pvary_p]=_pvary_batcher
1243-
12441235
### mutable arrays
12451236

12461237
defvectorized(core.ref_p)

‎jax/_src/interpreters/mlir.py‎

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3294,7 +3294,3 @@ def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
32943294
context=make_ir_context()
32953295
withcontext:
32963296
returnir.Module.parse(refined_module_str)
3297-
3298-
########################### pvary ##################################
3299-
3300-
register_lowering(core.pvary_p,lambdactx,*x,axes:x)

‎jax/_src/lax/parallel.py‎

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2364,6 +2364,29 @@ def f(ct, arg):
23642364

23652365
########################### pvary ##################################
23662366

2367+
core.pvary_p.def_impl(lambda*args,axes:args)
2368+
mlir.register_lowering(core.pvary_p,lambdactx,*x,axes:x)
2369+
2370+
def_pvary_abstract_eval(*args,axes):
2371+
ifnotconfig._check_vma.value:
2372+
returnargs
2373+
_check_axis_names(axes,'pvary')
2374+
check_unreduced_args(args,'pvary')
2375+
assertisinstance(axes,tuple)
2376+
arg_vma= [a.vmaforainargs]
2377+
forainarg_vma:
2378+
# If there is intersection between arg_vma and axes, error
2379+
ifset(axes)&a:
2380+
raiseValueError(
2381+
"pvary is a invariant->variant collective. This means that the axis"
2382+
" names mentioned in `axes` passed to `pvary` must not be present in"
2383+
f" `jax.typeof(inp).vma`. Got axes={axes} and"
2384+
f" jax.typeof(inp).vma={a}")
2385+
return [a.update(sharding=a.sharding.update(mesh=get_abstract_mesh()),
2386+
vma=a.vma.union(frozenset(axes)))
2387+
forainargs]
2388+
core.pvary_p.def_abstract_eval(_pvary_abstract_eval)
2389+
23672390
def_pvary_transpose_rule(cts,*args,axes):
23682391
deff(ct,arg):
23692392
assertad.is_undefined_primal(arg)
@@ -2374,6 +2397,13 @@ def f(ct, arg):
23742397
returntree_util.tree_unflatten(treedef,nonzero_in_cts)
23752398
ad.deflinear2(core.pvary_p,_pvary_transpose_rule)
23762399

2400+
def_pvary_batcher(vals_in,dims_in,*,axes):
2401+
ifany(type(axis)isintforaxisinaxes):
2402+
raiseNotImplementedError
2403+
vals_out=core.pvary_p.bind(*vals_in,axes=axes)
2404+
returnvals_out,dims_in
2405+
batching.primitive_batchers[core.pvary_p]=_pvary_batcher
2406+
23772407
####################### all_gather_reduced ###########################
23782408

23792409
# Varying -> Reduced collective
@@ -2626,7 +2656,11 @@ def preduced(x, axis_name):
26262656
ifnotaxes:
26272657
returnx
26282658
x_flat,treedef=tree_util.tree_flatten(x)
2629-
out_flat=preduced_p.bind(*x_flat,axes=axes)
2659+
cur_mesh=get_abstract_mesh()
2660+
new_axes=axesifcur_mesh.emptyelsecore.order_wrt_mesh(cur_mesh,axes)
2661+
assertset(new_axes)==set(axes)
2662+
delaxes
2663+
out_flat=preduced_p.bind(*x_flat,axes=new_axes)
26302664
returntree_util.tree_unflatten(treedef,out_flat)
26312665

26322666
preduced_p=core.Primitive('preduced')

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp