Skip to content

Manually register einsum xla #8787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 34 commits into from
Closed

Conversation

pgmoka
Copy link
Collaborator

@pgmoka pgmoka commented Mar 4, 2025

Do manual registration of XLANativeFunctions::einsum for XLA.

This is necessary because currently PyTorch overwrites the key AutogradXLA registration with a its XLA key registration. While ideally we would be able to resolve this problem, this work around resolves the issue from our end. It is also not possible to use full code generation due to #8739.

This manual registration relies on the XLANativeFunctions::einsum function from xla/torch_xla/csrc/aten_xla_type.cpp

@pgmoka pgmoka self-assigned this Mar 4, 2025
@pgmoka pgmoka requested review from tengyifei and ysiraichi March 4, 2025 20:23
@pgmoka
Copy link
Collaborator Author

pgmoka commented Mar 4, 2025

As the overwrite is written, there is no meaningful unit test we can add as we rely on the generated XLANativeFunctions function. I could do something like what nms_kernel is doing, and refer to tensor_methods::einsum, but this will require rewriting the conditions from

at::Tensor XLANativeFunctions::einsum(std::string_view equation,

This cleanest way to do something like that would be to create a utility function shared by both XLANativeFunctions::einsum and our overwrite, and then that might let us test the overwrite. This can be the next step here, or if we think it is unnecessary, we can just do the implementation from this PR.

nms_kernel is itself not being tested directly as far as I can tell. Perhaps this is a separate larger issue we can track separately from the PR.

@tengyifei @ysiraichi: Do y'all have any opinions on this?

@tengyifei
Copy link
Collaborator

@pgmoka are you looking for a unit test? A good test IMO is what we wrote in the https://github.com/tengyifei/playground/blob/master/aot-einsum-3.ipynb notebook. We could verify the lowering of einsum in an custom op.

Another test is we should remove the two workarounds referenced in https://github.com/search?q=repo%3Apytorch%2Fxla+8713&type=code, and then the unit test for XLAPatchedLinear should still pass. Because we also check its lowering there.

@pgmoka
Copy link
Collaborator Author

pgmoka commented Mar 4, 2025

CC: @lsy323

Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a unit test in Python and also remove the _xla_einsum workaround in this PR (which will also test that this registration worked).

@tengyifei
Copy link
Collaborator

Not sure how all these commits got into this branch. Usually I rebase the branch on top of the latest master and then force push. This way I only have a single commit in the PR.

@ysiraichi
Copy link
Collaborator

One thing you can do is to check for xla::einsum in the XLA counters. I believe that, before your PR, it wouldn't be in there, since the CompositeImplicitAutograd kernel was called.

@pgmoka
Copy link
Collaborator Author

pgmoka commented Mar 5, 2025

Not sure how all these commits got into this branch. Usually I rebase the branch on top of the latest master and then force push. This way I only have a single commit in the PR.

I honestly don't know how this happened either. I think I messed something up while fetching the current master to rebase to branch with. I needed to do this to get the latest changes related to _einsum. The final state is what I wanted, but it creates this unfortunate commit map on the PR

@tengyifei
Copy link
Collaborator

I honestly don't know how this happened either. I think I messed something up while fetching the current master to rebase to branch with. I needed to do this to get the latest changes related to _einsum. The final state is what I wanted, but it creates this unfortunate commit map on the PR

Gotcha. In that case could you squash the commits from git and reset the commit message so that it hopefully doesn't confuse future readers? Thanks!

@pgmoka pgmoka enabled auto-merge (squash) March 5, 2025 22:41
Use .backward() with in-place grad mutations for the GA API (#8768)

Use placeholder tensor in scan (#8785)

Pin update to 20250303 (#8788)

Co-authored-by: Chengji Yao <[email protected]>

correct linter
@pgmoka pgmoka force-pushed the manually_register_einsum_XLA branch from e2aace1 to b922fa0 Compare March 5, 2025 23:02
@pgmoka pgmoka closed this Mar 5, 2025
auto-merge was automatically disabled March 5, 2025 23:32

Pull request was closed

@pgmoka
Copy link
Collaborator Author

pgmoka commented Mar 5, 2025

Too many conflicts. I accidentally merged from master rather than rebasing, and it caused a bunch of issues. My changes are small enough that I will just carry on in a separate PR. I apologize to the reviewers for the noise

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.