Skip to content

Commit b14a500

Browse files
committed
comments
1 parent 1119146 commit b14a500

File tree

3 files changed

+87
-81
lines changed

3 files changed

+87
-81
lines changed

test/scan/test_scan_pallas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class ScanFlashAttentionTest(parameterized.TestCase):
7171
def fake_fa_wrapper(self, has_model_weight, use_scan):
7272
torch.manual_seed(12)
7373
torch_xla.manual_seed(12)
74-
hidden_states = torch.randn((2, 4, 256, 256)).requires_grad_().to('xla')
74+
hidden_states = torch.randn((8, 4, 256, 256)).requires_grad_().to('xla')
7575
with xm.xla_device():
7676
attention_layers = AttentionLayers(
7777
has_model_weight, num_layer=3, use_scan=use_scan)

test/test_pallas_spmd.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def test_flash_attention_spmd_data_parallel(self):
7676
f"{{devices=[{n_devices},1,1,1]{dev_ids}}}")
7777

7878
expected_o = self._attention(q, k, v)
79-
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
79+
torch.testing.assert_close(
80+
o.cpu(), expected_o.cpu(), atol=1e-05, rtol=1e-05)
8081

8182
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
8283
"This test only works on TPUv3+.")
@@ -100,7 +101,8 @@ def test_flash_attention_spmd_data_parallel_5d(self):
100101
f"{{devices=[{n_devices//2},2,1,1,1]{dev_ids}}}")
101102

102103
expected_o = self._attention(q, k, v)
103-
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
104+
torch.testing.assert_close(
105+
o.cpu(), expected_o.cpu(), atol=1e-05, rtol=1e-05)
104106

105107
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
106108
"This test only works on TPUv3+.")
@@ -121,7 +123,8 @@ def test_flash_attention_spmd_data_parallel_kv_and_ab_padding(self):
121123
f"{{devices=[{n_devices},1,1,1]{dev_ids}}}")
122124

123125
expected_o = self._attention(q, k, v, ab=ab)
124-
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
126+
torch.testing.assert_close(
127+
o.cpu(), expected_o.cpu(), atol=1e-05, rtol=1e-05)
125128

126129
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
127130
"This test only works on TPUv3+.")
@@ -172,7 +175,8 @@ def test_flash_attention_backward_spmd_data_parallel(self):
172175
xm.mark_step()
173176

174177
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
175-
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
178+
torch.testing.assert_close(
179+
i[0].grad.cpu(), i[1].cpu(), atol=1e-05, rtol=1e-05)
176180

177181
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
178182
"This test only works on TPUv3+.")
@@ -216,7 +220,8 @@ def test_flash_attention_wrapper_segment_ids_spmd(self):
216220
segment_ids=SegmentIds(jax_segment_ids, jax_segment_ids),
217221
)))
218222

219-
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
223+
torch.testing.assert_close(
224+
o.cpu(), expected_o.cpu(), atol=1e-05, rtol=1e-05)
220225

221226
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
222227
"This test only works on TPUv3+.")
@@ -286,7 +291,8 @@ def test_flash_attention_backward_segment_ids_spmd(self):
286291
xm.mark_step()
287292

288293
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
289-
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
294+
torch.testing.assert_close(
295+
i[0].grad.cpu(), i[1].cpu(), atol=1e-05, rtol=1e-05)
290296

291297
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
292298
"This test only works on TPUv3+.")
@@ -331,7 +337,8 @@ def test_cross_flash_attention_wrapper_segment_ids_spmd(self):
331337
segment_ids=SegmentIds(jax_q_segment_ids, jax_kv_segment_ids),
332338
)))
333339

334-
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
340+
torch.testing.assert_close(
341+
o.cpu(), expected_o.cpu(), atol=1e-05, rtol=1e-05)
335342

336343
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
337344
"This test only works on TPUv3+.")
@@ -402,7 +409,8 @@ def test_cross_flash_attention_backward_segment_ids_spmd(self):
402409
xm.mark_step()
403410

404411
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
405-
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
412+
torch.testing.assert_close(
413+
i[0].grad.cpu(), i[1].cpu(), atol=1e-05, rtol=1e-05)
406414

407415
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
408416
"This test only works on TPUv4+.")
@@ -484,7 +492,8 @@ def flash_attention_wrapper(q, k, v, casual, q_segment_ids, kv_segment_ids,
484492
xm.mark_step()
485493

486494
for i in [(q, q_grad), (k, k_grad), (v, v_grad), (ab, ab_grad)]:
487-
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-02))
495+
torch.testing.assert_close(
496+
i[0].grad.cpu(), i[1].cpu(), atol=1e-02, rtol=1e-05)
488497

489498

490499
if __name__ == '__main__':

torch_xla/experimental/custom_kernel.py

Lines changed: 68 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import os
3+
import math
34
import warnings
45

56
import torch
@@ -20,7 +21,7 @@ def _shard_map(func, mesh, input_specs, output_specs):
2021
2122
Note:
2223
``shard_map`` is an experimental API, and still subject to change. For an
23-
introduction to sharded data, refer to :ref:`sharded-computation`. For a more
24+
introduction to sharded data. For a more
2425
in-depth look at using ``shard_map``, refer to
2526
[SPMD multi-device parallelism with shard_map](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
2627
@@ -43,7 +44,7 @@ def _shard_map(func, mesh, input_specs, output_specs):
4344
the ``mesh`` and ``out_specs``.
4445
4546
Reference:
46-
This function is identical Jax's shard_map:
47+
This function behaves identically Jax's shard_map:
4748
https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html
4849
"""
4950

@@ -56,18 +57,14 @@ def _full_shape(a, spec):
5657
result_shape = []
5758
for axis_size, axis_sharding in zip(a.shape, spec):
5859
if axis_sharding is None:
59-
new_size = axis_size
60-
else:
61-
if isinstance(axis_sharding, (str, int)):
62-
mesh_mult = mesh_name_to_size[axis_sharding]
63-
else:
64-
# tuple or list
65-
mesh_mult = math.prod(mesh_name_to_size[a]
66-
for a in axis_sharding
67-
if mesh_name_to_size[a] is not None)
68-
69-
if mesh_mult is not None:
70-
new_size = axis_size * mesh_mult
60+
axis_sharding = ()
61+
mesh_mult = []
62+
if isinstance(axis_sharding, (str, int)):
63+
axis_sharding = [axis_sharding]
64+
for axis in axis_sharding:
65+
size = mesh_name_to_size[axis] or 1
66+
mesh_mult.append(size)
67+
new_size = axis_size * math.prod(mesh_mult)
7168
result_shape.append(new_size)
7269
return tuple(result_shape)
7370

@@ -76,7 +73,7 @@ def wrapped(*args):
7673
input_specs), f'args={len(args)}; input_specs={len(input_specs)}'
7774
new_args = []
7875
for i, (a, spec) in enumerate(zip(args, input_specs)):
79-
if isinstance(a, torch.Tensor) and spec is not None:
76+
if isinstance(a, torch.Tensor):
8077
assert (len(a.shape) == len(spec)
8178
), f'{i}th input has wrong shape: {a.shape} for {spec}'
8279
new_a = xs.enable_manual_sharding(a, spec, mesh=mesh).global_tensor
@@ -87,22 +84,21 @@ def wrapped(*args):
8784
res = func(*new_args)
8885
if isinstance(res, tuple):
8986
res_updated = []
90-
for i, r in enumerate(res):
91-
if isinstance(r, torch.Tensor):
87+
for i, (r, spec) in enumerate(zip(res, output_specs)):
88+
if isinstance(r, torch.Tensor) and spec is not None:
9289
assert str(r.device).startswith('xla'), f'{i}th device is {r.device}'
9390
assert len(r.shape) == len(
94-
output_specs[i]
95-
), f'{i}th shape is {r.shape}, sharding is {output_specs[i]}'
96-
return tuple(
97-
xs.disable_manual_sharding(a, spec, _full_shape(a, spec), mesh=mesh).
98-
global_tensor
99-
if isinstance(a, torch.Tensor) and spec is not None else a
100-
for a, spec in zip(res, output_specs))
91+
spec), f'{i}th shape is {r.shape}, sharding is {output_specs[i]}'
92+
new_r = xs.disable_manual_sharding(
93+
r, spec, _full_shape(r, spec), mesh=mesh).global_tensor
94+
else:
95+
new_r = r
96+
res_updated.append(new_r)
97+
return res_updated
10198
else:
10299
return xs.disable_manual_sharding(
103100
res, output_specs[0], _full_shape(res, output_specs[0]),
104101
mesh=mesh).global_tensor
105-
return res
106102

107103
return wrapped
108104

@@ -309,6 +305,24 @@ def wrapped_kernel(kernel: Callable,
309305
return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn)
310306

311307

308+
def _maybe_reshape_input_output_funcs(current_shape, non_batch_dims=3):
309+
batch_dims = len(current_shape) - non_batch_dims
310+
orig_batch_dims = current_shape[:batch_dims]
311+
other_dims = current_shape[batch_dims:]
312+
313+
def reshape_input(tensor):
314+
if tensor is None:
315+
return None
316+
return tensor.reshape(-1, *tensor.shape[batch_dims:])
317+
318+
def reshape_output(tensor):
319+
if tensor is None:
320+
return None
321+
return tensor.reshape(*orig_batch_dims, *tensor.shape[1:])
322+
323+
return reshape_input, reshape_output
324+
325+
312326
def _fa_custom_forward_single_device(
313327
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool,
314328
q_segment_ids: torch.Tensor, kv_segment_ids: torch.Tensor, sm_scale: float,
@@ -318,20 +332,16 @@ def _fa_custom_forward_single_device(
318332

319333
num_batches = None
320334
batch_size = None
321-
if len(q.shape) == 5:
322-
num_batches, batch_size, *rest = q.shape
323-
q = q.reshape(-1, *rest)
324-
k = k.reshape(-1, *rest)
325-
v = v.reshape(-1, *rest)
326-
if q_segment_ids is not None:
327-
q_segment_ids = q_segment_ids.reshape(-1, *rest)
328-
if kv_segment_ids is not None:
329-
kv_segment_ids = kv_segment_ids.reshape(-1, *rest)
330-
if ab is not None:
331-
ab = ab.reshape(-1, *rest)
332-
333-
# Suprisingly, any tensor that is input to the custom_op decorated function will show
334-
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
335+
reshape_to_4d, undo_reshape = _maybe_reshape_input_output_funcs(q.shape, 3)
336+
q = reshape_to_4d(q)
337+
v = reshape_to_4d(v)
338+
k = reshape_to_4d(k)
339+
q_segment_ids = reshape_to_4d(q_segment_ids)
340+
kv_segment_ids = reshape_to_4d(kv_segment_ids)
341+
ab = reshape_to_4d(ab)
342+
343+
# Surprisingly, any tensor that is input to the custom_op decorated function will show
344+
# requires_grad=False by design. We have to pass ctx_grad to record the
335345
# requires_grad for inputs.
336346
# Original we use save_residuals = q.requires_grad or k.requires_grad or v.requires_grad
337347
save_residuals = any(ctx_grad[:3])
@@ -401,12 +411,9 @@ def _fa_custom_forward_single_device(
401411
o, *aux = custom_call_output
402412
l, m = (v[..., 0] for v in aux[-2:])
403413

404-
if num_batches is not None:
405-
o = o.reshape(num_batches, batch_size, *o.shape[1:])
406-
if l is not None:
407-
l = l.reshape(num_batches, batch_size, *l.shape[1:])
408-
if m is not None:
409-
m = m.reshape(num_batches, batch_size, *m.shape[1:])
414+
o = undo_reshape(o)
415+
l = undo_reshape(l)
416+
m = undo_reshape(m)
410417

411418
return o, l, m
412419

@@ -518,21 +525,18 @@ def _fa_custom_backward_single_device(
518525

519526
num_batches = None
520527
batch_size = None
521-
if len(q.shape) == 5:
522-
num_batches, batch_size, *rest = q.shape
523-
grad_output = grad_output.reshape(-1, *rest)
524-
q = q.reshape(-1, *rest)
525-
k = k.reshape(-1, *rest)
526-
v = v.reshape(-1, *rest)
527-
o = o.reshape(-1, *rest)
528-
l = l.reshape(-1, *rest)
529-
m = m.reshape(-1, *rest)
530-
if q_segment_ids is not None:
531-
q_segment_ids = q_segment_ids.reshape(-1, *rest)
532-
if kv_segment_ids is not None:
533-
kv_segment_ids = kv_segment_ids.reshape(-1, *rest)
534-
if ab is not none:
535-
ab = ab.reshape(-1, *rest)
528+
reshape_to_4d, undo_reshape = _maybe_reshape_input_output_funcs(q.shape, 3)
529+
530+
grad_output = reshape_to_4d(grad_output)
531+
q = reshape_to_4d(q)
532+
k = reshape_to_4d(k)
533+
v = reshape_to_4d(v)
534+
o = reshape_to_4d(o)
535+
l = reshape_to_4d(l)
536+
m = reshape_to_4d(m)
537+
q_segment_ids = reshape_to_4d(q_segment_ids)
538+
kv_segment_ids = reshape_to_4d(kv_segment_ids)
539+
ab = reshape_to_4d(ab)
536540

537541
require_grad_q, require_grad_k, require_grad_v, *rest = ctx_grad
538542
require_grad_ab = ctx_grad[-3]
@@ -646,17 +650,10 @@ def _fa_custom_backward_single_device(
646650
if require_grad_v:
647651
grad_v = grads[1]
648652

649-
if num_batches is not None:
650-
651-
def _reshape(x):
652-
if x is not None:
653-
return x.reshape(num_batches, batch_size, *x.shape[1:])
654-
return None
655-
656-
grad_q = _reshape(grad_q)
657-
grad_k = _reshape(grad_k)
658-
grad_v = _reshape(grad_v)
659-
grad_ab = _reshape(grad_ab)
653+
grad_q = undo_reshape(grad_q)
654+
grad_k = undo_reshape(grad_k)
655+
grad_v = undo_reshape(grad_v)
656+
grad_ab = undo_reshape(grad_ab)
660657

661658
return grad_q, grad_k, grad_v, grad_ab
662659

0 commit comments

Comments
 (0)