Skip to content

Introduce apply_xla_patch_to_nn_linear and test that in a scan #8739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025

Conversation

tengyifei
Copy link
Collaborator

In order to propagate sharding annotations in 2D sharding, linear layers should be implemented with einsum instead of tranposes/reshapes. Additionally, they need to continue to function inside scan/scan_layers.

For this to work we need three pieces:

  • I added a apply_xla_patch_to_nn_linear function to replace the implementation of nn.Linear with einsum (calling XLAPatchedLinear).
  • The XLAPatchedLinear implementation should be wrapped in torch custom ops. That's because AOTAutograd used by scan will decompose all einsums into transposes/reshapes, unless we use @custom_op to mark a function as opaque to AOTAutograd.
  • Even after wrapping them with @custom_op, the einsum is still decomposed into transposes/reshapes due to
    torch.einsum is incorrectly decomposed when wrapped inside a custom op #8713. That's a bug/PyTorch limitation. To workaround this, I added a _xla_einsum C++ function that directly builds an einsum given XLA tensors, skipping over any PyTorch dispatcher complexity.

Added a test that demonstrates how nn.Linear layers by default flattens any non-contracting dims, and how we could avoid that with apply_xla_patch_to_nn_linear.

@tengyifei tengyifei force-pushed the yifeit/workaround-einsum branch from c53acea to 7d834f7 Compare February 24, 2025 22:56
@tengyifei tengyifei changed the title Support einsum layers in a scan Introduce apply_xla_patch_to_nn_linear and test that in a scan Feb 24, 2025
@tengyifei tengyifei marked this pull request as ready for review February 24, 2025 23:22
In order to propagate sharding annotations in 2D sharding, linear layers
should be implemented with einsum instead of tranposes/reshapes.
Additionally, they need to continue to function inside scan/scan_layers.

For this to work we need three pieces:

- I added a `apply_xla_patch_to_nn_linear` function to replace the
implementation of `nn.Linear` with einsum (calling XLAPatchedLinear).
- The XLAPatchedLinear implementation should be wrapped in torch custom
ops. That's because AOTAutograd used by scan will decompose all einsums
into transposes/reshapes, unless we use `@custom_op` to mark a function
as opaque to AOTAutograd.
- Even after wrapping them with `@custom_op`, the einsum is still
decomposed into transposes/reshapes due to
#8713. That's a bug/PyTorch
limitation. To workaround this, I added a `_xla_einsum` C++ function
that directly builds an einsum given XLA tensors, skipping over any
PyTorch dispatcher complexity.

Added a test that demonstrates how `nn.Linear` layers by default
flattens any non-contracting dims, and how we could avoid that with
`apply_xla_patch_to_nn_linear`.
@tengyifei tengyifei force-pushed the yifeit/workaround-einsum branch from 7d834f7 to 2cf50c8 Compare February 24, 2025 23:48
Copy link
Collaborator

@zpcore zpcore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@tengyifei tengyifei merged commit 6f020aa into master Feb 25, 2025
23 checks passed
Copy link
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants