Skip to content

Commit e7e71fb

Browse files
committedJan 24, 2025·
BUG: .device attribute inside jax.jit
1 parent 73f6426 commit e7e71fb

File tree

4 files changed

+61
-18
lines changed

4 files changed

+61
-18
lines changed
 

‎array_api_compat/common/_helpers.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -648,20 +648,24 @@ def device(x: Array, /) -> Device:
648648
if is_numpy_array(x):
649649
return "cpu"
650650
elif is_dask_array(x):
651-
# Peek at the metadata of the jax array to determine type
651+
# Peek at the metadata of the Dask array to determine type
652652
if is_numpy_array(x._meta):
653653
# Must be on CPU since backed by numpy
654654
return "cpu"
655655
return _DASK_DEVICE
656656
elif is_jax_array(x):
657-
# JAX has .device() as a method, but it is being deprecated so that it
658-
# can become a property, in accordance with the standard. In order for
659-
# this function to not break when JAX makes the flip, we check for
660-
# both here.
661-
if inspect.ismethod(x.device):
662-
return x.device()
657+
# FIXME Jitted JAX arrays do not have a device attribute
658+
# https://github.com/jax-ml/jax/issues/26000
659+
# Return None in this case. Note that this workaround breaks
660+
# the standard and will result in new arrays being created on the
661+
# default device instead of the same device as the input array(s).
662+
x_device = getattr(x, 'device', None)
663+
# Older JAX releases had .device() as a method, which has been replaced
664+
# with a property in accordance with the standard.
665+
if inspect.ismethod(x_device):
666+
return x_device()
663667
else:
664-
return x.device
668+
return x_device
665669
elif is_pydata_sparse_array(x):
666670
# `sparse` will gain `.device`, so check for this first.
667671
x_device = getattr(x, 'device', None)
@@ -792,8 +796,11 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
792796
raise ValueError(f"Unsupported device {device!r}")
793797
elif is_jax_array(x):
794798
if not hasattr(x, "__array_namespace__"):
795-
# In JAX v0.4.31 and older, this import adds to_device method to x.
799+
# In JAX v0.4.31 and older, this import adds to_device method to x...
796800
import jax.experimental.array_api # noqa: F401
801+
# ... but only on eager JAX. It won't work inside jax.jit.
802+
if not hasattr(x, "to_device"):
803+
return x
797804
return x.to_device(device, stream=stream)
798805
elif is_pydata_sparse_array(x) and device == _device(x):
799806
# Perform trivial check to return the same array if

‎tests/test_array_namespace.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_array_namespace(library, api_version, use_compat):
2222
if use_compat and library not in wrapped_libraries:
2323
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
2424
return
25-
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
25+
namespace = array_namespace(array, api_version=api_version, use_compat=use_compat)
2626

2727
if use_compat is False or use_compat is None and library not in wrapped_libraries:
2828
if library == "jax.numpy" and use_compat is None:
@@ -44,7 +44,7 @@ def test_array_namespace(library, api_version, use_compat):
4444

4545
if library == "numpy":
4646
# check that the same namespace is returned for NumPy scalars
47-
scalar_namespace = array_api_compat.array_namespace(
47+
scalar_namespace = array_namespace(
4848
xp.float64(0.0), api_version=api_version, use_compat=use_compat
4949
)
5050
assert scalar_namespace == namespace
@@ -75,8 +75,7 @@ def test_array_namespace(library, api_version, use_compat):
7575
def test_jax_zero_gradient():
7676
jx = jax.numpy.arange(4)
7777
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
78-
assert (array_api_compat.get_namespace(jax_zero) is
79-
array_api_compat.get_namespace(jx))
78+
assert array_namespace(jax_zero) is array_namespace(jx)
8079

8180
def test_array_namespace_errors():
8281
pytest.raises(TypeError, lambda: array_namespace([1]))
@@ -91,7 +90,7 @@ def test_array_namespace_errors_torch():
9190
x = np.asarray([1, 2])
9291
pytest.raises(TypeError, lambda: array_namespace(x, y))
9392

94-
def test_api_version():
93+
def test_api_version_torch():
9594
x = torch.asarray([1, 2])
9695
torch_ = import_("torch", wrapper=True)
9796
assert array_namespace(x, api_version="2023.12") == torch_
@@ -113,7 +112,7 @@ def test_api_version():
113112

114113
def test_get_namespace():
115114
# Backwards compatible wrapper
116-
assert array_api_compat.get_namespace is array_api_compat.array_namespace
115+
assert array_api_compat.get_namespace is array_namespace
117116

118117
def test_python_scalars():
119118
a = torch.asarray([1, 2])

‎tests/test_common.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import numpy as np
55
import array
6-
from numpy.testing import assert_allclose
6+
from numpy.testing import assert_equal
77

88
from array_api_compat import ( # noqa: F401
99
is_numpy_array, is_cupy_array, is_torch_array,
@@ -195,7 +195,10 @@ def test_device(library):
195195
dev = device(x)
196196

197197
x2 = to_device(x, dev)
198-
assert device(x) == device(x2)
198+
assert device(x2) == device(x)
199+
200+
x3 = xp.asarray(x, device=dev)
201+
assert device(x3) == device(x)
199202

200203

201204
@pytest.mark.parametrize("library", wrapped_libraries)
@@ -214,7 +217,7 @@ def test_to_device_host(library):
214217
# a `device(x)` query; however, what's really important
215218
# here is that we can test portably after calling
216219
# to_device(x, "cpu") to return to host
217-
assert_allclose(x, expected)
220+
assert_equal(x, expected)
218221

219222

220223
@pytest.mark.parametrize("target_library", is_array_functions.keys())

‎tests/test_jax.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from numpy.testing import assert_equal
4+
import pytest
5+
6+
from array_api_compat import device, to_device
7+
8+
HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31"
9+
10+
11+
@pytest.mark.parametrize(
12+
"func",
13+
[
14+
lambda x: jnp.zeros(1, device=device(x)),
15+
lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))),
16+
lambda x: jnp.zeros_like(jnp.empty(1, device=device(x))),
17+
lambda x: jnp.full(1, fill_value=0, device=device(x)),
18+
pytest.param(
19+
lambda x: jnp.asarray([0], device=device(x)),
20+
marks=pytest.mark.skipif(
21+
not HAS_JAX_0_4_31, reason="asarray() has no device= parameter"
22+
),
23+
),
24+
lambda x: to_device(jnp.zeros(1), device(x)),
25+
]
26+
)
27+
def test_device_jit(func):
28+
# Test work around to https://github.com/jax-ml/jax/issues/26000
29+
# Also test missing to_device() method in JAX < 0.4.31
30+
# when inside jax.jit, even after importing jax.experimental.array_api
31+
32+
x = jnp.ones(1)
33+
assert_equal(func(x), jnp.asarray([0]))
34+
assert_equal(jax.jit(func)(x), jnp.asarray([0]))

0 commit comments

Comments
 (0)
Please sign in to comment.