Skip to content

Commit 011d734

Browse files
committed
yapf
1 parent b8d1ee5 commit 011d734

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

test/test_pallas_spmd.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,16 @@ def test_flash_attention_spmd_data_parallel(self):
8484
def test_flash_attention_spmd_data_parallel_5d(self):
8585
n_devices = xr.global_runtime_device_count()
8686
xs.set_global_mesh(
87-
xs.Mesh(range(n_devices), (n_devices // 2, 2, 1, 1, 1),
88-
('fsdp', 'dp', 'a', 'b', 'c')))
87+
xs.Mesh(
88+
range(n_devices), (n_devices // 2, 2, 1, 1, 1),
89+
('fsdp', 'dp', 'a', 'b', 'c')))
8990

9091
q = torch.randn(4, 2, 2, 128, 4).to("xla")
9192
k = torch.randn(4, 2, 2, 128, 4).to("xla")
9293
v = torch.randn(4, 2, 2, 128, 4).to("xla")
9394

94-
o = flash_attention(q, k, v, partition_spec=('fsdp', 'dp', None, None, None))
95+
o = flash_attention(
96+
q, k, v, partition_spec=('fsdp', 'dp', None, None, None))
9597
dev_ids = ','.join(map(str, range(n_devices)))
9698
self.assertEqual(
9799
torch_xla._XLAC._get_xla_sharding_spec(o),
@@ -144,7 +146,7 @@ def test_flash_attention_backward_spmd_data_parallel(self):
144146
q_grad = q.grad
145147
k_grad = k.grad
146148
v_grad = v.grad
147-
149+
148150
dev_ids = ','.join(map(str, range(n_devices)))
149151
self.assertEqual(
150152
torch_xla._XLAC._get_xla_sharding_spec(q_grad),

torch_xla/experimental/custom_kernel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def wrapped(*args):
9090
for i, r in enumerate(res):
9191
if isinstance(r, torch.Tensor):
9292
assert str(r.device).startswith('xla'), f'{i}th device is {r.device}'
93-
assert len(r.shape) == len(output_specs[i]), f'{i}th shape is {r.shape}, sharding is {output_specs[i]}'
93+
assert len(r.shape) == len(
94+
output_specs[i]
95+
), f'{i}th shape is {r.shape}, sharding is {output_specs[i]}'
9496
return tuple(
9597
xs.disable_manual_sharding(a, spec, _full_shape(a, spec), mesh=mesh).
9698
global_tensor
@@ -387,7 +389,8 @@ def _fa_custom_forward_single_device(
387389
args += [ab]
388390
if segment_ids is not None:
389391
args += [q_segment_ids_fa, kv_segment_ids_fa]
390-
custom_call_output = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes)
392+
custom_call_output = torch_xla._XLAC._xla_tpu_custom_call(
393+
args, payload, shapes, dtypes)
391394

392395
assert isinstance(custom_call_output, list)
393396
if not save_residuals:

0 commit comments

Comments
 (0)