-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
BUG/TST: special.logsumexp
on non-default device
#22756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
3025a6a
deef6eb
568ea55
addaaac
c7681d8
06826fa
926d9f8
a479eee
46e314e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,8 @@ | |
|
||
from scipy._lib._fpumode import get_fpu_mode | ||
from scipy._lib._array_api import ( | ||
SCIPY_ARRAY_API, SCIPY_DEVICE, array_namespace, default_xp | ||
SCIPY_ARRAY_API, SCIPY_DEVICE, array_namespace, default_xp, | ||
is_cupy, is_dask, is_jax, is_torch, xp_device, | ||
) | ||
from scipy._lib._testutils import FPUModeChangeWarning | ||
from scipy._lib.array_api_extra.testing import patch_lazy_xp_functions | ||
|
@@ -154,9 +155,10 @@ def num_parallel_threads(): | |
try: | ||
import torch # type: ignore[import-not-found] | ||
xp_available_backends.update({'torch': torch}) | ||
# can use `mps` or `cpu` | ||
torch.set_default_device(SCIPY_DEVICE) | ||
if SCIPY_DEVICE != "cpu": | ||
# FIXME don't set this when SCIPY_DEVICE == "cpu" | ||
# as a workaround to pytorch/pytorch#150199 | ||
torch.set_default_device(SCIPY_DEVICE) | ||
xp_skip_cpu_only_backends.add('torch') | ||
|
||
# default to float64 unless explicitly requested | ||
|
@@ -409,6 +411,51 @@ def skip_or_xfail_xp_backends(request: pytest.FixtureRequest, | |
skip_or_xfail(reason=reason) | ||
|
||
|
||
@pytest.fixture | ||
def nondefault_device(xp): | ||
"""Fixture that returns a device other than the default device for the backend. | ||
Used to test input->output device propagation. | ||
|
||
Usage | ||
----- | ||
from scipy._lib._array_api import xp_device | ||
|
||
def test(xp, nondefault_device): | ||
a = xp.asarray(..., device=nondefault_device) | ||
b = f(a) | ||
assert xp_device(b) == nondefault_device | ||
""" | ||
if is_cupy(xp): | ||
pytest.xfail(reason="data-apis/array-api-compat#293") | ||
if is_dask(xp): | ||
pytest.skip(reason="dummy device from array-api-compat does not propagate") | ||
if is_jax(xp): | ||
# The .device attribute is not accessible inside jax.jit; the consequence | ||
# (downstream of array-api-compat hacks) is that a non-default device in | ||
# input is not guaranteed to propagate to the output even if the scipy code | ||
# states `device=xp_device(arg)`` in all array creation functions. | ||
# While this issue is specific to jax.jit, it would be unnecessarily | ||
# verbose to skip the test for each jit-capable function and run it for | ||
# those that only support eager mode. | ||
# Additionally, devices() below only returns CudaDevice objects when | ||
# CUDA is enabled, which prevents us from running this test on CPU vs. GPU. | ||
pytest.xfail(reason="jax-ml/jax#26000 + jax-ml/jax#27606") | ||
if is_torch(xp) and SCIPY_DEVICE != "cpu": | ||
# Note workaround when parsing SCIPY_DEVICE above. | ||
# Also note that when SCIPY_DEVICE=cpu this test won't run in CI | ||
# because CUDA-enabled CI boxes always use SCIPY_DEVICE=cuda. | ||
pytest.xfail(reason="pytorch/pytorch#150199") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Workaround: data-apis/array-api-compat#299 |
||
|
||
devices = xp.__array_namespace_info__().devices() | ||
# Don't use xp.__array_namespace_info__().default_device(): | ||
# https://github.com/data-apis/array-api/issues/835 | ||
default = xp_device(xp.empty(())) | ||
try: | ||
return next(iter(d for d in devices if d != default)) | ||
except StopIteration: | ||
pytest.skip(reason="Only one device available") | ||
|
||
|
||
# Following the approach of NumPy's conftest.py... | ||
# Use a known and persistent tmpdir for hypothesis' caches, which | ||
# can be automatically cleared by the OS or user. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
|
||
import numpy as np | ||
|
||
from scipy._lib._array_api import is_array_api_strict, xp_default_dtype | ||
from scipy._lib._array_api import is_array_api_strict, xp_default_dtype, xp_device | ||
from scipy._lib._array_api_no_0d import (xp_assert_equal, xp_assert_close, | ||
xp_assert_less) | ||
|
||
|
@@ -292,6 +292,13 @@ def test_gh22601_infinite_elements(self, x, y, xp): | |
ref = xp.log(xp.sum(xp.exp(xp.asarray([x, y])))) | ||
xp_assert_equal(res, ref) | ||
|
||
@pytest.mark.parametrize("x", [1.0, 1.0j, []]) | ||
def test_device(self, x, xp, nondefault_device): | ||
"""Test input device propagation to output.""" | ||
x = xp.asarray(x, device=nondefault_device) | ||
assert xp_device(logsumexp(x)) == nondefault_device | ||
assert xp_device(logsumexp(x, b=x)) == nondefault_device | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At first, if you're working on a machine with a single GPU for example, this PR may appear to be fairly straightforward. However, when I test on a node/machine that has multiple GPUs and use i.e., this test will currently fail on this branch:
Does CuPy require special treatment? Do multiple GPUs require special treatment? It isn't immediately obvious to me, but the nature of the failure suggests that device propagation is not working as intended at first glance. I'm using CuPy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is our ecosystem currently shimming around There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
On a single GPU machine, cupy has only one device and the test introduced by this PR is skipped.
Yes, by array-api-compat. |
||
|
||
|
||
class TestSoftmax: | ||
def test_softmax_fixtures(self, xp): | ||
|
Uh oh!
There was an error while loading. Please reload this page.