Skip to content

write _shard_map; refactor flash attention to support 5d inputs. #8730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 3, 2025

Conversation

qihqi
Copy link
Collaborator

@qihqi qihqi commented Feb 21, 2025

Jax's shard_map works by enabling manual sharding on inputs, and disabling it on outputs.
here we introduce _shard_map to simulate this behavior. This is sufficient to support use cases of calling pallas.
It is not sufficient to support other use cases of shard_map, such as desire to use manual collectives. Because current collective implementation, such as xm.all_gather checks fail with manual sharding.

@qihqi qihqi force-pushed the hanq_shard_map branch 2 times, most recently from 0a27ca4 to 8a5a29f Compare February 27, 2025 05:19
@qihqi qihqi requested a review from pgmoka February 27, 2025 05:19
@qihqi qihqi changed the title write _shard_map; refactor flash attention to use it. write _shard_map; refactor flash attention to support 5d inputs. Feb 27, 2025
@qihqi qihqi requested a review from tengyifei February 27, 2025 05:25
@qihqi qihqi force-pushed the hanq_shard_map branch 2 times, most recently from d885e7b to 8983df7 Compare February 28, 2025 04:21
Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to refactor flash_attention to be shard_map(local_flash_attention)

Comment on lines +91 to +94
q = torch.randn(4, 2, 2, 128, 4).to("xla")
k = torch.randn(4, 2, 2, 128, 4).to("xla")
v = torch.randn(4, 2, 2, 128, 4).to("xla")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In test_flash_attention_spmd_data_parallel, , (8, 2, 128, 8) is used. Would the equivalent here be something like (8, 2, 2, 128, 8)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(8, 2, 2, 128, 8) would also work.
In general, the axis to shard has to be multiple of number of devices in this axis.
In this case the mesh is (n_devices // 2, 2, 1,1,1) so it's 4,2, 1,1,1; therefore having 4,2 as leading dim is sufficient.

For TPUs with 4 devices it would be 2,2,1,1,1 which 4,2 is still multiple-of. Therefore this configuration is general.

@qihqi qihqi requested review from pgmoka and tengyifei March 1, 2025 01:02
@tengyifei
Copy link
Collaborator

Still need to fix TPU test which looks relevant

@qihqi qihqi merged commit cee0820 into master Mar 3, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants