Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
/jaxPublic

Commit65f1667

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas/Fuser] Add bcast fix for binop push rule
PiperOrigin-RevId: 837172431
1 parentbe2883a commit65f1667

File tree

2 files changed

+78
-4
lines changed

2 files changed

+78
-4
lines changed

‎jax/_src/pallas/fuser/block_spec.py‎

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ def _eltwise_usage_rule(
718718
return [used_out]
719719

720720

721-
def_bcast_block_spec(
721+
def_pull_bcast_block_spec(
722722
block_spec:pallas_core.BlockSpec,i:int
723723
)->pallas_core.BlockSpec:
724724
defnew_index_map(*args):
@@ -741,6 +741,21 @@ def new_index_map(*args):
741741
returnpallas_core.BlockSpec(new_block_shape,new_index_map)
742742

743743

744+
def_push_bcast_block_spec(
745+
block_spec:pallas_core.BlockSpec,
746+
i:int,
747+
size:int,
748+
)->pallas_core.BlockSpec:
749+
750+
bcast_dim_block_shape=size
751+
ifisinstance(block_spec.block_shape[i],pallas_core.Element):
752+
bcast_dim_block_shape=pallas_core.Element(size)
753+
new_block_shape=util.tuple_update(# pytype: disable=wrong-arg-types
754+
block_spec.block_shape,i,bcast_dim_block_shape
755+
)
756+
returnpallas_core.BlockSpec(new_block_shape,block_spec.index_map)
757+
758+
744759
def_binop_usage_rule(prim,ctx,used_out:set[Usage]):
745760
delprim
746761
ifused_out== {Usage.SCALAR_PREFETCH}:
@@ -782,9 +797,9 @@ def _eval_function(_, x, y):
782797
zip(left_aval.shape,right_aval.shape,strict=True)
783798
):
784799
ifl==1andr!=1:
785-
l_block_spec=_bcast_block_spec(l_block_spec,i)
800+
l_block_spec=_pull_bcast_block_spec(l_block_spec,i)
786801
ifr==1andl!=1:
787-
r_block_spec=_bcast_block_spec(r_block_spec,i)
802+
r_block_spec=_pull_bcast_block_spec(r_block_spec,i)
788803

789804
return [l_block_spec,r_block_spec]
790805

@@ -2117,6 +2132,23 @@ def _binop_push_rule(
21172132
left_aval,right_aval=ctx.avals_in
21182133
assertisinstance(left_aval,core.ShapedArray)
21192134
assertisinstance(right_aval,core.ShapedArray)
2135+
ifnotright_aval.shape:
2136+
returnleft_block_spec
2137+
ifnotleft_aval.shape:
2138+
returnright_block_spec
2139+
lhs_has_block_spec=left_block_specisnotpallas_core.no_block_spec
2140+
rhs_has_block_spec=right_block_specisnotpallas_core.no_block_spec
2141+
ifnot (lhs_has_block_spec^rhs_has_block_spec):
2142+
# We can only do a push if one of the block specs is unspecified
2143+
# or they are identical.
2144+
ifleft_block_specisright_block_spec:
2145+
returnleft_block_spec
2146+
raiseValueError('Illegal binary push. One of the block specs must be no_block_spec.')
2147+
forl,rinzip(left_aval.shape,right_aval.shape,strict=True):
2148+
ifl==1andr!=1andlhs_has_block_spec:
2149+
raiseValueError('Cannot propagate block spec through LHS broadcast.')
2150+
ifr==1andl!=1andrhs_has_block_spec:
2151+
raiseValueError('Cannot propagate block spec through RHS broadcast.')
21202152
ifleft_block_specispallas_core.no_block_spec:
21212153
returnright_block_spec
21222154
ifright_block_specispallas_core.no_block_spec:
@@ -2233,7 +2265,6 @@ def _custom_call_hi_primitive_push_block_spec_rule(
22332265
returnprim.push_block_spec_rule(ctx,block_specs)
22342266

22352267

2236-
22372268
@register_push_block_spec_rule(pjit.jit_p)
22382269
def_pjit_push_rule(ctx,*block_specs,jaxpr:core.ClosedJaxpr,**_):
22392270
assertnotjaxpr.consts

‎tests/pallas/fuser_block_spec_test.py‎

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,49 @@ def setUp(self):
13131313
ifconfig.enable_x64.value:
13141314
self.skipTest('x64 not supported')
13151315

1316+
deftest_binop(self):
1317+
1318+
deff(x):
1319+
returnx+jnp.ones_like(x)
1320+
1321+
block_spec=pl.BlockSpec((128,128),lambdai,j: (i,j))
1322+
x_type=jax.ShapeDtypeStruct((512,512),jnp.float32)
1323+
out_block_spec=block_spec_lib.push_block_spec(f,block_spec)(x_type)
1324+
self.assertEqual(out_block_spec.block_shape,block_spec.block_shape)
1325+
1326+
deff(x,y):
1327+
returnx+y
1328+
1329+
x_block_spec=pl.BlockSpec((128,128),lambdai,j: (i,j))
1330+
y_block_spec=pl.BlockSpec((128,1),lambdai,j: (i,0))
1331+
x_type=jax.ShapeDtypeStruct((512,512),jnp.float32)
1332+
y_type=jax.ShapeDtypeStruct((512,1),jnp.float32)
1333+
withself.assertRaisesRegex(
1334+
ValueError,'Cannot propagate block spec through RHS broadcast.'
1335+
):
1336+
block_spec_lib.push_block_spec(f,pl.no_block_spec,y_block_spec)(
1337+
x_type,y_type
1338+
)
1339+
out_block_spec=block_spec_lib.push_block_spec(
1340+
f,x_block_spec,pl.no_block_spec
1341+
)(x_type,y_type)
1342+
self.assertIs(x_block_spec,out_block_spec)
1343+
1344+
x_block_spec=pl.BlockSpec((1,128),lambdai,j: (0,j))
1345+
y_block_spec=pl.BlockSpec((128,128),lambdai,j: (i,j))
1346+
x_type=jax.ShapeDtypeStruct((1,512),jnp.float32)
1347+
y_type=jax.ShapeDtypeStruct((512,512),jnp.float32)
1348+
withself.assertRaisesRegex(
1349+
ValueError,'Cannot propagate block spec through LHS broadcast.'
1350+
):
1351+
block_spec_lib.push_block_spec(f,x_block_spec,pl.no_block_spec)(
1352+
x_type,y_type
1353+
)
1354+
out_block_spec=block_spec_lib.push_block_spec(
1355+
f,pl.no_block_spec,y_block_spec
1356+
)(x_type,y_type)
1357+
self.assertIs(out_block_spec,y_block_spec)
1358+
13161359
deftest_jit(self):
13171360

13181361
deff(x):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp