Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
/jaxPublic

Commit11c1aa8

Browse files
brianwa84Google-ML-Automation
authored andcommitted
[Pallas:SC] Mark plsc.subcore_barrier() effect as lowerable.
PiperOrigin-RevId: 837832833
1 parent9e97164 commit11c1aa8

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

‎jax/_src/pallas/mosaic/sc_primitives.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ class MemoryEffect(jax_core.Effect):
444444

445445

446446
effects.control_flow_allowed_effects.add_type(MemoryEffect)
447+
effects.lowerable_effects.add_type(MemoryEffect)
447448
_memory_effect=MemoryEffect()
448449

449450
barrier_p=jax_core.Primitive("barrier")

‎tests/pallas/tpu_sparsecore_pallas_test.py‎

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,19 @@
2020

2121
fromabsl.testingimportabsltest
2222
fromabsl.testingimportparameterized
23+
importhypothesisashp
24+
importhypothesis.strategiesashps
2325
importjax
2426
fromjaximportlax
2527
fromjax._srcimporttest_utilasjtu
2628
fromjax._src.pallas.mosaicimportsc_core
29+
fromjax._src.stateimportdischargeasstate_discharge
2730
fromjax.experimentalimportpallasaspl
2831
fromjax.experimental.pallasimporttpuaspltpu
2932
fromjax.experimental.pallasimporttpu_scasplsc
3033
importjax.numpyasjnp
3134
importnumpyasnp
32-
importhypothesisashp
33-
importhypothesis.strategiesashps
35+
3436

3537
jtu.setup_hypothesis()
3638
jax.config.parse_flags_with_absl()
@@ -1176,6 +1178,29 @@ def kernel(x_ref, o_ref, scratch_scalar_ref, scratch_vec_ref):
11761178
x)
11771179
np.testing.assert_array_equal(kernel(x),expected)
11781180

1181+
@parameterized.named_parameters(
1182+
("barrier",lambda_:plsc.subcore_barrier()),
1183+
("debug_print",lambdavec:pl.debug_print('test',vec)),
1184+
)
1185+
deftest_effect_discharge(self,effectful_op):
1186+
x=jnp.arange(self.sc_info.num_lanes)
1187+
mesh=plsc.VectorSubcoreMesh(
1188+
core_axis_name="core",subcore_axis_name="subcore",num_cores=1
1189+
)
1190+
defstateful(refs):
1191+
defbody(x_ref,o_ref):
1192+
defwith_scratch(scratch_ref):
1193+
pltpu.sync_copy(x_ref,scratch_ref)
1194+
scratch_ref[...]=scratch_ref[...]+1
1195+
effectful_op(scratch_ref[...])
1196+
pltpu.sync_copy(scratch_ref,o_ref)
1197+
pl.run_scoped(with_scratch,pltpu.VMEM(x.shape,x.dtype))
1198+
pl.core_map(mesh)(lambda:body(*refs))
1199+
1200+
_,out=jax.jit(state_discharge.run_state(stateful))(
1201+
(x,jnp.empty_like(x)))
1202+
np.testing.assert_array_equal(out,x+1)
1203+
11791204
deftest_parallel_loop_effects(self):
11801205
chunk_size=8
11811206

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp