@@ -21,11 +21,29 @@ def assume_pure(fn):
21
21
return _jax2torch (jax_view (fn ))
22
22
23
23
24
+ # Define the JAX function to compute value and vjp
25
+ def _jax_forward (fn , primals ):
26
+ import jax
27
+
28
+ # Prepare the function call for jax.vjp
29
+ # jax.vjp expects positional primals. We wrap fn to accept args, kwargs.
30
+ def fn_wrapper (a , kw ):
31
+ return fn (* a , ** kw )
32
+
33
+ # primals will be (args_rec, kwargs_rec)
34
+ return jax .vjp (fn_wrapper , * primals ) # Unpack primals here
35
+
36
+
37
+ def _jax_backward (vjp_spec , saved_tensors , grad_args ):
38
+ from jax .tree_util import tree_unflatten
39
+ fun_vjp = tree_unflatten (vjp_spec , saved_tensors )
40
+ return fun_vjp (grad_args )
41
+
42
+
24
43
def _jax2torch (fn ):
25
44
26
45
@wraps (fn )
27
46
def inner (* args , ** kwargs ):
28
- import jax
29
47
from jax .tree_util import tree_flatten , tree_unflatten
30
48
31
49
class JaxFun (torch .autograd .Function ):
@@ -37,19 +55,13 @@ def forward(ctx, tree_def, *flat_args_kwargs_values):
37
55
# Reconstruct the original args and kwargs inside forward
38
56
args_rec , kwargs_rec = tree_unflatten (tree_def , flat_args_kwargs_values )
39
57
40
- # Prepare the function call for jax.vjp
41
- # jax.vjp expects positional primals. We wrap fn to accept args, kwargs.
42
- def fn_wrapper (a , kw ):
43
- return fn (* a , ** kw )
44
-
45
- # Define the JAX function to compute value and vjp
46
- def jax_vjp_func (primals ):
47
- # primals will be (args_rec, kwargs_rec)
48
- return jax .vjp (fn_wrapper , * primals ) # Unpack primals here
49
-
50
58
# Execute the JAX computation
51
59
# Pass the reconstructed args/kwargs tuple as the primal
52
- y_ , fun_vjp = xb .call_jax (jax_vjp_func , args = ((args_rec , kwargs_rec ),))
60
+ y_ , fun_vjp = xb .call_jax (
61
+ _jax_forward , args = (
62
+ fn ,
63
+ (args_rec , kwargs_rec ),
64
+ ))
53
65
54
66
# Save necessary information for backward
55
67
# Flatten the vjp function (may contain tensors/non-tensors)
@@ -72,12 +84,8 @@ def backward(ctx, *grad_args):
72
84
assert len (grad_args ) > 0
73
85
grad_args = grad_args if len (grad_args ) > 1 else grad_args [0 ]
74
86
75
- def jax_func (vjp_spec , saved_tensors , grad_args ):
76
- fun_vjp = tree_unflatten (vjp_spec , saved_tensors )
77
- return fun_vjp (grad_args )
78
-
79
87
input_grads_structured = xb .call_jax (
80
- jax_func , args = (ctx .vjp_spec , ctx .saved_tensors , grad_args ))
88
+ _jax_backward , args = (ctx .vjp_spec , ctx .saved_tensors , grad_args ))
81
89
82
90
# Flatten the gradients to match the flat inputs to forward
83
91
flat_input_grads , _ = tree_flatten (input_grads_structured )
@@ -98,10 +106,9 @@ def jax_func(vjp_spec, saved_tensors, grad_args):
98
106
final_grads .append (grad )
99
107
except StopIteration :
100
108
# Should not happen if JAX computed grads for all required inputs
101
- print (
109
+ raise ValueError (
102
110
"Warning: Mismatch between required grads and JAX output grads."
103
- )
104
- final_grads .append (None )
111
+ ) from None
105
112
else :
106
113
# This input leaf did not require grad
107
114
final_grads .append (None )
0 commit comments