Skip to content

Commit 1f24d81

Browse files
committed
Fix tests
1 parent 59b2e16 commit 1f24d81

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

test/test_as_stride_use_slice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ class ScanFlashAttentionTest(parameterized.TestCase):
203203
def fake_fa_wrapper(self, has_model_weight, use_scan):
204204
with xm.xla_device():
205205
dm = AttentionLayers(has_model_weight, 3, use_scan)
206-
hidden_states = torch.randn((2, 4, 256, 256)).requires_grad_()
206+
hidden_states = torch.randn((8, 4, 256, 256)).requires_grad_()
207207
hidden_states.retain_grad()
208208
output = dm(hidden_states)
209209
return output

torch_xla/experimental/custom_kernel.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def _fa_custom_forward_single_device(
399399
o = o[0]
400400
# SPMD integration
401401
# We need to consistently return full_q, full_k, full_v,... even though they are empty to support AOT.
402-
return tuple([o] + [torch.Tensor() for _ in range(6)])
402+
return tuple([o] + [torch.Tensor() for _ in range(2)])
403403

404404
assert isinstance(o, list)
405405
o, *aux = o
@@ -410,6 +410,8 @@ def _fa_custom_forward_single_device(
410410
l = l.reshape(num_batches, batch_size, *l.shape[1:])
411411
m = m.reshape(num_batches, batch_size, *m.shape[1:])
412412

413+
print(f'o: {o.shape}')
414+
413415
return o, l, m
414416

415417

@@ -455,8 +457,10 @@ def fa_custom_forward(
455457
if partition_spec is not None:
456458
if len(partition_spec) == 5:
457459
segment_id_partition_spec = (partition_spec[0], partition_spec[1], partition_spec[3])
460+
lm_partition_spec = partition_spec[:4]
458461
else:
459462
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
463+
lm_partition_spec = partition_spec[:3]
460464

461465
input_specs = [
462466
partition_spec, # q
@@ -472,8 +476,8 @@ def fa_custom_forward(
472476

473477
output_specs = [
474478
partition_spec, # o
475-
partition_spec, # l
476-
partition_spec, # m
479+
lm_partition_spec, # l
480+
lm_partition_spec, # m
477481
]
478482

479483
fa_forward_callable = _shard_map(
@@ -696,16 +700,18 @@ def fa_custom_backward(
696700
if partition_spec:
697701
if len(partition_spec) == 5:
698702
segment_id_partition_spec = (partition_spec[0], partition_spec[1], partition_spec[3])
703+
lm_partition_spec = partition_spec[:4]
699704
else:
700705
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
706+
lm_partition_spec = partition_spec[:3]
701707
input_specs = [
702708
partition_spec, # grad_output
703709
partition_spec, # q
704710
partition_spec, # k
705711
partition_spec, # v
706712
partition_spec, # o
707-
partition_spec, # l
708-
partition_spec, # m
713+
lm_partition_spec, # l
714+
lm_partition_spec, # m
709715
segment_id_partition_spec, # q_segment_ids
710716
segment_id_partition_spec, # kv_segment_ids
711717
partition_spec, # ab
@@ -849,6 +855,10 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
849855
# AOT compatiable funtion only accepts argument types listed https://github.com/pytorch/pytorch/blob/82859f61857ef39898b34a5cdf0ae56ec25704d9/torch/_functorch/_aot_autograd/utils.py#L23-L34, so we serliaze partition_spec and mesh into string.
850856
outs = fa_custom_forward(*custom_op_arg, ctx_grads)
851857

858+
for i, o in enumerate(outs):
859+
if isinstance(o, torch.Tensor):
860+
print(f'{i}: {o.shape}')
861+
852862
o = outs[0]
853863
full_q, full_k, full_v, l, m, full_ab = [x for x in outs[1:]]
854864

0 commit comments

Comments
 (0)