Skip to content

Commit 07270c3

Browse files
committed
cache
1 parent 0f7c5ce commit 07270c3

File tree

1 file changed

+27
-20
lines changed

1 file changed

+27
-20
lines changed

torch_xla/experimental/assume_pure.py

+27-20
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,29 @@ def assume_pure(fn):
2121
return _jax2torch(jax_view(fn))
2222

2323

24+
# Define the JAX function to compute value and vjp
25+
def _jax_forward(fn, primals):
26+
import jax
27+
28+
# Prepare the function call for jax.vjp
29+
# jax.vjp expects positional primals. We wrap fn to accept args, kwargs.
30+
def fn_wrapper(a, kw):
31+
return fn(*a, **kw)
32+
33+
# primals will be (args_rec, kwargs_rec)
34+
return jax.vjp(fn_wrapper, *primals) # Unpack primals here
35+
36+
37+
def _jax_backward(vjp_spec, saved_tensors, grad_args):
38+
from jax.tree_util import tree_unflatten
39+
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
40+
return fun_vjp(grad_args)
41+
42+
2443
def _jax2torch(fn):
2544

2645
@wraps(fn)
2746
def inner(*args, **kwargs):
28-
import jax
2947
from jax.tree_util import tree_flatten, tree_unflatten
3048

3149
class JaxFun(torch.autograd.Function):
@@ -37,19 +55,13 @@ def forward(ctx, tree_def, *flat_args_kwargs_values):
3755
# Reconstruct the original args and kwargs inside forward
3856
args_rec, kwargs_rec = tree_unflatten(tree_def, flat_args_kwargs_values)
3957

40-
# Prepare the function call for jax.vjp
41-
# jax.vjp expects positional primals. We wrap fn to accept args, kwargs.
42-
def fn_wrapper(a, kw):
43-
return fn(*a, **kw)
44-
45-
# Define the JAX function to compute value and vjp
46-
def jax_vjp_func(primals):
47-
# primals will be (args_rec, kwargs_rec)
48-
return jax.vjp(fn_wrapper, *primals) # Unpack primals here
49-
5058
# Execute the JAX computation
5159
# Pass the reconstructed args/kwargs tuple as the primal
52-
y_, fun_vjp = xb.call_jax(jax_vjp_func, args=((args_rec, kwargs_rec),))
60+
y_, fun_vjp = xb.call_jax(
61+
_jax_forward, args=(
62+
fn,
63+
(args_rec, kwargs_rec),
64+
))
5365

5466
# Save necessary information for backward
5567
# Flatten the vjp function (may contain tensors/non-tensors)
@@ -72,12 +84,8 @@ def backward(ctx, *grad_args):
7284
assert len(grad_args) > 0
7385
grad_args = grad_args if len(grad_args) > 1 else grad_args[0]
7486

75-
def jax_func(vjp_spec, saved_tensors, grad_args):
76-
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
77-
return fun_vjp(grad_args)
78-
7987
input_grads_structured = xb.call_jax(
80-
jax_func, args=(ctx.vjp_spec, ctx.saved_tensors, grad_args))
88+
_jax_backward, args=(ctx.vjp_spec, ctx.saved_tensors, grad_args))
8189

8290
# Flatten the gradients to match the flat inputs to forward
8391
flat_input_grads, _ = tree_flatten(input_grads_structured)
@@ -98,10 +106,9 @@ def jax_func(vjp_spec, saved_tensors, grad_args):
98106
final_grads.append(grad)
99107
except StopIteration:
100108
# Should not happen if JAX computed grads for all required inputs
101-
print(
109+
raise ValueError(
102110
"Warning: Mismatch between required grads and JAX output grads."
103-
)
104-
final_grads.append(None)
111+
) from None
105112
else:
106113
# This input leaf did not require grad
107114
final_grads.append(None)

0 commit comments

Comments
 (0)