@@ -718,7 +718,7 @@ def _eltwise_usage_rule(
718718return [used_out ]
719719
720720
721- def _bcast_block_spec (
721+ def _pull_bcast_block_spec (
722722block_spec :pallas_core .BlockSpec ,i :int
723723)-> pallas_core .BlockSpec :
724724def new_index_map (* args ):
@@ -741,6 +741,21 @@ def new_index_map(*args):
741741return pallas_core .BlockSpec (new_block_shape ,new_index_map )
742742
743743
744+ def _push_bcast_block_spec (
745+ block_spec :pallas_core .BlockSpec ,
746+ i :int ,
747+ size :int ,
748+ )-> pallas_core .BlockSpec :
749+
750+ bcast_dim_block_shape = size
751+ if isinstance (block_spec .block_shape [i ],pallas_core .Element ):
752+ bcast_dim_block_shape = pallas_core .Element (size )
753+ new_block_shape = util .tuple_update (# pytype: disable=wrong-arg-types
754+ block_spec .block_shape ,i ,bcast_dim_block_shape
755+ )
756+ return pallas_core .BlockSpec (new_block_shape ,block_spec .index_map )
757+
758+
744759def _binop_usage_rule (prim ,ctx ,used_out :set [Usage ]):
745760del prim
746761if used_out == {Usage .SCALAR_PREFETCH }:
@@ -782,9 +797,9 @@ def _eval_function(_, x, y):
782797zip (left_aval .shape ,right_aval .shape ,strict = True )
783798 ):
784799if l == 1 and r != 1 :
785- l_block_spec = _bcast_block_spec (l_block_spec ,i )
800+ l_block_spec = _pull_bcast_block_spec (l_block_spec ,i )
786801if r == 1 and l != 1 :
787- r_block_spec = _bcast_block_spec (r_block_spec ,i )
802+ r_block_spec = _pull_bcast_block_spec (r_block_spec ,i )
788803
789804return [l_block_spec ,r_block_spec ]
790805
@@ -2117,6 +2132,23 @@ def _binop_push_rule(
21172132left_aval ,right_aval = ctx .avals_in
21182133assert isinstance (left_aval ,core .ShapedArray )
21192134assert isinstance (right_aval ,core .ShapedArray )
2135+ if not right_aval .shape :
2136+ return left_block_spec
2137+ if not left_aval .shape :
2138+ return right_block_spec
2139+ lhs_has_block_spec = left_block_spec is not pallas_core .no_block_spec
2140+ rhs_has_block_spec = right_block_spec is not pallas_core .no_block_spec
2141+ if not (lhs_has_block_spec ^ rhs_has_block_spec ):
2142+ # We can only do a push if one of the block specs is unspecified
2143+ # or they are identical.
2144+ if left_block_spec is right_block_spec :
2145+ return left_block_spec
2146+ raise ValueError ('Illegal binary push. One of the block specs must be no_block_spec.' )
2147+ for l ,r in zip (left_aval .shape ,right_aval .shape ,strict = True ):
2148+ if l == 1 and r != 1 and lhs_has_block_spec :
2149+ raise ValueError ('Cannot propagate block spec through LHS broadcast.' )
2150+ if r == 1 and l != 1 and rhs_has_block_spec :
2151+ raise ValueError ('Cannot propagate block spec through RHS broadcast.' )
21202152if left_block_spec is pallas_core .no_block_spec :
21212153return right_block_spec
21222154if right_block_spec is pallas_core .no_block_spec :
@@ -2233,7 +2265,6 @@ def _custom_call_hi_primitive_push_block_spec_rule(
22332265return prim .push_block_spec_rule (ctx ,block_specs )
22342266
22352267
2236-
22372268@register_push_block_spec_rule (pjit .jit_p )
22382269def _pjit_push_rule (ctx ,* block_specs ,jaxpr :core .ClosedJaxpr ,** _ ):
22392270assert not jaxpr .consts