Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@assume_pure #8923

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

@assume_pure #8923

wants to merge 6 commits into from

Conversation

tengyifei
Copy link
Collaborator

@tengyifei tengyifei commented Apr 2, 2025

Fixes #8805.

We introduce a decorator, @assume_pure, that can be placed on PyTorch/XLA functions and easily eliminate lazy tensor tracing overhead. If you have a pure function that only uses torch upstream ops, that function can be decorated with @assume_pure and will only be traced once for each unique input tensor shape combinations.

Design

@assume_pure brings together three pieces of existing technologies:

  • jax.vjp, which takes a JAX function and gives you the autograd forward and backward pass
  • torchax, which converts a pure PyTorch function to a JAX function
  • xb.call_jax, which can call any JAX function from PyTorch/XLA and integrate it into the HLO graph

It works by:

  • Use torchax.interop.jax_view to obtain a JAX function from the input PyTorch function
  • Use jax.vjp to get the forward and backward pass
  • Return a torch.autograd.Function instance, where the forward implementation is xb.call_jax(forward_pass), and the backward implementation is xb.call_jax(backward_pass), respectively.

The core logic is actually just a single line:

def assume_pure(fn):
  from torchax.interop import jax_view
  return j2t_autograd(jax_view(fn))

Alternatives

Instead of jax.vjp we could also use AOTAutograd to get the forward and backward pass. However, AOTAutograd has a number of downsides:

  • It does more than just getting the backward. It also forcefully decomposes all operations into the "aten" op set. Decomposing operations will negatively impact performance, especially in the case of einsum.
  • There is no straightforward path to support profiler trace spans. In contrast, in the proposed approach we could translate xp.Trace(...) to jax.named_scope(...).
  • Supporting custom operations such as pallas kernels will be cumbersome. We'll need to wrap every kernel into a PyTorch custom operator in order for AOTAutograd to not crash on those functions. In contrast, in the proposed approach we could augment our pallas kernels to directly jump into JAX when the input tensor is a torchax tensor.

Benchmarks

TODO

@tengyifei tengyifei changed the title Yifeit/vjp in xla @assume_pure Apr 2, 2025
@tengyifei tengyifei force-pushed the yifeit/torchax-in-torch-xla branch from 13a631b to 8621943 Compare April 4, 2025 05:55
Add more tests
@tengyifei tengyifei force-pushed the yifeit/vjp-in-xla branch from 4e63515 to a1230e1 Compare April 9, 2025 00:38
@tengyifei tengyifei changed the base branch from yifeit/torchax-in-torch-xla to master April 9, 2025 00:38
@tengyifei tengyifei force-pushed the yifeit/vjp-in-xla branch from c37e821 to 422c6df Compare April 9, 2025 07:52
@tengyifei tengyifei marked this pull request as ready for review April 9, 2025 07:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use jax autograd from PyTorch
1 participant