Skip to content

Commit 4a71be2

Browse files
qihqipgmoka
authored andcommitted
write _shard_map; refactor flash attention to support 5d inputs. (#8730)
1 parent a954763 commit 4a71be2

File tree

3 files changed

+418
-187
lines changed

3 files changed

+418
-187
lines changed

test/scan/test_scan_pallas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class ScanFlashAttentionTest(parameterized.TestCase):
7171
def fake_fa_wrapper(self, has_model_weight, use_scan):
7272
torch.manual_seed(12)
7373
torch_xla.manual_seed(12)
74-
hidden_states = torch.randn((2, 4, 256, 256)).requires_grad_().to('xla')
74+
hidden_states = torch.randn((8, 4, 256, 256)).requires_grad_().to('xla')
7575
with xm.xla_device():
7676
attention_layers = AttentionLayers(
7777
has_model_weight, num_layer=3, use_scan=use_scan)

0 commit comments

Comments
 (0)