Skip to content

Commit cecefde

Browse files
committed
comments
1 parent 011d734 commit cecefde

File tree

2 files changed

+83
-78
lines changed

2 files changed

+83
-78
lines changed

test/test_pallas_spmd.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def test_flash_attention_spmd_data_parallel(self):
7676
f"{{devices=[{n_devices},1,1,1]{dev_ids}}}")
7777

7878
expected_o = self._attention(q, k, v)
79-
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
79+
torch.testing.assert_close(
80+
o.cpu(), expected_o.cpu(), atol=1e-05, rtol=1e-05)
8081

8182
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
8283
"This test only works on TPUv3+.")
@@ -100,7 +101,8 @@ def test_flash_attention_spmd_data_parallel_5d(self):
100101
f"{{devices=[{n_devices//2},2,1,1,1]{dev_ids}}}")
101102

102103
expected_o = self._attention(q, k, v)
103-
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
104+
torch.testing.assert_close(
105+
o.cpu(), expected_o.cpu(), atol=1e-05, rtol=1e-05)
104106

105107
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
106108
"This test only works on TPUv3+.")
@@ -121,7 +123,8 @@ def test_flash_attention_spmd_data_parallel_kv_and_ab_padding(self):
121123
f"{{devices=[{n_devices},1,1,1]{dev_ids}}}")
122124

123125
expected_o = self._attention(q, k, v, ab=ab)
124-
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
126+
torch.testing.assert_close(
127+
o.cpu(), expected_o.cpu(), atol=1e-05, rtol=1e-05)
125128

126129
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
127130
"This test only works on TPUv3+.")
@@ -172,7 +175,8 @@ def test_flash_attention_backward_spmd_data_parallel(self):
172175
xm.mark_step()
173176

174177
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
175-
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
178+
torch.testing.assert_close(
179+
i[0].grad.cpu(), i[1].cpu(), atol=1e-05, rtol=1e-05)
176180

177181
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
178182
"This test only works on TPUv3+.")
@@ -216,7 +220,8 @@ def test_flash_attention_wrapper_segment_ids_spmd(self):
216220
segment_ids=SegmentIds(jax_segment_ids, jax_segment_ids),
217221
)))
218222

219-
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
223+
torch.testing.assert_close(
224+
o.cpu(), expected_o.cpu(), atol=1e-05, rtol=1e-05)
220225

221226
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
222227
"This test only works on TPUv3+.")
@@ -286,7 +291,8 @@ def test_flash_attention_backward_segment_ids_spmd(self):
286291
xm.mark_step()
287292

288293
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
289-
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
294+
torch.testing.assert_close(
295+
i[0].grad.cpu(), i[1].cpu(), atol=1e-05, rtol=1e-05)
290296

291297
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
292298
"This test only works on TPUv3+.")
@@ -331,7 +337,8 @@ def test_cross_flash_attention_wrapper_segment_ids_spmd(self):
331337
segment_ids=SegmentIds(jax_q_segment_ids, jax_kv_segment_ids),
332338
)))
333339

334-
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
340+
torch.testing.assert_close(
341+
o.cpu(), expected_o.cpu(), atol=1e-05, rtol=1e-05)
335342

336343
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
337344
"This test only works on TPUv3+.")
@@ -402,7 +409,8 @@ def test_cross_flash_attention_backward_segment_ids_spmd(self):
402409
xm.mark_step()
403410

404411
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
405-
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
412+
torch.testing.assert_close(
413+
i[0].grad.cpu(), i[1].cpu(), atol=1e-05, rtol=1e-05)
406414

407415
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
408416
"This test only works on TPUv4+.")
@@ -484,7 +492,8 @@ def flash_attention_wrapper(q, k, v, casual, q_segment_ids, kv_segment_ids,
484492
xm.mark_step()
485493

486494
for i in [(q, q_grad), (k, k_grad), (v, v_grad), (ab, ab_grad)]:
487-
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-02))
495+
torch.testing.assert_close(
496+
i[0].grad.cpu(), i[1].cpu(), atol=1e-02, rtol=1e-05)
488497

489498

490499
if __name__ == '__main__':

torch_xla/experimental/custom_kernel.py

Lines changed: 65 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def _shard_map(func, mesh, input_specs, output_specs):
2020
2121
Note:
2222
``shard_map`` is an experimental API, and still subject to change. For an
23-
introduction to sharded data, refer to :ref:`sharded-computation`. For a more
23+
introduction to sharded data. For a more
2424
in-depth look at using ``shard_map``, refer to
2525
[SPMD multi-device parallelism with shard_map](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
2626
@@ -43,7 +43,7 @@ def _shard_map(func, mesh, input_specs, output_specs):
4343
the ``mesh`` and ``out_specs``.
4444
4545
Reference:
46-
This function is identical Jax's shard_map:
46+
This function behaves identically Jax's shard_map:
4747
https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html
4848
"""
4949

@@ -56,18 +56,14 @@ def _full_shape(a, spec):
5656
result_shape = []
5757
for axis_size, axis_sharding in zip(a.shape, spec):
5858
if axis_sharding is None:
59-
new_size = axis_size
60-
else:
61-
if isinstance(axis_sharding, (str, int)):
62-
mesh_mult = mesh_name_to_size[axis_sharding]
63-
else:
64-
# tuple or list
65-
mesh_mult = math.prod(mesh_name_to_size[a]
66-
for a in axis_sharding
67-
if mesh_name_to_size[a] is not None)
68-
69-
if mesh_mult is not None:
70-
new_size = axis_size * mesh_mult
59+
axis_sharding = ()
60+
mesh_mult = []
61+
if isinstance(axis_sharding, (str, int)):
62+
axis_sharding = [axis_sharding]
63+
for a in axis_sharding:
64+
size = mesh_name_to_size[a] or 1
65+
mesh_mult.append(a)
66+
new_size = axis_size * math.prod(mesh_mult)
7167
result_shape.append(new_size)
7268
return tuple(result_shape)
7369

@@ -87,17 +83,16 @@ def wrapped(*args):
8783
res = func(*new_args)
8884
if isinstance(res, tuple):
8985
res_updated = []
90-
for i, r in enumerate(res):
91-
if isinstance(r, torch.Tensor):
86+
for i, (r, spec) in enumerate(zip(res, output_specs)):
87+
if isinstance(r, torch.Tensor) and spec is not None:
9288
assert str(r.device).startswith('xla'), f'{i}th device is {r.device}'
9389
assert len(r.shape) == len(
94-
output_specs[i]
95-
), f'{i}th shape is {r.shape}, sharding is {output_specs[i]}'
96-
return tuple(
97-
xs.disable_manual_sharding(a, spec, _full_shape(a, spec), mesh=mesh).
98-
global_tensor
99-
if isinstance(a, torch.Tensor) and spec is not None else a
100-
for a, spec in zip(res, output_specs))
90+
spec), f'{i}th shape is {r.shape}, sharding is {output_specs[i]}'
91+
new_r = xs.disable_manual_sharding(
92+
r, spec, _full_shape(a, spec), mesh=mesh).global_tensor
93+
else:
94+
new_r = r
95+
res_updated.append(new_r)
10196
else:
10297
return xs.disable_manual_sharding(
10398
res, output_specs[0], _full_shape(res, output_specs[0]),
@@ -309,6 +304,24 @@ def wrapped_kernel(kernel: Callable,
309304
return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn)
310305

311306

307+
def _maybe_reshape_input_output_funcs(current_shape, non_batch_dims=3):
308+
batch_dims = len(current_shape) - non_batch_dims
309+
orig_batch_dims = current_shape[:batch_dims]
310+
other_dims = current_shape[batch_dims:]
311+
312+
def reshape_input(tensor):
313+
if tensor is None:
314+
return None
315+
return tensor.reshape(-1, *tensor.shape[batch_dims:])
316+
317+
def reshape_output(tensor):
318+
if tensor is None:
319+
return None
320+
return tensor.reshape(*orig_batch_dims, *tensor.shape[1:])
321+
322+
return reshape_input, reshape_output
323+
324+
312325
def _fa_custom_forward_single_device(
313326
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool,
314327
q_segment_ids: torch.Tensor, kv_segment_ids: torch.Tensor, sm_scale: float,
@@ -318,20 +331,16 @@ def _fa_custom_forward_single_device(
318331

319332
num_batches = None
320333
batch_size = None
321-
if len(q.shape) == 5:
322-
num_batches, batch_size, *rest = q.shape
323-
q = q.reshape(-1, *rest)
324-
k = k.reshape(-1, *rest)
325-
v = v.reshape(-1, *rest)
326-
if q_segment_ids is not None:
327-
q_segment_ids = q_segment_ids.reshape(-1, *rest)
328-
if kv_segment_ids is not None:
329-
kv_segment_ids = kv_segment_ids.reshape(-1, *rest)
330-
if ab is not None:
331-
ab = ab.reshape(-1, *rest)
332-
333-
# Suprisingly, any tensor that is input to the custom_op decorated function will show
334-
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
334+
reshape_to_4d, undo_reshape = _maybe_reshape_input_output_funcs(q.shape, 3)
335+
q = reshape_to_4d(q)
336+
v = reshape_to_4d(v)
337+
k = reshape_to_4d(k)
338+
q_segment_ids = reshape_to_4d(q_segment_ids)
339+
kv_segment_ids = reshape_to_4d(kv_segment_ids)
340+
ab = reshape_to_4d(ab)
341+
342+
# Surprisingly, any tensor that is input to the custom_op decorated function will show
343+
# requires_grad=False by design. We have to pass ctx_grad to record the
335344
# requires_grad for inputs.
336345
# Original we use save_residuals = q.requires_grad or k.requires_grad or v.requires_grad
337346
save_residuals = any(ctx_grad[:3])
@@ -401,12 +410,9 @@ def _fa_custom_forward_single_device(
401410
o, *aux = custom_call_output
402411
l, m = (v[..., 0] for v in aux[-2:])
403412

404-
if num_batches is not None:
405-
o = o.reshape(num_batches, batch_size, *o.shape[1:])
406-
if l is not None:
407-
l = l.reshape(num_batches, batch_size, *l.shape[1:])
408-
if m is not None:
409-
m = m.reshape(num_batches, batch_size, *m.shape[1:])
413+
o = undo_reshape(o)
414+
l = undo_reshape(l)
415+
m = undo_reshape(m)
410416

411417
return o, l, m
412418

@@ -518,21 +524,18 @@ def _fa_custom_backward_single_device(
518524

519525
num_batches = None
520526
batch_size = None
521-
if len(q.shape) == 5:
522-
num_batches, batch_size, *rest = q.shape
523-
grad_output = grad_output.reshape(-1, *rest)
524-
q = q.reshape(-1, *rest)
525-
k = k.reshape(-1, *rest)
526-
v = v.reshape(-1, *rest)
527-
o = o.reshape(-1, *rest)
528-
l = l.reshape(-1, *rest)
529-
m = m.reshape(-1, *rest)
530-
if q_segment_ids is not None:
531-
q_segment_ids = q_segment_ids.reshape(-1, *rest)
532-
if kv_segment_ids is not None:
533-
kv_segment_ids = kv_segment_ids.reshape(-1, *rest)
534-
if ab is not none:
535-
ab = ab.reshape(-1, *rest)
527+
reshape_to_4d, undo_reshape = _maybe_reshape_input_output_funcs(q.shape, 3)
528+
529+
grad_output = reshape_to_4d(grad_output)
530+
q = reshape_to_4d(q)
531+
k = reshape_to_4d(k)
532+
v = reshape_to_4d(v)
533+
o = reshape_to_4d(o)
534+
l = reshape_to_4d(l)
535+
m = reshape_to_4d(m)
536+
q_segment_ids = reshape_to_4d(q_segment_ids)
537+
kv_segment_ids = reshape_to_4d(kv_segment_ids)
538+
ab = reshape_to_4d(ab)
536539

537540
require_grad_q, require_grad_k, require_grad_v, *rest = ctx_grad
538541
require_grad_ab = ctx_grad[-3]
@@ -646,17 +649,10 @@ def _fa_custom_backward_single_device(
646649
if require_grad_v:
647650
grad_v = grads[1]
648651

649-
if num_batches is not None:
650-
651-
def _reshape(x):
652-
if x is not None:
653-
return x.reshape(num_batches, batch_size, *x.shape[1:])
654-
return None
655-
656-
grad_q = _reshape(grad_q)
657-
grad_k = _reshape(grad_k)
658-
grad_v = _reshape(grad_v)
659-
grad_ab = _reshape(grad_ab)
652+
grad_q = undo_reshape(grad_q)
653+
grad_k = undo_reshape(grad_k)
654+
grad_v = undo_reshape(grad_v)
655+
grad_ab = undo_reshape(grad_ab)
660656

661657
return grad_q, grad_k, grad_v, grad_ab
662658

0 commit comments

Comments
 (0)