Skip to content

Commit 4fffbe2

Browse files
committed
BUG: .device attribute inside jax.jit
1 parent 8a79994 commit 4fffbe2

File tree

4 files changed

+60
-17
lines changed

4 files changed

+60
-17
lines changed

Diff for: array_api_compat/common/_helpers.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -658,14 +658,18 @@ def device(x: Array, /) -> Device:
658658
pass
659659
return _DASK_DEVICE
660660
elif is_jax_array(x):
661-
# JAX has .device() as a method, but it is being deprecated so that it
662-
# can become a property, in accordance with the standard. In order for
663-
# this function to not break when JAX makes the flip, we check for
664-
# both here.
665-
if inspect.ismethod(x.device):
666-
return x.device()
661+
# FIXME Jitted JAX arrays do not have a device attribute
662+
# https://github.com/jax-ml/jax/issues/26000
663+
# Return None in this case. Note that this workaround breaks
664+
# the standard and will result in new arrays being created on the
665+
# default device instead of the same device as the input array(s).
666+
x_device = getattr(x, 'device', None)
667+
# Older JAX releases had .device() as a method, which has been replaced
668+
# with a property in accordance with the standard.
669+
if inspect.ismethod(x_device):
670+
return x_device()
667671
else:
668-
return x.device
672+
return x_device
669673
elif is_pydata_sparse_array(x):
670674
# `sparse` will gain `.device`, so check for this first.
671675
x_device = getattr(x, 'device', None)
@@ -796,8 +800,11 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
796800
raise ValueError(f"Unsupported device {device!r}")
797801
elif is_jax_array(x):
798802
if not hasattr(x, "__array_namespace__"):
799-
# In JAX v0.4.31 and older, this import adds to_device method to x.
803+
# In JAX v0.4.31 and older, this import adds to_device method to x...
800804
import jax.experimental.array_api # noqa: F401
805+
# ... but only on eager JAX. It won't work inside jax.jit.
806+
if not hasattr(x, "to_device"):
807+
return x
801808
return x.to_device(device, stream=stream)
802809
elif is_pydata_sparse_array(x) and device == _device(x):
803810
# Perform trivial check to return the same array if

Diff for: tests/test_array_namespace.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_array_namespace(library, api_version, use_compat):
2222
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
2323
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
2424
return
25-
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
25+
namespace = array_namespace(array, api_version=api_version, use_compat=use_compat)
2626

2727
if use_compat is False or use_compat is None and library not in wrapped_libraries:
2828
if library == "jax.numpy" and use_compat is None:
@@ -44,7 +44,7 @@ def test_array_namespace(library, api_version, use_compat):
4444

4545
if library == "numpy":
4646
# check that the same namespace is returned for NumPy scalars
47-
scalar_namespace = array_api_compat.array_namespace(
47+
scalar_namespace = array_namespace(
4848
xp.float64(0.0), api_version=api_version, use_compat=use_compat
4949
)
5050
assert scalar_namespace == namespace
@@ -75,8 +75,7 @@ def test_array_namespace(library, api_version, use_compat):
7575
def test_jax_zero_gradient():
7676
jx = jax.numpy.arange(4)
7777
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
78-
assert (array_api_compat.get_namespace(jax_zero) is
79-
array_api_compat.get_namespace(jx))
78+
assert array_namespace(jax_zero) is array_namespace(jx)
8079

8180
def test_array_namespace_errors():
8281
pytest.raises(TypeError, lambda: array_namespace([1]))
@@ -91,7 +90,7 @@ def test_array_namespace_errors_torch():
9190
x = np.asarray([1, 2])
9291
pytest.raises(TypeError, lambda: array_namespace(x, y))
9392

94-
def test_api_version():
93+
def test_api_version_torch():
9594
x = torch.asarray([1, 2])
9695
torch_ = import_("torch", wrapper=True)
9796
assert array_namespace(x, api_version="2023.12") == torch_
@@ -113,7 +112,7 @@ def test_api_version():
113112

114113
def test_get_namespace():
115114
# Backwards compatible wrapper
116-
assert array_api_compat.get_namespace is array_api_compat.array_namespace
115+
assert array_api_compat.get_namespace is array_namespace
117116

118117
def test_python_scalars():
119118
a = torch.asarray([1, 2])

Diff for: tests/test_common.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import numpy as np
55
import array
6-
from numpy.testing import assert_allclose
6+
from numpy.testing import assert_equal
77

88
from array_api_compat import ( # noqa: F401
99
is_numpy_array, is_cupy_array, is_torch_array,
@@ -188,7 +188,10 @@ def test_device(library):
188188
dev = device(x)
189189

190190
x2 = to_device(x, dev)
191-
assert device(x) == device(x2)
191+
assert device(x2) == device(x)
192+
193+
x3 = xp.asarray(x, device=dev)
194+
assert device(x3) == device(x)
192195

193196

194197
@pytest.mark.parametrize("library", wrapped_libraries)
@@ -207,7 +210,7 @@ def test_to_device_host(library):
207210
# a `device(x)` query; however, what's really important
208211
# here is that we can test portably after calling
209212
# to_device(x, "cpu") to return to host
210-
assert_allclose(x, expected)
213+
assert_equal(x, expected)
211214

212215

213216
@pytest.mark.parametrize("target_library", is_array_functions.keys())

Diff for: tests/test_jax.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from numpy.testing import assert_equal
4+
import pytest
5+
6+
from array_api_compat import device, to_device
7+
8+
HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31"
9+
10+
11+
@pytest.mark.parametrize(
12+
"func",
13+
[
14+
lambda x: jnp.zeros(1, device=device(x)),
15+
lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))),
16+
lambda x: jnp.zeros_like(jnp.empty(1, device=device(x))),
17+
lambda x: jnp.full(1, fill_value=0, device=device(x)),
18+
pytest.param(
19+
lambda x: jnp.asarray([0], device=device(x)),
20+
marks=pytest.mark.skipif(
21+
not HAS_JAX_0_4_31, reason="asarray() has no device= parameter"
22+
),
23+
),
24+
lambda x: to_device(jnp.zeros(1), device(x)),
25+
]
26+
)
27+
def test_device_jit(func):
28+
# Test work around to https://github.com/jax-ml/jax/issues/26000
29+
# Also test missing to_device() method in JAX < 0.4.31
30+
# when inside jax.jit, even after importing jax.experimental.array_api
31+
32+
x = jnp.ones(1)
33+
assert_equal(func(x), jnp.asarray([0]))
34+
assert_equal(jax.jit(func)(x), jnp.asarray([0]))

0 commit comments

Comments
 (0)