-
Notifications
You must be signed in to change notification settings - Fork 524
write _shard_map; refactor flash attention to support 5d inputs. #8730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
d5c7ae7
to
af2bbf6
Compare
0a27ca4
to
8a5a29f
Compare
d885e7b
to
8983df7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to refactor flash_attention to be shard_map(local_flash_attention)
q = torch.randn(4, 2, 2, 128, 4).to("xla") | ||
k = torch.randn(4, 2, 2, 128, 4).to("xla") | ||
v = torch.randn(4, 2, 2, 128, 4).to("xla") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In test_flash_attention_spmd_data_parallel
, , (8, 2, 128, 8)
is used. Would the equivalent here be something like (8, 2, 2, 128, 8)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(8, 2, 2, 128, 8) would also work.
In general, the axis to shard has to be multiple of number of devices in this axis.
In this case the mesh is (n_devices // 2, 2, 1,1,1) so it's 4,2, 1,1,1; therefore having 4,2 as leading dim is sufficient.
For TPUs with 4 devices it would be 2,2,1,1,1 which 4,2 is still multiple-of. Therefore this configuration is general.
Still need to fix TPU test which looks relevant |
Jax's shard_map works by enabling manual sharding on inputs, and disabling it on outputs.
here we introduce _shard_map to simulate this behavior. This is sufficient to support use cases of calling pallas.
It is not sufficient to support other use cases of shard_map, such as desire to use manual collectives. Because current collective implementation, such as
xm.all_gather
checks fail with manual sharding.