Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #8805.
We introduce a decorator,
@assume_pure
, that can be placed on PyTorch/XLA functions and easily eliminate lazy tensor tracing overhead. If you have a pure function that only uses torch upstream ops, that function can be decorated with@assume_pure
and will only be traced once for each unique input tensor shape combinations.Design
@assume_pure
brings together three pieces of existing technologies:jax.vjp
, which takes a JAX function and gives you the autograd forward and backward passtorchax
, which converts a pure PyTorch function to a JAX functionxb.call_jax
, which can call any JAX function from PyTorch/XLA and integrate it into the HLO graphIt works by:
torchax.interop.jax_view
to obtain a JAX function from the input PyTorch functionjax.vjp
to get the forward and backward passtorch.autograd.Function
instance, where the forward implementation isxb.call_jax(forward_pass)
, and the backward implementation isxb.call_jax(backward_pass)
, respectively.The core logic is actually just a single line:
Alternatives
Instead of
jax.vjp
we could also use AOTAutograd to get the forward and backward pass. However, AOTAutograd has a number of downsides:xp.Trace(...)
tojax.named_scope(...)
.Benchmarks
TODO