Skip to content

Commit 17270e2

Browse files
tengyifeiqihqi
andauthored
Minimal support for calling JAX from PyTorch/XLA (#8781)
Co-authored-by: Han Qi <[email protected]>
1 parent 00fac78 commit 17270e2

File tree

4 files changed

+111
-0
lines changed

4 files changed

+111
-0
lines changed

.github/workflows/_test.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ jobs:
128128
set -x
129129
130130
pip install expecttest unittest-xml-reporting
131+
pip install torch_xla[pallas] \
132+
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
133+
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
134+
135+
# Install torchax
136+
pip install pytorch/xla/torchax
131137
132138
if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then
133139
pip install -r pytorch/xla/benchmarks/requirements.txt

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ function run_xla_op_tests2 {
208208
run_test "$CDIR/eager/test_eager_spmd.py"
209209
run_test "$CDIR/test_callback.py"
210210
XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py"
211+
run_test "$CDIR/test_jax_interop.py"
211212
}
212213

213214
# All the new xla op tests should go to run_xla_op_tests3

test/test_jax_interop.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from absl.testing import absltest
2+
3+
import torch
4+
import torch_xla
5+
import torch_xla.core.xla_model as xm
6+
import torch_xla.core.xla_builder as xb
7+
8+
9+
class TestJaxInterop(absltest.TestCase):
10+
11+
def test_call_jax(self):
12+
"""
13+
Test that we can call a JAX function from PyTorch/XLA lazy tensor tracing.
14+
"""
15+
16+
dev = xm.xla_device()
17+
a = torch.ones((3, 3), device=dev)
18+
19+
def f(a, b):
20+
import jax.numpy as jnp
21+
return a + jnp.sin(b)
22+
23+
b = xb.call_jax(f, (a, a), {}, 'my_test')
24+
torch_xla.sync()
25+
torch.testing.assert_close(
26+
b, torch.sin(torch.ones(3, 3)) + 1, check_device=False)
27+
28+
def test_call_jax_pytree(self):
29+
"""
30+
Test that call_jax works with PyTree inputs.
31+
"""
32+
dev = xm.xla_device()
33+
a = torch.ones((2, 2), device=dev)
34+
b = torch.ones((2, 2), device=dev) * 2
35+
36+
def f(inputs):
37+
a = inputs['a']
38+
b = inputs['b']
39+
return a @ b
40+
41+
inputs = {'a': a, 'b': b}
42+
c = xb.call_jax(f, (inputs,))
43+
torch_xla.sync()
44+
torch.testing.assert_close(
45+
c,
46+
torch.tensor(
47+
[
48+
[4, 4],
49+
[4, 4],
50+
],
51+
dtype=torch.float32,
52+
),
53+
check_device=False)
54+
55+
56+
if __name__ == "__main__":
57+
absltest.main()

torch_xla/core/xla_builder.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torch
22
import torch_xla
3+
from torch.utils._pytree import tree_flatten, tree_unflatten
4+
from torch_xla.experimental.custom_kernel import jax_import_guard
35

46

57
class Type:
@@ -799,3 +801,48 @@ def computation_from_module_proto(name, proto):
799801

800802
def get_computation_hlo(computation):
801803
return torch_xla._XLAC._xla_computation_text(computation)
804+
805+
806+
def call_jax(jax_func, args, kwargs=None, name=None):
807+
"""
808+
Call a JAX function `jax_func` with the given `args` and `kwargs` that may contain
809+
XLA tensors.
810+
"""
811+
812+
if name is None:
813+
name = 'jax_func_' + jax_func.__name__
814+
kwargs = kwargs or {}
815+
816+
# If we don't do this before calling jax, any torch_xla operation will hang.
817+
jax_import_guard()
818+
819+
import jax
820+
import torchax.ops.mappings as mappings
821+
822+
flattened, spec = tree_flatten((args, kwargs))
823+
824+
def fn_flattened_inputs(*flattened):
825+
args, kwargs = tree_unflatten(flattened, spec)
826+
return jax_func(*args, **kwargs)
827+
828+
sample_input_shapes = tuple(
829+
jax.ShapeDtypeStruct(a.shape, mappings.t2j_dtype(a.dtype))
830+
for a in flattened)
831+
# `as_serialized_hlo_module_proto` is mentioned at
832+
# https://github.com/jax-ml/jax/discussions/22266
833+
hlo_module = jax.jit(fn_flattened_inputs).lower(
834+
*sample_input_shapes).compiler_ir(
835+
'hlo').as_serialized_hlo_module_proto() # type: ignore
836+
computation = computation_from_module_proto(name, hlo_module)
837+
838+
builder = create_builder(name)
839+
params = []
840+
for idx, val in enumerate(flattened):
841+
params.append(mkparam(builder, idx, tensor_shape(val)))
842+
call_op = Op.call(computation, params)
843+
call_computation = call_op.build('call_jax')
844+
result = torch_xla._XLAC._xla_user_computation(f'xla::call_jax_{name}',
845+
flattened, call_computation)
846+
if isinstance(result, list) and len(result) == 1:
847+
return result[0]
848+
return result

0 commit comments

Comments
 (0)