Skip to content

Commit 866e32b

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] ModuleContext.reserve_barrier is now a context manager
This allows unreserving the barrier once it is no longer needed and is consistent with how resource estimation works, e.g. for `cond`. PiperOrigin-RevId: 745483567
1 parent 4275135 commit 866e32b

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ class ModuleContext:
311311
lowering_semantics: mgpu.LoweringSemantics
312312
primitive_semantics: gpu_core.PrimitiveSemantics
313313

314+
@contextlib.contextmanager
314315
def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef:
315316
"""Reserves a barrier.
316317
@@ -320,7 +321,9 @@ def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef:
320321
available = self.runtime_barriers.get(barrier, [])
321322
if not available:
322323
raise RuntimeError(f"Barrier {barrier} is already reserved")
323-
return available.pop()
324+
barrier = available.pop()
325+
yield barrier
326+
available.append(barrier)
324327

325328
@contextlib.contextmanager
326329
def alloc_tmem(
@@ -1965,7 +1968,7 @@ def _run_scoped_lowering_rule(
19651968
input_refs.append(acc)
19661969
should_discharge.append(True)
19671970
elif isinstance(aval.dtype, gpu_core.BarrierType):
1968-
input_refs.append(
1971+
barrier_ref = alloc_stack.enter_context(
19691972
ctx.module_ctx.reserve_barrier(
19701973
mgpu.Barrier(
19711974
aval.dtype.num_arrivals
@@ -1974,17 +1977,19 @@ def _run_scoped_lowering_rule(
19741977
)
19751978
)
19761979
)
1980+
input_refs.append(barrier_ref)
19771981
should_discharge.append(False)
19781982
elif isinstance(aval.dtype, gpu_core.ClusterBarrierType):
19791983
collective_dims = jax.tree.map(
19801984
lambda axis: _resolve_cluster_axis(ctx.module_ctx.axis_names, axis),
19811985
aval.dtype.collective_axes,
19821986
)
1983-
input_refs.append(
1987+
barrier_ref = alloc_stack.enter_context(
19841988
ctx.module_ctx.reserve_barrier(
19851989
mgpu.ClusterBarrier(collective_dims, *aval.shape)
19861990
)
19871991
)
1992+
input_refs.append(barrier_ref)
19881993
should_discharge.append(False)
19891994
elif aval.memory_space == gpu_core.SMEM:
19901995
[input_ref] = alloc_stack.enter_context(

tests/pallas/mosaic_gpu_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,26 @@ def body(tmp_ref):
10461046
x = np.ones((8, 128), jnp.float32)
10471047
np.testing.assert_array_equal(kernel(x), x + 1.0)
10481048

1049+
def test_run_scoped_in_cond(self):
1050+
@functools.partial(
1051+
self.pallas_call,
1052+
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
1053+
in_specs=[pl.BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)],
1054+
out_specs=pl.BlockSpec(memory_space=plgpu.GPUMemorySpace.SMEM),
1055+
)
1056+
def kernel(x_ref_gmem, o_ref):
1057+
def scoped_kernel(barrier_ref):
1058+
plgpu.copy_gmem_to_smem(x_ref_gmem, o_ref, barrier_ref)
1059+
plgpu.barrier_wait(barrier_ref)
1060+
1061+
def branch():
1062+
pl.run_scoped(scoped_kernel, plgpu.Barrier(num_arrivals=1))
1063+
1064+
jax.lax.cond(x_ref_gmem[0] % 2 == 0, branch, branch)
1065+
1066+
x = jnp.full((256,), 1234, dtype=jnp.int32)
1067+
np.testing.assert_array_equal(kernel(x), x)
1068+
10491069
def test_program_id(self):
10501070
@functools.partial(
10511071
self.pallas_call,

0 commit comments

Comments
 (0)