Skip to content

BUG: .device attribute inside jax.jit #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,20 +648,24 @@ def device(x: Array, /) -> Device:
if is_numpy_array(x):
return "cpu"
elif is_dask_array(x):
# Peek at the metadata of the jax array to determine type
# Peek at the metadata of the Dask array to determine type
if is_numpy_array(x._meta):
# Must be on CPU since backed by numpy
return "cpu"
return _DASK_DEVICE
elif is_jax_array(x):
# JAX has .device() as a method, but it is being deprecated so that it
# can become a property, in accordance with the standard. In order for
# this function to not break when JAX makes the flip, we check for
# both here.
if inspect.ismethod(x.device):
return x.device()
# FIXME Jitted JAX arrays do not have a device attribute
# https://github.com/jax-ml/jax/issues/26000
# Return None in this case. Note that this workaround breaks
# the standard and will result in new arrays being created on the
# default device instead of the same device as the input array(s).
x_device = getattr(x, 'device', None)
# Older JAX releases had .device() as a method, which has been replaced
# with a property in accordance with the standard.
if inspect.ismethod(x_device):
return x_device()
else:
return x.device
return x_device
elif is_pydata_sparse_array(x):
# `sparse` will gain `.device`, so check for this first.
x_device = getattr(x, 'device', None)
Expand Down Expand Up @@ -792,8 +796,11 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
raise ValueError(f"Unsupported device {device!r}")
elif is_jax_array(x):
if not hasattr(x, "__array_namespace__"):
# In JAX v0.4.31 and older, this import adds to_device method to x.
# In JAX v0.4.31 and older, this import adds to_device method to x...
import jax.experimental.array_api # noqa: F401
# ... but only on eager JAX. It won't work inside jax.jit.
if not hasattr(x, "to_device"):
return x
return x.to_device(device, stream=stream)
elif is_pydata_sparse_array(x) and device == _device(x):
# Perform trivial check to return the same array if
Expand Down
11 changes: 5 additions & 6 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_array_namespace(library, api_version, use_compat):
if use_compat and library not in wrapped_libraries:
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
return
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
namespace = array_namespace(array, api_version=api_version, use_compat=use_compat)

if use_compat is False or use_compat is None and library not in wrapped_libraries:
if library == "jax.numpy" and use_compat is None:
Expand All @@ -44,7 +44,7 @@ def test_array_namespace(library, api_version, use_compat):

if library == "numpy":
# check that the same namespace is returned for NumPy scalars
scalar_namespace = array_api_compat.array_namespace(
scalar_namespace = array_namespace(
xp.float64(0.0), api_version=api_version, use_compat=use_compat
)
assert scalar_namespace == namespace
Expand Down Expand Up @@ -75,8 +75,7 @@ def test_array_namespace(library, api_version, use_compat):
def test_jax_zero_gradient():
jx = jax.numpy.arange(4)
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
assert (array_api_compat.get_namespace(jax_zero) is
array_api_compat.get_namespace(jx))
assert array_namespace(jax_zero) is array_namespace(jx)

def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
Expand All @@ -91,7 +90,7 @@ def test_array_namespace_errors_torch():
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))

def test_api_version():
def test_api_version_torch():
x = torch.asarray([1, 2])
torch_ = import_("torch", wrapper=True)
assert array_namespace(x, api_version="2023.12") == torch_
Expand All @@ -113,7 +112,7 @@ def test_api_version():

def test_get_namespace():
# Backwards compatible wrapper
assert array_api_compat.get_namespace is array_api_compat.array_namespace
assert array_api_compat.get_namespace is array_namespace

def test_python_scalars():
a = torch.asarray([1, 2])
Expand Down
9 changes: 6 additions & 3 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import numpy as np
import array
from numpy.testing import assert_allclose
from numpy.testing import assert_equal

from array_api_compat import ( # noqa: F401
is_numpy_array, is_cupy_array, is_torch_array,
Expand Down Expand Up @@ -195,7 +195,10 @@ def test_device(library):
dev = device(x)

x2 = to_device(x, dev)
assert device(x) == device(x2)
assert device(x2) == device(x)

x3 = xp.asarray(x, device=dev)
assert device(x3) == device(x)


@pytest.mark.parametrize("library", wrapped_libraries)
Expand All @@ -214,7 +217,7 @@ def test_to_device_host(library):
# a `device(x)` query; however, what's really important
# here is that we can test portably after calling
# to_device(x, "cpu") to return to host
assert_allclose(x, expected)
assert_equal(x, expected)


@pytest.mark.parametrize("target_library", is_array_functions.keys())
Expand Down
34 changes: 34 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import jax
import jax.numpy as jnp
from numpy.testing import assert_equal
import pytest

from array_api_compat import device, to_device

HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31"


@pytest.mark.parametrize(
"func",
[
lambda x: jnp.zeros(1, device=device(x)),
lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))),
lambda x: jnp.zeros_like(jnp.empty(1, device=device(x))),
lambda x: jnp.full(1, fill_value=0, device=device(x)),
pytest.param(
lambda x: jnp.asarray([0], device=device(x)),
marks=pytest.mark.skipif(
not HAS_JAX_0_4_31, reason="asarray() has no device= parameter"
),
),
lambda x: to_device(jnp.zeros(1), device(x)),
]
)
def test_device_jit(func):
# Test work around to https://github.com/jax-ml/jax/issues/26000
# Also test missing to_device() method in JAX < 0.4.31
# when inside jax.jit, even after importing jax.experimental.array_api

x = jnp.ones(1)
assert_equal(func(x), jnp.asarray([0]))
assert_equal(jax.jit(func)(x), jnp.asarray([0]))
Loading