Skip to content

BUG/TST: special.logsumexp on non-default device #22756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions scipy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from scipy._lib._fpumode import get_fpu_mode
from scipy._lib._array_api import (
SCIPY_ARRAY_API, SCIPY_DEVICE, array_namespace, default_xp
SCIPY_ARRAY_API, SCIPY_DEVICE, array_namespace, default_xp,
is_cupy, is_dask, is_jax, is_torch, xp_device,
)
from scipy._lib._testutils import FPUModeChangeWarning
from scipy._lib.array_api_extra.testing import patch_lazy_xp_functions
Expand Down Expand Up @@ -154,9 +155,10 @@ def num_parallel_threads():
try:
import torch # type: ignore[import-not-found]
xp_available_backends.update({'torch': torch})
# can use `mps` or `cpu`
torch.set_default_device(SCIPY_DEVICE)
if SCIPY_DEVICE != "cpu":
# FIXME don't set this when SCIPY_DEVICE == "cpu"
# as a workaround to pytorch/pytorch#150199
torch.set_default_device(SCIPY_DEVICE)
xp_skip_cpu_only_backends.add('torch')

# default to float64 unless explicitly requested
Expand Down Expand Up @@ -409,6 +411,51 @@ def skip_or_xfail_xp_backends(request: pytest.FixtureRequest,
skip_or_xfail(reason=reason)


@pytest.fixture
def nondefault_device(xp):
"""Fixture that returns a device other than the default device for the backend.
Used to test input->output device propagation.

Usage
-----
from scipy._lib._array_api import xp_device

def test(xp, nondefault_device):
a = xp.asarray(..., device=nondefault_device)
b = f(a)
assert xp_device(b) == nondefault_device
"""
if is_cupy(xp):
pytest.xfail(reason="data-apis/array-api-compat#293")
if is_dask(xp):
pytest.skip(reason="dummy device from array-api-compat does not propagate")
if is_jax(xp):
# The .device attribute is not accessible inside jax.jit; the consequence
# (downstream of array-api-compat hacks) is that a non-default device in
# input is not guaranteed to propagate to the output even if the scipy code
# states `device=xp_device(arg)`` in all array creation functions.
# While this issue is specific to jax.jit, it would be unnecessarily
# verbose to skip the test for each jit-capable function and run it for
# those that only support eager mode.
# Additionally, devices() below only returns CudaDevice objects when
# CUDA is enabled, which prevents us from running this test on CPU vs. GPU.
pytest.xfail(reason="jax-ml/jax#26000 + jax-ml/jax#27606")
if is_torch(xp) and SCIPY_DEVICE != "cpu":
# Note workaround when parsing SCIPY_DEVICE above.
# Also note that when SCIPY_DEVICE=cpu this test won't run in CI
# because CUDA-enabled CI boxes always use SCIPY_DEVICE=cuda.
pytest.xfail(reason="pytorch/pytorch#150199")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


devices = xp.__array_namespace_info__().devices()
# Don't use xp.__array_namespace_info__().default_device():
# https://github.com/data-apis/array-api/issues/835
default = xp_device(xp.empty(()))
try:
return next(iter(d for d in devices if d != default))
except StopIteration:
pytest.skip(reason="Only one device available")


# Following the approach of NumPy's conftest.py...
# Use a known and persistent tmpdir for hypothesis' caches, which
# can be automatically cleared by the OS or user.
Expand Down
5 changes: 3 additions & 2 deletions scipy/special/_logsumexp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from scipy._lib._array_api import (
array_namespace,
xp_device,
xp_size,
xp_promote,
xp_float_to_complex,
Expand Down Expand Up @@ -133,7 +134,7 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
else:
shape = np.asarray(a.shape) # NumPy is convenient for shape manipulation
shape[axis] = 1
out = xp.full(tuple(shape), -xp.inf, dtype=a.dtype)
out = xp.full(tuple(shape), -xp.inf, dtype=a.dtype, device=xp_device(a))
sgn = xp.sign(out)

if xp.isdtype(out.dtype, 'complex floating'):
Expand Down Expand Up @@ -184,7 +185,7 @@ def _elements_and_indices_with_max_real(a, axis=-1, xp=None):
# Of those, choose one arbitrarily. This is a reasonably
# simple, array-API compatible way of doing so that doesn't
# have a problem with `axis` being a tuple or None.
i = xp.reshape(xp.arange(xp_size(a)), a.shape)
i = xp.reshape(xp.arange(xp_size(a), device=xp_device(a)), a.shape)
i = xpx.at(i, ~mask).set(-1)
max_i = xp.max(i, axis=axis, keepdims=True)
mask = i == max_i
Expand Down
9 changes: 8 additions & 1 deletion scipy/special/tests/test_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from scipy._lib._array_api import is_array_api_strict, xp_default_dtype
from scipy._lib._array_api import is_array_api_strict, xp_default_dtype, xp_device
from scipy._lib._array_api_no_0d import (xp_assert_equal, xp_assert_close,
xp_assert_less)

Expand Down Expand Up @@ -292,6 +292,13 @@ def test_gh22601_infinite_elements(self, x, y, xp):
ref = xp.log(xp.sum(xp.exp(xp.asarray([x, y]))))
xp_assert_equal(res, ref)

@pytest.mark.parametrize("x", [1.0, 1.0j, []])
def test_device(self, x, xp, nondefault_device):
"""Test input device propagation to output."""
x = xp.asarray(x, device=nondefault_device)
assert xp_device(logsumexp(x)) == nondefault_device
assert xp_device(logsumexp(x, b=x)) == nondefault_device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first, if you're working on a machine with a single GPU for example, this PR may appear to be fairly straightforward.

However, when I test on a node/machine that has multiple GPUs and use i.e., SCIPY_DEVICE=cuda python dev.py test -t scipy/special/tests/test_logsumexp.py::TestLogSumExp::test_device -b cupy

this test will currently fail on this branch:

scipy/special/tests/test_logsumexp.py:299: in test_device
    assert xp_device(logsumexp(x)) == nondefault_device
E   assert <CUDA Device 0> == <CUDA Device 1>

Does CuPy require special treatment? Do multiple GPUs require special treatment? It isn't immediately obvious to me, but the nature of the failure suggests that device propagation is not working as intended at first glance.

I'm using CuPy 13.3.0, which is fairly recent. I could try bumping the version maybe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is our ecosystem currently shimming around cupy.cuda.Device to compensate for not having the device kwarg on the array coercion for CuPy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first, if you're working on a machine with a single GPU for example, this PR may appear to be fairly straightforward.

On a single GPU machine, cupy has only one device and the test introduced by this PR is skipped.

Is our ecosystem currently shimming around cupy.cuda.Device to compensate for not having the device kwarg on the array coercion for CuPy?

Yes, by array-api-compat.
Looks like multi-device support was not thought through: data-apis/array-api-compat#293



class TestSoftmax:
def test_softmax_fixtures(self, xp):
Expand Down
Loading