Skip to content

Commit 78a0682

Browse files
committed
CuPy multi-device support
1 parent b5a57eb commit 78a0682

File tree

9 files changed

+134
-98
lines changed

9 files changed

+134
-98
lines changed

array_api_compat/common/_aliases.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import inspect
88
from typing import NamedTuple, Optional, Sequence, Tuple, Union
99

10-
from ._helpers import array_namespace, _check_device, device, is_cupy_namespace
10+
from ._helpers import array_namespace, device, is_cupy_namespace, _device_ctx
1111
from ._typing import Array, Device, DType, Namespace
1212

1313
# These functions are modified from the NumPy versions.
1414

15-
# Creation functions add the device keyword (which does nothing for NumPy)
15+
# Creation functions add the device keyword (which does nothing for NumPy and Dask)
1616

1717
def arange(
1818
start: Union[int, float],
@@ -25,8 +25,8 @@ def arange(
2525
device: Optional[Device] = None,
2626
**kwargs,
2727
) -> Array:
28-
_check_device(xp, device)
29-
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
28+
with _device_ctx(xp, device):
29+
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
3030

3131
def empty(
3232
shape: Union[int, Tuple[int, ...]],
@@ -36,8 +36,8 @@ def empty(
3636
device: Optional[Device] = None,
3737
**kwargs,
3838
) -> Array:
39-
_check_device(xp, device)
40-
return xp.empty(shape, dtype=dtype, **kwargs)
39+
with _device_ctx(xp, device):
40+
return xp.empty(shape, dtype=dtype, **kwargs)
4141

4242
def empty_like(
4343
x: Array,
@@ -48,8 +48,8 @@ def empty_like(
4848
device: Optional[Device] = None,
4949
**kwargs,
5050
) -> Array:
51-
_check_device(xp, device)
52-
return xp.empty_like(x, dtype=dtype, **kwargs)
51+
with _device_ctx(xp, device, like=x):
52+
return xp.empty_like(x, dtype=dtype, **kwargs)
5353

5454
def eye(
5555
n_rows: int,
@@ -62,8 +62,8 @@ def eye(
6262
device: Optional[Device] = None,
6363
**kwargs,
6464
) -> Array:
65-
_check_device(xp, device)
66-
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
65+
with _device_ctx(xp, device):
66+
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
6767

6868
def full(
6969
shape: Union[int, Tuple[int, ...]],
@@ -74,8 +74,8 @@ def full(
7474
device: Optional[Device] = None,
7575
**kwargs,
7676
) -> Array:
77-
_check_device(xp, device)
78-
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
77+
with _device_ctx(xp, device):
78+
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
7979

8080
def full_like(
8181
x: Array,
@@ -87,8 +87,8 @@ def full_like(
8787
device: Optional[Device] = None,
8888
**kwargs,
8989
) -> Array:
90-
_check_device(xp, device)
91-
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
90+
with _device_ctx(xp, device, like=x):
91+
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
9292

9393
def linspace(
9494
start: Union[int, float],
@@ -102,8 +102,8 @@ def linspace(
102102
endpoint: bool = True,
103103
**kwargs,
104104
) -> Array:
105-
_check_device(xp, device)
106-
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
105+
with _device_ctx(xp, device):
106+
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
107107

108108
def ones(
109109
shape: Union[int, Tuple[int, ...]],
@@ -113,8 +113,8 @@ def ones(
113113
device: Optional[Device] = None,
114114
**kwargs,
115115
) -> Array:
116-
_check_device(xp, device)
117-
return xp.ones(shape, dtype=dtype, **kwargs)
116+
with _device_ctx(xp, device):
117+
return xp.ones(shape, dtype=dtype, **kwargs)
118118

119119
def ones_like(
120120
x: Array,
@@ -125,8 +125,8 @@ def ones_like(
125125
device: Optional[Device] = None,
126126
**kwargs,
127127
) -> Array:
128-
_check_device(xp, device)
129-
return xp.ones_like(x, dtype=dtype, **kwargs)
128+
with _device_ctx(xp, device, like=x):
129+
return xp.ones_like(x, dtype=dtype, **kwargs)
130130

131131
def zeros(
132132
shape: Union[int, Tuple[int, ...]],
@@ -136,8 +136,8 @@ def zeros(
136136
device: Optional[Device] = None,
137137
**kwargs,
138138
) -> Array:
139-
_check_device(xp, device)
140-
return xp.zeros(shape, dtype=dtype, **kwargs)
139+
with _device_ctx(xp, device):
140+
return xp.zeros(shape, dtype=dtype, **kwargs)
141141

142142
def zeros_like(
143143
x: Array,
@@ -148,8 +148,8 @@ def zeros_like(
148148
device: Optional[Device] = None,
149149
**kwargs,
150150
) -> Array:
151-
_check_device(xp, device)
152-
return xp.zeros_like(x, dtype=dtype, **kwargs)
151+
with _device_ctx(xp, device, like=x):
152+
return xp.zeros_like(x, dtype=dtype, **kwargs)
153153

154154
# np.unique() is split into four functions in the array API:
155155
# unique_all, unique_counts, unique_inverse, and unique_values (this is done

array_api_compat/common/_helpers.py

+60-34
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
"""
88
from __future__ import annotations
99

10+
import contextlib
1011
import sys
1112
import math
1213
import inspect
1314
import warnings
15+
from collections.abc import Generator
16+
from types import ModuleType
1417
from typing import Optional, Union, Any
1518

1619
from ._typing import Array, Device, Namespace
@@ -595,10 +598,6 @@ def your_function(x, y):
595598
# backwards compatibility alias
596599
get_namespace = array_namespace
597600

598-
def _check_device(xp, device):
599-
if xp == sys.modules.get('numpy'):
600-
if device not in ["cpu", None]:
601-
raise ValueError(f"Unsupported device for NumPy: {device!r}")
602601

603602
# Placeholder object to represent the dask device
604603
# when the array backend is not the CPU.
@@ -609,6 +608,7 @@ def __repr__(self):
609608

610609
_DASK_DEVICE = _dask_device()
611610

611+
612612
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
613613
# or cupy.ndarray. They are not included in array objects of this library
614614
# because this library just reuses the respective ndarray classes without
@@ -685,50 +685,39 @@ def device(x: Array, /) -> Device:
685685
# Prevent shadowing, used below
686686
_device = device
687687

688+
688689
# Based on cupy.array_api.Array.to_device
689690
def _cupy_to_device(x, device, /, stream=None):
690691
import cupy as cp
691-
from cupy.cuda import Device as _Device
692-
from cupy.cuda import stream as stream_module
693-
from cupy_backends.cuda.api import runtime
694692

695-
if device == x.device:
696-
return x
697-
elif device == "cpu":
693+
if device == "cpu":
698694
# allowing us to use `to_device(x, "cpu")`
699695
# is useful for portable test swapping between
700696
# host and device backends
701697
return x.get()
702-
elif not isinstance(device, _Device):
703-
raise ValueError(f"Unsupported device {device!r}")
704-
else:
705-
# see cupy/cupy#5985 for the reason how we handle device/stream here
706-
prev_device = runtime.getDevice()
707-
prev_stream: stream_module.Stream = None
708-
if stream is not None:
709-
prev_stream = stream_module.get_current_stream()
710-
# stream can be an int as specified in __dlpack__, or a CuPy stream
711-
if isinstance(stream, int):
712-
stream = cp.cuda.ExternalStream(stream)
713-
elif isinstance(stream, cp.cuda.Stream):
714-
pass
715-
else:
716-
raise ValueError('the input stream is not recognized')
717-
stream.use()
718-
try:
719-
runtime.setDevice(device.id)
720-
arr = x.copy()
721-
finally:
722-
runtime.setDevice(prev_device)
723-
if stream is not None:
724-
prev_stream.use()
725-
return arr
698+
if not isinstance(device, cp.cuda.Device):
699+
raise TypeError(f"Unsupported device {device!r}")
700+
701+
# see cupy/cupy#5985 for the reason how we handle device/stream here
702+
703+
# stream can be an int as specified in __dlpack__, or a CuPy stream
704+
if isinstance(stream, int):
705+
stream = cp.cuda.ExternalStream(stream)
706+
elif stream is None:
707+
stream = contextlib.nullcontext()
708+
elif not isinstance(stream, cp.cuda.Stream):
709+
raise TypeError('the input stream is not recognized')
710+
711+
with device, stream:
712+
return cp.asarray(x)
713+
726714

727715
def _torch_to_device(x, device, /, stream=None):
728716
if stream is not None:
729717
raise NotImplementedError
730718
return x.to(device)
731719

720+
732721
def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
733722
"""
734723
Copy the array from the device on which it currently resides to the specified ``device``.
@@ -811,6 +800,43 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
811800
return x.to_device(device, stream=stream)
812801

813802

803+
def _device_ctx(
804+
bare_xp: ModuleType, device: Device, like: Array | None = None
805+
) -> Generator[None]:
806+
"""Context manager which changes the current device in CuPy.
807+
808+
Used internally by array creation functions in common._aliases.
809+
"""
810+
if device is None:
811+
if like is None:
812+
return contextlib.nullcontext()
813+
device = _device(like)
814+
815+
if bare_xp is sys.modules.get('numpy'):
816+
if device != "cpu":
817+
raise ValueError(f"Unsupported device for NumPy: {device!r}")
818+
return contextlib.nullcontext()
819+
820+
if bare_xp is sys.modules.get('dask.array'):
821+
if device not in ("cpu", _DASK_DEVICE):
822+
raise ValueError(f"Unsupported device for Dask: {device!r}")
823+
return contextlib.nullcontext()
824+
825+
if bare_xp is sys.modules.get('cupy'):
826+
if not isinstance(device, bare_xp.cuda.Device):
827+
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
828+
return device
829+
830+
# PyTorch doesn't have a "current device" context manager and you
831+
# can't use array creation functions from common._aliases.
832+
raise AssertionError("unreachable") # pragma: nocover
833+
834+
835+
def _validate_device(bare_xp: ModuleType, device: Device) -> None:
836+
with _device_ctx(bare_xp, device):
837+
pass
838+
839+
814840
def size(x: Array) -> int | None:
815841
"""
816842
Return the total number of elements of x.

array_api_compat/cupy/_aliases.py

+9-23
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@
6262
tensordot = get_xp(cp)(_aliases.tensordot)
6363
sign = get_xp(cp)(_aliases.sign)
6464

65-
_copy_default = object()
66-
6765

6866
# asarray also adds the copy keyword, which is not present in numpy 1.0.
6967
def asarray(
@@ -74,7 +72,7 @@ def asarray(
7472
*,
7573
dtype: Optional[DType] = None,
7674
device: Optional[Device] = None,
77-
copy: Optional[bool] = _copy_default,
75+
copy: Optional[bool] = None,
7876
**kwargs,
7977
) -> Array:
8078
"""
@@ -83,26 +81,14 @@ def asarray(
8381
See the corresponding documentation in the array library and/or the array API
8482
specification for more details.
8583
"""
86-
with cp.cuda.Device(device):
87-
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
88-
# in asarray in numpy/_aliases.py.
89-
if copy is not _copy_default:
90-
# A future version of CuPy will change the meaning of copy=False
91-
# to mean no-copy. We don't know for certain what version it will
92-
# be yet, so to avoid breaking that version, we use a different
93-
# default value for copy so asarray(obj) with no copy kwarg will
94-
# always do the copy-if-needed behavior.
95-
96-
# This will still need to be updated to remove the
97-
# NotImplementedError for copy=False, but at least this won't
98-
# break the default or existing behavior.
99-
if copy is None:
100-
copy = False
101-
elif copy is False:
102-
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
103-
kwargs['copy'] = copy
104-
105-
return cp.array(obj, dtype=dtype, **kwargs)
84+
if copy is False:
85+
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
86+
87+
with _helpers._device_ctx(cp, device):
88+
if copy is None:
89+
return cp.asarray(obj, dtype=dtype, **kwargs)
90+
else:
91+
return cp.array(obj, dtype=dtype, copy=True, **kwargs)
10692

10793

10894
def astype(

array_api_compat/cupy/_info.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
complex128,
2727
)
2828

29+
2930
class __array_namespace_info__:
3031
"""
3132
Get the array API inspection namespace for CuPy.
@@ -117,7 +118,7 @@ def default_device(self):
117118
118119
Returns
119120
-------
120-
device : str
121+
device : Device
121122
The default device used for new CuPy arrays.
122123
123124
Examples
@@ -126,6 +127,15 @@ def default_device(self):
126127
>>> info.default_device()
127128
Device(0)
128129
130+
Notes
131+
-----
132+
This method returns the static default device when CuPy is initialized.
133+
However, the *current* device used by creation functions (``empty`` etc.)
134+
can be changed globally or with a context manager.
135+
136+
See Also
137+
--------
138+
https://github.com/data-apis/array-api/issues/835
129139
"""
130140
return cuda.Device(0)
131141

@@ -312,7 +322,7 @@ def devices(self):
312322
313323
Returns
314324
-------
315-
devices : list of str
325+
devices : list[Device]
316326
The devices supported by CuPy.
317327
318328
See Also

0 commit comments

Comments
 (0)