diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 8dfa09f..ff43660 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -309,6 +309,9 @@ __all__ += ["all", "any"] +from ._array_object import Device +__all__ += ["Device"] + # Helper functions that are not part of the standard from ._flags import ( diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index d8ed018..cd9a360 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -43,13 +43,26 @@ import numpy as np -# Placeholder object to represent the "cpu" device (the only device NumPy -# supports). -class _cpu_device: +class Device: + def __init__(self, device="CPU_DEVICE"): + if device not in ("CPU_DEVICE", "device1", "device2"): + raise ValueError(f"The device '{device}' is not a valid choice.") + self._device = device + def __repr__(self): - return "CPU_DEVICE" + return f"array_api_strict.Device('{self._device}')" + + def __eq__(self, other): + if not isinstance(other, Device): + return False + return self._device == other._device + + def __hash__(self): + return hash(("Device", self._device)) + -CPU_DEVICE = _cpu_device() +CPU_DEVICE = Device() +ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2")) _default = object() @@ -73,7 +86,7 @@ class Array: # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. @classmethod - def _new(cls, x, /): + def _new(cls, x, /, device): """ This is a private method for initializing the array API Array object. @@ -95,6 +108,9 @@ def _new(cls, x, /): ) obj._array = x obj._dtype = _dtype + if device is None: + device = CPU_DEVICE + obj._device = device return obj # Prevent Array() from working @@ -116,7 +132,11 @@ def __repr__(self: Array, /) -> str: """ Performs the operation __repr__. """ - suffix = f", dtype={self.dtype})" + suffix = f", dtype={self.dtype}" + if self.device != CPU_DEVICE: + suffix += f", device={self.device})" + else: + suffix += ")" if 0 in self.shape: prefix = "empty(" mid = str(self.shape) @@ -134,6 +154,8 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None will be present in other implementations. """ + if self._device != CPU_DEVICE: + raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") # copy keyword is new in 2.0.0; for older versions don't use it # retry without that keyword. if np.__version__[0] < '2': @@ -193,6 +215,14 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor return other + def _check_device(self, other): + """Check that other is on a device compatible with the current array""" + if isinstance(other, (int, complex, float, bool)): + return + elif isinstance(other, Array): + if self.device != other.device: + raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") + # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): """ @@ -244,7 +274,7 @@ def _promote_scalar(self, scalar): # behavior for integers within the bounds of the integer dtype. # Outside of those bounds we use the default NumPy behavior (either # cast or raise OverflowError). - return Array._new(np.array(scalar, dtype=self.dtype._np_dtype)) + return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=CPU_DEVICE) @staticmethod def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: @@ -276,9 +306,9 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: # performant. broadcast_to(x1._array, x2.shape) is much slower. We # could also manually type promote x2, but that is more complicated # and about the same performance as this. - x1 = Array._new(x1._array[None]) + x1 = Array._new(x1._array[None], device=x1.device) elif x2.ndim == 0 and x1.ndim != 0: - x2 = Array._new(x2._array[None]) + x2 = Array._new(x2._array[None], device=x2.device) return (x1, x2) # Note: A large fraction of allowed indices are disallowed here (see the @@ -462,29 +492,31 @@ def __abs__(self: Array, /) -> Array: if self.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in __abs__") res = self._array.__abs__() - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __add__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __add__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__add__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__add__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __and__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__and__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__and__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __array_namespace__( self: Array, /, *, api_version: Optional[str] = None @@ -568,6 +600,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ Performs the operation __eq__. """ + self._check_device(other) # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. other = self._check_allowed_dtypes(other, "all", "__eq__") @@ -575,7 +608,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: return other self, other = self._normalize_two_args(self, other) res = self._array.__eq__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __float__(self: Array, /) -> float: """ @@ -593,23 +626,25 @@ def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __floordiv__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__floordiv__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ge__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__ge__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__ge__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __getitem__( self: Array, @@ -625,6 +660,7 @@ def __getitem__( """ Performs the operation __getitem__. """ + # XXX Does key have to be on the same device? Is there an exception for CPU_DEVICE? # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index self._validate_index(key) @@ -632,18 +668,19 @@ def __getitem__( # Indexing self._array with array_api_strict arrays can be erroneous key = key._array res = self._array.__getitem__(key) - return self._new(res) + return self._new(res, device=self.device) def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __gt__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__gt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__gt__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=other.device) def __int__(self: Array, /) -> int: """ @@ -671,7 +708,7 @@ def __invert__(self: Array, /) -> Array: if self.dtype not in _integer_or_boolean_dtypes: raise TypeError("Only integer or boolean dtypes are allowed in __invert__") res = self._array.__invert__() - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __iter__(self: Array, /): """ @@ -686,85 +723,92 @@ def __iter__(self: Array, /): # define __iter__, but it doesn't disallow it. The default Python # behavior is to implement iter as a[0], a[1], ... when __getitem__ is # implemented, which implies iteration on 1-D arrays. - return (Array._new(i) for i in self._array) + return (Array._new(i, device=self.device) for i in self._array) def __le__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __le__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__le__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__le__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __lshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __lshift__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__lshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__lshift__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __lt__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__lt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__lt__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __matmul__(self: Array, other: Array, /) -> Array: """ Performs the operation __matmul__. """ + self._check_device(other) # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._check_allowed_dtypes(other, "numeric", "__matmul__") if other is NotImplemented: return other res = self._array.__matmul__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __mod__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __mod__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__mod__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__mod__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __mul__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __mul__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__mul__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__mul__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ Performs the operation __ne__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__ne__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __neg__(self: Array, /) -> Array: """ @@ -773,18 +817,19 @@ def __neg__(self: Array, /) -> Array: if self.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in __neg__") res = self._array.__neg__() - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __or__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __or__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__or__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__or__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __pos__(self: Array, /) -> Array: """ @@ -793,7 +838,7 @@ def __pos__(self: Array, /) -> Array: if self.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in __pos__") res = self._array.__pos__() - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: """ @@ -801,6 +846,7 @@ def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: """ from ._elementwise_functions import pow + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__pow__") if other is NotImplemented: return other @@ -812,12 +858,13 @@ def __rshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __rshift__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__rshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__rshift__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __setitem__( self, @@ -842,12 +889,13 @@ def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __sub__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__sub__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__sub__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) # PEP 484 requires int to be a subtype of float, but __truediv__ should # not accept int. @@ -855,28 +903,31 @@ def __truediv__(self: Array, other: Union[float, Array], /) -> Array: """ Performs the operation __truediv__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "floating-point", "__truediv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__truediv__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __xor__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__xor__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __iadd__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __iadd__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__iadd__") if other is NotImplemented: return other @@ -887,17 +938,19 @@ def __radd__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __radd__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__radd__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__radd__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __iand__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __iand__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__") if other is NotImplemented: return other @@ -908,17 +961,19 @@ def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __rand__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__rand__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ifloordiv__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__") if other is NotImplemented: return other @@ -929,17 +984,19 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __rfloordiv__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__rfloordiv__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ilshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __ilshift__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__ilshift__") if other is NotImplemented: return other @@ -950,12 +1007,13 @@ def __rlshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __rlshift__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__rlshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__rlshift__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __imatmul__(self: Array, other: Array, /) -> Array: """ @@ -966,8 +1024,9 @@ def __imatmul__(self: Array, other: Array, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") if other is NotImplemented: return other + self._check_device(other) res = self._array.__imatmul__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __rmatmul__(self: Array, other: Array, /) -> Array: """ @@ -978,8 +1037,9 @@ def __rmatmul__(self: Array, other: Array, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__") if other is NotImplemented: return other + self._check_device(other) res = self._array.__rmatmul__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __imod__(self: Array, other: Union[int, float, Array], /) -> Array: """ @@ -998,9 +1058,10 @@ def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array: other = self._check_allowed_dtypes(other, "real numeric", "__rmod__") if other is NotImplemented: return other + self._check_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rmod__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __imul__(self: Array, other: Union[int, float, Array], /) -> Array: """ @@ -1019,9 +1080,10 @@ def __rmul__(self: Array, other: Union[int, float, Array], /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rmul__") if other is NotImplemented: return other + self._check_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rmul__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array: """ @@ -1037,12 +1099,13 @@ def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __ror__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__ror__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array: """ @@ -1084,9 +1147,10 @@ def __rrshift__(self: Array, other: Union[int, Array], /) -> Array: other = self._check_allowed_dtypes(other, "integer", "__rrshift__") if other is NotImplemented: return other + self._check_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rrshift__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __isub__(self: Array, other: Union[int, float, Array], /) -> Array: """ @@ -1105,9 +1169,10 @@ def __rsub__(self: Array, other: Union[int, float, Array], /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rsub__") if other is NotImplemented: return other + self._check_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rsub__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __itruediv__(self: Array, other: Union[float, Array], /) -> Array: """ @@ -1126,9 +1191,10 @@ def __rtruediv__(self: Array, other: Union[float, Array], /) -> Array: other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__") if other is NotImplemented: return other + self._check_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rtruediv__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array: """ @@ -1147,15 +1213,19 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array: other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__") if other is NotImplemented: return other + self._check_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rxor__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def to_device(self: Array, device: Device, /, stream: None = None) -> Array: if stream is not None: raise ValueError("The stream argument to to_device() is not supported") - if device == CPU_DEVICE: + if device == self._device: return self + elif isinstance(device, Device): + arr = np.asarray(self._array, copy=True) + return self.__class__._new(arr, device=device) raise ValueError(f"Unsupported device {device!r}") @property @@ -1169,7 +1239,7 @@ def dtype(self) -> Dtype: @property def device(self) -> Device: - return CPU_DEVICE + return self._device # Note: mT is new in array API spec (see matrix_transpose) @property @@ -1216,4 +1286,4 @@ def T(self) -> Array: # https://data-apis.org/array-api/latest/API_specification/array_object.html#t if self.ndim != 2: raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.") - return self.__class__._new(self._array.T) + return self.__class__._new(self._array.T, device=self.device) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 67ba67c..a46c7a8 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -32,9 +32,12 @@ def _supports_buffer_protocol(obj): def _check_device(device): # _array_object imports in this file are inside the functions to avoid # circular imports - from ._array_object import CPU_DEVICE + from ._array_object import Device, ALL_DEVICES - if device not in [CPU_DEVICE, None]: + if device is not None and not isinstance(device, Device): + raise ValueError(f"Unsupported device {device!r}") + + if device is not None and device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") def asarray( @@ -76,10 +79,10 @@ def asarray( new_array = np.array(obj._array, copy=copy, dtype=_np_dtype) if new_array is not obj._array: raise ValueError("Unable to avoid copy while creating an array from given array.") - return Array._new(new_array) + return Array._new(new_array, device=device) elif _supports_buffer_protocol(obj): # Buffer protocol will always support no-copy - return Array._new(np.array(obj, copy=copy, dtype=_np_dtype)) + return Array._new(np.array(obj, copy=copy, dtype=_np_dtype), device=device) else: # No-copy is unsupported for Python built-in types. raise ValueError("Unable to avoid copy while creating an array from given object.") @@ -89,13 +92,13 @@ def asarray( copy = False if isinstance(obj, Array): - return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype)) + return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device) if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)): # Give a better error message in this case. NumPy would convert this # to an object array. TODO: This won't handle large integers in lists. raise OverflowError("Integer out of bounds for array dtypes") res = np.array(obj, dtype=_np_dtype, copy=copy) - return Array._new(res) + return Array._new(res, device=device) def arange( @@ -119,7 +122,7 @@ def arange( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype)) + return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype), device=device) def empty( @@ -140,7 +143,7 @@ def empty( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.empty(shape, dtype=dtype)) + return Array._new(np.empty(shape, dtype=dtype), device=device) def empty_like( @@ -158,7 +161,7 @@ def empty_like( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.empty_like(x._array, dtype=dtype)) + return Array._new(np.empty_like(x._array, dtype=dtype), device=device) def eye( @@ -182,7 +185,7 @@ def eye( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype)) + return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype), device=device) _default = object() @@ -208,7 +211,7 @@ def from_dlpack( if copy not in [_default, None]: raise NotImplementedError("The copy argument to from_dlpack is not yet implemented") - return Array._new(np.from_dlpack(x)) + return Array._new(np.from_dlpack(x), device=device) def full( @@ -237,7 +240,7 @@ def full( # This will happen if the fill value is not something that NumPy # coerces to one of the acceptable dtypes. raise TypeError("Invalid input to full") - return Array._new(res) + return Array._new(res, device=device) def full_like( @@ -265,7 +268,7 @@ def full_like( # This will happen if the fill value is not something that NumPy # coerces to one of the acceptable dtypes. raise TypeError("Invalid input to full_like") - return Array._new(res) + return Array._new(res, device=device) def linspace( @@ -290,7 +293,7 @@ def linspace( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) + return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint), device=device) def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: @@ -307,8 +310,17 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: if len({a.dtype for a in arrays}) > 1: raise ValueError("meshgrid inputs must all have the same dtype") + if len({a.device for a in arrays}) > 1: + raise ValueError("meshgrid inputs must all be on the same device") + + # arrays is allowed to be empty + if arrays: + device = arrays[0].device + else: + device = None + return [ - Array._new(array) + Array._new(array, device=device) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing) ] @@ -331,7 +343,7 @@ def ones( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.ones(shape, dtype=dtype)) + return Array._new(np.ones(shape, dtype=dtype), device=device) def ones_like( @@ -346,10 +358,12 @@ def ones_like( _check_valid_dtype(dtype) _check_device(device) + if device is None: + device = x.device if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.ones_like(x._array, dtype=dtype)) + return Array._new(np.ones_like(x._array, dtype=dtype), device=device) def tril(x: Array, /, *, k: int = 0) -> Array: @@ -363,7 +377,7 @@ def tril(x: Array, /, *, k: int = 0) -> Array: if x.ndim < 2: # Note: Unlike np.tril, x must be at least 2-D raise ValueError("x must be at least 2-dimensional for tril") - return Array._new(np.tril(x._array, k=k)) + return Array._new(np.tril(x._array, k=k), device=x.device) def triu(x: Array, /, *, k: int = 0) -> Array: @@ -377,7 +391,7 @@ def triu(x: Array, /, *, k: int = 0) -> Array: if x.ndim < 2: # Note: Unlike np.triu, x must be at least 2-D raise ValueError("x must be at least 2-dimensional for triu") - return Array._new(np.triu(x._array, k=k)) + return Array._new(np.triu(x._array, k=k), device=x.device) def zeros( @@ -398,7 +412,7 @@ def zeros( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.zeros(shape, dtype=dtype)) + return Array._new(np.zeros(shape, dtype=dtype), device=device) def zeros_like( @@ -413,7 +427,9 @@ def zeros_like( _check_valid_dtype(dtype) _check_device(device) + if device is None: + device = x.device if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.zeros_like(x._array, dtype=dtype)) + return Array._new(np.zeros_like(x._array, dtype=dtype), device=device) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 3405710..046dfc7 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -37,10 +37,12 @@ def astype( _check_device(device) else: raise TypeError("The device argument to astype requires at least version 2023.12 of the array API") + else: + device = x.device if not copy and dtype == x.dtype: return x - return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy)) + return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy), device=device) def broadcast_arrays(*arrays: Array) -> List[Array]: @@ -52,7 +54,7 @@ def broadcast_arrays(*arrays: Array) -> List[Array]: from ._array_object import Array return [ - Array._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays]) + Array._new(array, device=arrays[0].device) for array in np.broadcast_arrays(*[a._array for a in arrays]) ] @@ -64,7 +66,7 @@ def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: """ from ._array_object import Array - return Array._new(np.broadcast_to(x._array, shape)) + return Array._new(np.broadcast_to(x._array, shape), device=x.device) def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index a91454f..b51ed92 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -121,7 +121,7 @@ def __hash__(self): "integer": _integer_dtypes, "integer or boolean": _integer_or_boolean_dtypes, "boolean": _boolean_dtypes, - "real floating-point": _floating_dtypes, + "real floating-point": _real_floating_dtypes, "complex floating-point": _complex_floating_dtypes, "floating-point": _floating_dtypes, } diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index ab1cbb7..761caff 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -28,7 +28,7 @@ def abs(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in abs") - return Array._new(np.abs(x._array)) + return Array._new(np.abs(x._array), device=x.device) # Note: the function name is different here @@ -40,7 +40,7 @@ def acos(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in acos") - return Array._new(np.arccos(x._array)) + return Array._new(np.arccos(x._array), device=x.device) # Note: the function name is different here @@ -52,7 +52,7 @@ def acosh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in acosh") - return Array._new(np.arccosh(x._array)) + return Array._new(np.arccosh(x._array), device=x.device) def add(x1: Array, x2: Array, /) -> Array: @@ -61,12 +61,15 @@ def add(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in add") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.add(x1._array, x2._array)) + return Array._new(np.add(x1._array, x2._array), device=x1.device) # Note: the function name is different here @@ -78,7 +81,7 @@ def asin(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in asin") - return Array._new(np.arcsin(x._array)) + return Array._new(np.arcsin(x._array), device=x.device) # Note: the function name is different here @@ -90,7 +93,7 @@ def asinh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in asinh") - return Array._new(np.arcsinh(x._array)) + return Array._new(np.arcsinh(x._array), device=x.device) # Note: the function name is different here @@ -102,7 +105,7 @@ def atan(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in atan") - return Array._new(np.arctan(x._array)) + return Array._new(np.arctan(x._array), device=x.device) # Note: the function name is different here @@ -112,12 +115,14 @@ def atan2(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in atan2") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.arctan2(x1._array, x2._array)) + return Array._new(np.arctan2(x1._array, x2._array), device=x1.device) # Note: the function name is different here @@ -129,7 +134,7 @@ def atanh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in atanh") - return Array._new(np.arctanh(x._array)) + return Array._new(np.arctanh(x._array), device=x.device) def bitwise_and(x1: Array, x2: Array, /) -> Array: @@ -138,6 +143,9 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if ( x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes @@ -146,7 +154,7 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array: # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_and(x1._array, x2._array)) + return Array._new(np.bitwise_and(x1._array, x2._array), device=x1.device) # Note: the function name is different here @@ -156,6 +164,9 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError("Only integer dtypes are allowed in bitwise_left_shift") # Call result type here just to raise on disallowed type combinations @@ -164,7 +175,7 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: # Note: bitwise_left_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") - return Array._new(np.left_shift(x1._array, x2._array)) + return Array._new(np.left_shift(x1._array, x2._array), device=x1.device) # Note: the function name is different here @@ -176,7 +187,7 @@ def bitwise_invert(x: Array, /) -> Array: """ if x.dtype not in _integer_or_boolean_dtypes: raise TypeError("Only integer or boolean dtypes are allowed in bitwise_invert") - return Array._new(np.invert(x._array)) + return Array._new(np.invert(x._array), device=x.device) def bitwise_or(x1: Array, x2: Array, /) -> Array: @@ -185,6 +196,9 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if ( x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes @@ -193,7 +207,7 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array: # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_or(x1._array, x2._array)) + return Array._new(np.bitwise_or(x1._array, x2._array), device=x1.device) # Note: the function name is different here @@ -203,6 +217,9 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError("Only integer dtypes are allowed in bitwise_right_shift") # Call result type here just to raise on disallowed type combinations @@ -211,7 +228,7 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: # Note: bitwise_right_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): raise ValueError("bitwise_right_shift(x1, x2) is only defined for x2 >= 0") - return Array._new(np.right_shift(x1._array, x2._array)) + return Array._new(np.right_shift(x1._array, x2._array), device=x1.device) def bitwise_xor(x1: Array, x2: Array, /) -> Array: @@ -220,6 +237,9 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if ( x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes @@ -228,7 +248,7 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array: # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_xor(x1._array, x2._array)) + return Array._new(np.bitwise_xor(x1._array, x2._array), device=x1.device) def ceil(x: Array, /) -> Array: @@ -242,7 +262,7 @@ def ceil(x: Array, /) -> Array: if x.dtype in _integer_dtypes: # Note: The return dtype of ceil is the same as the input return x - return Array._new(np.ceil(x._array)) + return Array._new(np.ceil(x._array), device=x.device) # WARNING: This function is not yet tested by the array-api-tests test suite. @@ -259,6 +279,11 @@ def clip( See its docstring for more information. """ + if isinstance(min, Array) and x.device != min.device: + raise ValueError(f"Arrays from two different devices ({x.device} and {min.device}) can not be combined.") + if isinstance(max, Array) and x.device != max.device: + raise ValueError(f"Arrays from two different devices ({x.device} and {max.device}) can not be combined.") + if (x.dtype not in _real_numeric_dtypes or isinstance(min, Array) and min.dtype not in _real_numeric_dtypes or isinstance(max, Array) and max.dtype not in _real_numeric_dtypes): @@ -307,7 +332,7 @@ def clip( # TODO: I'm not completely sure this always gives the correct thing # for integer dtypes. See https://github.com/numpy/numpy/issues/24976 result = result.astype(x.dtype._np_dtype) - return Array._new(result) + return Array._new(result, device=x.device) def conj(x: Array, /) -> Array: """ @@ -317,7 +342,7 @@ def conj(x: Array, /) -> Array: """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in conj") - return Array._new(np.conj(x._array)) + return Array._new(np.conj(x._array), device=x.device) @requires_api_version('2023.12') def copysign(x1: Array, x2: Array, /) -> Array: @@ -326,12 +351,15 @@ def copysign(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + + if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real numeric dtypes are allowed in copysign") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.copysign(x1._array, x2._array)) + return Array._new(np.copysign(x1._array, x2._array), device=x1.device) def cos(x: Array, /) -> Array: """ @@ -341,7 +369,7 @@ def cos(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in cos") - return Array._new(np.cos(x._array)) + return Array._new(np.cos(x._array), device=x.device) def cosh(x: Array, /) -> Array: @@ -352,7 +380,7 @@ def cosh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in cosh") - return Array._new(np.cosh(x._array)) + return Array._new(np.cosh(x._array), device=x.device) def divide(x1: Array, x2: Array, /) -> Array: @@ -361,12 +389,14 @@ def divide(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in divide") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.divide(x1._array, x2._array)) + return Array._new(np.divide(x1._array, x2._array), device=x1.device) def equal(x1: Array, x2: Array, /) -> Array: @@ -375,10 +405,12 @@ def equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.equal(x1._array, x2._array)) + return Array._new(np.equal(x1._array, x2._array), device=x1.device) def exp(x: Array, /) -> Array: @@ -389,7 +421,7 @@ def exp(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in exp") - return Array._new(np.exp(x._array)) + return Array._new(np.exp(x._array), device=x.device) def expm1(x: Array, /) -> Array: @@ -400,7 +432,7 @@ def expm1(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in expm1") - return Array._new(np.expm1(x._array)) + return Array._new(np.expm1(x._array), device=x.device) def floor(x: Array, /) -> Array: @@ -414,7 +446,7 @@ def floor(x: Array, /) -> Array: if x.dtype in _integer_dtypes: # Note: The return dtype of floor is the same as the input return x - return Array._new(np.floor(x._array)) + return Array._new(np.floor(x._array), device=x.device) def floor_divide(x1: Array, x2: Array, /) -> Array: @@ -423,12 +455,14 @@ def floor_divide(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in floor_divide") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.floor_divide(x1._array, x2._array)) + return Array._new(np.floor_divide(x1._array, x2._array), device=x1.device) def greater(x1: Array, x2: Array, /) -> Array: @@ -437,12 +471,14 @@ def greater(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.greater(x1._array, x2._array)) + return Array._new(np.greater(x1._array, x2._array), device=x1.device) def greater_equal(x1: Array, x2: Array, /) -> Array: @@ -451,12 +487,14 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater_equal") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.greater_equal(x1._array, x2._array)) + return Array._new(np.greater_equal(x1._array, x2._array), device=x1.device) @requires_api_version('2023.12') def hypot(x1: Array, x2: Array, /) -> Array: @@ -465,12 +503,14 @@ def hypot(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in hypot") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.hypot(x1._array, x2._array)) + return Array._new(np.hypot(x1._array, x2._array), device=x1.device) def imag(x: Array, /) -> Array: """ @@ -480,7 +520,7 @@ def imag(x: Array, /) -> Array: """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in imag") - return Array._new(np.imag(x._array)) + return Array._new(np.imag(x._array), device=x.device) def isfinite(x: Array, /) -> Array: @@ -491,7 +531,7 @@ def isfinite(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isfinite") - return Array._new(np.isfinite(x._array)) + return Array._new(np.isfinite(x._array), device=x.device) def isinf(x: Array, /) -> Array: @@ -502,7 +542,7 @@ def isinf(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isinf") - return Array._new(np.isinf(x._array)) + return Array._new(np.isinf(x._array), device=x.device) def isnan(x: Array, /) -> Array: @@ -513,7 +553,7 @@ def isnan(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isnan") - return Array._new(np.isnan(x._array)) + return Array._new(np.isnan(x._array), device=x.device) def less(x1: Array, x2: Array, /) -> Array: @@ -522,12 +562,14 @@ def less(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.less(x1._array, x2._array)) + return Array._new(np.less(x1._array, x2._array), device=x1.device) def less_equal(x1: Array, x2: Array, /) -> Array: @@ -536,12 +578,14 @@ def less_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less_equal") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.less_equal(x1._array, x2._array)) + return Array._new(np.less_equal(x1._array, x2._array), device=x1.device) def log(x: Array, /) -> Array: @@ -552,7 +596,7 @@ def log(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in log") - return Array._new(np.log(x._array)) + return Array._new(np.log(x._array), device=x.device) def log1p(x: Array, /) -> Array: @@ -563,7 +607,7 @@ def log1p(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in log1p") - return Array._new(np.log1p(x._array)) + return Array._new(np.log1p(x._array), device=x.device) def log2(x: Array, /) -> Array: @@ -574,7 +618,7 @@ def log2(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in log2") - return Array._new(np.log2(x._array)) + return Array._new(np.log2(x._array), device=x.device) def log10(x: Array, /) -> Array: @@ -585,21 +629,23 @@ def log10(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in log10") - return Array._new(np.log10(x._array)) + return Array._new(np.log10(x._array), device=x.device) -def logaddexp(x1: Array, x2: Array) -> Array: +def logaddexp(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logaddexp `. See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in logaddexp") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logaddexp(x1._array, x2._array)) + return Array._new(np.logaddexp(x1._array, x2._array), device=x1.device) def logical_and(x1: Array, x2: Array, /) -> Array: @@ -608,12 +654,14 @@ def logical_and(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_and") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_and(x1._array, x2._array)) + return Array._new(np.logical_and(x1._array, x2._array), device=x1.device) def logical_not(x: Array, /) -> Array: @@ -624,7 +672,7 @@ def logical_not(x: Array, /) -> Array: """ if x.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_not") - return Array._new(np.logical_not(x._array)) + return Array._new(np.logical_not(x._array), device=x.device) def logical_or(x1: Array, x2: Array, /) -> Array: @@ -633,12 +681,14 @@ def logical_or(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_or") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_or(x1._array, x2._array)) + return Array._new(np.logical_or(x1._array, x2._array), device=x1.device) def logical_xor(x1: Array, x2: Array, /) -> Array: @@ -647,12 +697,14 @@ def logical_xor(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_xor") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_xor(x1._array, x2._array)) + return Array._new(np.logical_xor(x1._array, x2._array), device=x1.device) @requires_api_version('2023.12') def maximum(x1: Array, x2: Array, /) -> Array: @@ -661,6 +713,8 @@ def maximum(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in maximum") # Call result type here just to raise on disallowed type combinations @@ -668,7 +722,7 @@ def maximum(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) # TODO: maximum(-0., 0.) is unspecified. Should we issue a warning/error # in that case? - return Array._new(np.maximum(x1._array, x2._array)) + return Array._new(np.maximum(x1._array, x2._array), device=x1.device) @requires_api_version('2023.12') def minimum(x1: Array, x2: Array, /) -> Array: @@ -677,12 +731,14 @@ def minimum(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in minimum") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.minimum(x1._array, x2._array)) + return Array._new(np.minimum(x1._array, x2._array), device=x1.device) def multiply(x1: Array, x2: Array, /) -> Array: """ @@ -690,12 +746,14 @@ def multiply(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in multiply") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.multiply(x1._array, x2._array)) + return Array._new(np.multiply(x1._array, x2._array), device=x1.device) def negative(x: Array, /) -> Array: @@ -706,7 +764,7 @@ def negative(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in negative") - return Array._new(np.negative(x._array)) + return Array._new(np.negative(x._array), device=x.device) def not_equal(x1: Array, x2: Array, /) -> Array: @@ -715,10 +773,12 @@ def not_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.not_equal(x1._array, x2._array)) + return Array._new(np.not_equal(x1._array, x2._array), device=x1.device) def positive(x: Array, /) -> Array: @@ -729,7 +789,7 @@ def positive(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in positive") - return Array._new(np.positive(x._array)) + return Array._new(np.positive(x._array), device=x.device) # Note: the function name is different here @@ -739,12 +799,14 @@ def pow(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in pow") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.power(x1._array, x2._array)) + return Array._new(np.power(x1._array, x2._array), device=x1.device) def real(x: Array, /) -> Array: @@ -755,7 +817,7 @@ def real(x: Array, /) -> Array: """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in real") - return Array._new(np.real(x._array)) + return Array._new(np.real(x._array), device=x.device) def remainder(x1: Array, x2: Array, /) -> Array: @@ -764,12 +826,14 @@ def remainder(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in remainder") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.remainder(x1._array, x2._array)) + return Array._new(np.remainder(x1._array, x2._array), device=x1.device) def round(x: Array, /) -> Array: @@ -780,7 +844,7 @@ def round(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in round") - return Array._new(np.round(x._array)) + return Array._new(np.round(x._array), device=x.device) def sign(x: Array, /) -> Array: @@ -791,7 +855,7 @@ def sign(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sign") - return Array._new(np.sign(x._array)) + return Array._new(np.sign(x._array), device=x.device) @requires_api_version('2023.12') @@ -803,7 +867,7 @@ def signbit(x: Array, /) -> Array: """ if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in signbit") - return Array._new(np.signbit(x._array)) + return Array._new(np.signbit(x._array), device=x.device) def sin(x: Array, /) -> Array: @@ -814,7 +878,7 @@ def sin(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in sin") - return Array._new(np.sin(x._array)) + return Array._new(np.sin(x._array), device=x.device) def sinh(x: Array, /) -> Array: @@ -825,7 +889,7 @@ def sinh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in sinh") - return Array._new(np.sinh(x._array)) + return Array._new(np.sinh(x._array), device=x.device) def square(x: Array, /) -> Array: @@ -836,7 +900,7 @@ def square(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in square") - return Array._new(np.square(x._array)) + return Array._new(np.square(x._array), device=x.device) def sqrt(x: Array, /) -> Array: @@ -847,7 +911,7 @@ def sqrt(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in sqrt") - return Array._new(np.sqrt(x._array)) + return Array._new(np.sqrt(x._array), device=x.device) def subtract(x1: Array, x2: Array, /) -> Array: @@ -856,12 +920,14 @@ def subtract(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in subtract") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.subtract(x1._array, x2._array)) + return Array._new(np.subtract(x1._array, x2._array), device=x1.device) def tan(x: Array, /) -> Array: @@ -872,7 +938,7 @@ def tan(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in tan") - return Array._new(np.tan(x._array)) + return Array._new(np.tan(x._array), device=x.device) def tanh(x: Array, /) -> Array: @@ -883,7 +949,7 @@ def tanh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in tanh") - return Array._new(np.tanh(x._array)) + return Array._new(np.tanh(x._array), device=x.device) def trunc(x: Array, /) -> Array: @@ -897,4 +963,4 @@ def trunc(x: Array, /) -> Array: if x.dtype in _integer_dtypes: # Note: The return dtype of trunc is the same as the input return x - return Array._new(np.trunc(x._array)) + return Array._new(np.trunc(x._array), device=x.device) diff --git a/array_api_strict/_fft.py b/array_api_strict/_fft.py index 32b9551..4b0ceb6 100644 --- a/array_api_strict/_fft.py +++ b/array_api_strict/_fft.py @@ -14,7 +14,7 @@ float32, complex64, ) -from ._array_object import Array, CPU_DEVICE +from ._array_object import Array, ALL_DEVICES from ._data_type_functions import astype from ._flags import requires_extension @@ -36,7 +36,7 @@ def fft( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in fft") - res = Array._new(np.fft.fft(x._array, n=n, axis=axis, norm=norm)) + res = Array._new(np.fft.fft(x._array, n=n, axis=axis, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -59,7 +59,7 @@ def ifft( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in ifft") - res = Array._new(np.fft.ifft(x._array, n=n, axis=axis, norm=norm)) + res = Array._new(np.fft.ifft(x._array, n=n, axis=axis, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -82,7 +82,7 @@ def fftn( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in fftn") - res = Array._new(np.fft.fftn(x._array, s=s, axes=axes, norm=norm)) + res = Array._new(np.fft.fftn(x._array, s=s, axes=axes, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -105,7 +105,7 @@ def ifftn( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in ifftn") - res = Array._new(np.fft.ifftn(x._array, s=s, axes=axes, norm=norm)) + res = Array._new(np.fft.ifftn(x._array, s=s, axes=axes, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -128,7 +128,7 @@ def rfft( """ if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in rfft") - res = Array._new(np.fft.rfft(x._array, n=n, axis=axis, norm=norm)) + res = Array._new(np.fft.rfft(x._array, n=n, axis=axis, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == float32: @@ -151,7 +151,7 @@ def irfft( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in irfft") - res = Array._new(np.fft.irfft(x._array, n=n, axis=axis, norm=norm)) + res = Array._new(np.fft.irfft(x._array, n=n, axis=axis, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -174,7 +174,7 @@ def rfftn( """ if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in rfftn") - res = Array._new(np.fft.rfftn(x._array, s=s, axes=axes, norm=norm)) + res = Array._new(np.fft.rfftn(x._array, s=s, axes=axes, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == float32: @@ -197,7 +197,7 @@ def irfftn( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in irfftn") - res = Array._new(np.fft.irfftn(x._array, s=s, axes=axes, norm=norm)) + res = Array._new(np.fft.irfftn(x._array, s=s, axes=axes, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -220,7 +220,7 @@ def hfft( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in hfft") - res = Array._new(np.fft.hfft(x._array, n=n, axis=axis, norm=norm)) + res = Array._new(np.fft.hfft(x._array, n=n, axis=axis, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -243,7 +243,7 @@ def ihfft( """ if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in ihfft") - res = Array._new(np.fft.ihfft(x._array, n=n, axis=axis, norm=norm)) + res = Array._new(np.fft.ihfft(x._array, n=n, axis=axis, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == float32: @@ -257,9 +257,9 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar See its docstring for more information. """ - if device not in [CPU_DEVICE, None]: + if device is not None and device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") - return Array._new(np.fft.fftfreq(n, d=d)) + return Array._new(np.fft.fftfreq(n, d=d), device=device) @requires_extension('fft') def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: @@ -268,9 +268,9 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A See its docstring for more information. """ - if device not in [CPU_DEVICE, None]: + if device is not None and device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") - return Array._new(np.fft.rfftfreq(n, d=d)) + return Array._new(np.fft.rfftfreq(n, d=d), device=device) @requires_extension('fft') def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: @@ -281,7 +281,7 @@ def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in fftshift") - return Array._new(np.fft.fftshift(x._array, axes=axes)) + return Array._new(np.fft.fftshift(x._array, axes=axes), device=x.device) @requires_extension('fft') def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: @@ -292,7 +292,7 @@ def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in ifftshift") - return Array._new(np.fft.ifftshift(x._array, axes=axes)) + return Array._new(np.fft.ifftshift(x._array, axes=axes), device=x.device) __all__ = [ "fft", diff --git a/array_api_strict/_indexing_functions.py b/array_api_strict/_indexing_functions.py index 316a3a7..c0f8e26 100644 --- a/array_api_strict/_indexing_functions.py +++ b/array_api_strict/_indexing_functions.py @@ -22,4 +22,6 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: raise TypeError("Only integer dtypes are allowed in indexing") if indices.ndim != 1: raise ValueError("Only 1-dim indices array is supported") - return Array._new(np.take(x._array, indices._array, axis=axis)) + if x.device != indices.device: + raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") + return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index ab5447a..3ed7fb2 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -6,7 +6,7 @@ from typing import Optional, Union, Tuple, List from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info -from ._array_object import CPU_DEVICE +from ._array_object import ALL_DEVICES, CPU_DEVICE from ._flags import get_array_api_strict_flags, requires_api_version from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 @@ -121,7 +121,7 @@ def dtypes( @requires_api_version('2023.12') def devices() -> List[device]: - return [CPU_DEVICE] + return list(ALL_DEVICES) __all__ = [ "capabilities", diff --git a/array_api_strict/_linalg.py b/array_api_strict/_linalg.py index bd11aa4..d341277 100644 --- a/array_api_strict/_linalg.py +++ b/array_api_strict/_linalg.py @@ -1,5 +1,7 @@ from __future__ import annotations +from functools import partial + from ._dtypes import ( _floating_dtypes, _numeric_dtypes, @@ -59,11 +61,11 @@ def cholesky(x: Array, /, *, upper: bool = False) -> Array: raise TypeError('Only floating-point dtypes are allowed in cholesky') L = np.linalg.cholesky(x._array) if upper: - U = Array._new(L).mT + U = Array._new(L, device=x.device).mT if U.dtype in [complex64, complex128]: U = conj(U) return U - return Array._new(L) + return Array._new(L, device=x.device) # Note: cross is the numpy top-level namespace, not np.linalg @requires_extension('linalg') @@ -81,6 +83,9 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if x1.shape[axis] != 3: raise ValueError('cross() dimension must equal 3') + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if get_array_api_strict_flags()['api_version'] >= '2023.12': if axis >= 0: raise ValueError("axis must be negative in cross") @@ -91,7 +96,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: # positive axis applied before or after broadcasting. NumPy applies # the axis before broadcasting. Since that behavior is what has always # been implemented here, we keep it for backwards compatibility. - return Array._new(np.cross(x1._array, x2._array, axis=axis)) + return Array._new(np.cross(x1._array, x2._array, axis=axis), device=x1.device) @requires_extension('linalg') def det(x: Array, /) -> Array: @@ -104,7 +109,7 @@ def det(x: Array, /) -> Array: # np.linalg.det. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in det') - return Array._new(np.linalg.det(x._array)) + return Array._new(np.linalg.det(x._array), device=x.device) # Note: diagonal is the numpy top-level namespace, not np.linalg @requires_extension('linalg') @@ -116,7 +121,7 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array: """ # Note: diagonal always operates on the last two axes, whereas np.diagonal # operates on the first two axes by default - return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1)) + return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1), device=x.device) @requires_extension('linalg') def eigh(x: Array, /) -> EighResult: @@ -132,7 +137,7 @@ def eigh(x: Array, /) -> EighResult: # Note: the return type here is a namedtuple, which is different from # np.eigh, which only returns a tuple. - return EighResult(*map(Array._new, np.linalg.eigh(x._array))) + return EighResult(*map(partial(Array._new, device=x.device), np.linalg.eigh(x._array))) @requires_extension('linalg') @@ -147,7 +152,7 @@ def eigvalsh(x: Array, /) -> Array: if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in eigvalsh') - return Array._new(np.linalg.eigvalsh(x._array)) + return Array._new(np.linalg.eigvalsh(x._array), device=x.device) @requires_extension('linalg') def inv(x: Array, /) -> Array: @@ -161,7 +166,7 @@ def inv(x: Array, /) -> Array: if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in inv') - return Array._new(np.linalg.inv(x._array)) + return Array._new(np.linalg.inv(x._array), device=x.device) # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). @@ -181,7 +186,7 @@ def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in matrix_norm') - return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord)) + return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord), device=x.device) @requires_extension('linalg') @@ -197,7 +202,7 @@ def matrix_power(x: Array, n: int, /) -> Array: raise TypeError('Only floating-point dtypes are allowed for the first argument of matrix_power') # np.matrix_power already checks if n is an integer - return Array._new(np.linalg.matrix_power(x._array, n)) + return Array._new(np.linalg.matrix_power(x._array, n), device=x.device) # Note: the keyword argument name rtol is different from np.linalg.matrix_rank @requires_extension('linalg') @@ -220,7 +225,7 @@ def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> A # Note: this is different from np.linalg.matrix_rank, which does not multiply # the tolerance by the largest singular value. tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis] - return Array._new(np.count_nonzero(S > tol, axis=-1)) + return Array._new(np.count_nonzero(S > tol, axis=-1), device=x.device) # Note: outer is the numpy top-level namespace, not np.linalg @@ -240,7 +245,10 @@ def outer(x1: Array, x2: Array, /) -> Array: if x1.ndim != 1 or x2.ndim != 1: raise ValueError('The input arrays to outer must be 1-dimensional') - return Array._new(np.outer(x1._array, x2._array)) + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + + return Array._new(np.outer(x1._array, x2._array), device=x1.device) # Note: the keyword argument name rtol is different from np.linalg.pinv @requires_extension('linalg') @@ -259,7 +267,7 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: # default tolerance by max(M, N). if rtol is None: rtol = max(x.shape[-2:]) * finfo(x.dtype).eps - return Array._new(np.linalg.pinv(x._array, rcond=rtol)) + return Array._new(np.linalg.pinv(x._array, rcond=rtol), device=x.device) @requires_extension('linalg') def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: # noqa: F821 @@ -275,7 +283,7 @@ def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRRe # Note: the return type here is a namedtuple, which is different from # np.linalg.qr, which only returns a tuple. - return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode))) + return QRResult(*map(partial(Array._new, device=x.device), np.linalg.qr(x._array, mode=mode))) @requires_extension('linalg') def slogdet(x: Array, /) -> SlogdetResult: @@ -291,7 +299,7 @@ def slogdet(x: Array, /) -> SlogdetResult: # Note: the return type here is a namedtuple, which is different from # np.linalg.slogdet, which only returns a tuple. - return SlogdetResult(*map(Array._new, np.linalg.slogdet(x._array))) + return SlogdetResult(*map(partial(Array._new, device=x.device), np.linalg.slogdet(x._array))) # Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a # vector when it is exactly 1-dimensional. All other cases treat x2 as a stack @@ -348,7 +356,10 @@ def solve(x1: Array, x2: Array, /) -> Array: if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in solve') - return Array._new(_solve(x1._array, x2._array)) + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + + return Array._new(_solve(x1._array, x2._array), device=x1.device) @requires_extension('linalg') def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: @@ -364,7 +375,7 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: the return type here is a namedtuple, which is different from # np.svd, which only returns a tuple. - return SVDResult(*map(Array._new, np.linalg.svd(x._array, full_matrices=full_matrices))) + return SVDResult(*map(partial(Array._new, device=x.device), np.linalg.svd(x._array, full_matrices=full_matrices))) # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). @@ -372,7 +383,7 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in svdvals') - return Array._new(np.linalg.svd(x._array, compute_uv=False)) + return Array._new(np.linalg.svd(x._array, compute_uv=False), device=x.device) # Note: trace is the numpy top-level namespace, not np.linalg @requires_extension('linalg') @@ -397,7 +408,7 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr dtype = dtype._np_dtype # Note: trace always operates on the last two axes, whereas np.trace # operates on the first two axes by default - return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype))) + return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype)), device=x.device) # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). @@ -437,7 +448,7 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No else: _axis = axis - res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord)) + res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord), device=x.device) if keepdims: # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py index dcb654d..5ffdaa6 100644 --- a/array_api_strict/_linear_algebra_functions.py +++ b/array_api_strict/_linear_algebra_functions.py @@ -30,7 +30,10 @@ def matmul(x1: Array, x2: Array, /) -> Array: if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in matmul') - return Array._new(np.matmul(x1._array, x2._array)) + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + + return Array._new(np.matmul(x1._array, x2._array), device=x1.device) # Note: tensordot is the numpy top-level namespace but not in np.linalg @@ -41,14 +44,17 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in tensordot') - return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + + return Array._new(np.tensordot(x1._array, x2._array, axes=axes), device=x1.device) # Note: this function is new in the array API spec. Unlike transpose, it only # transposes the last two axes. def matrix_transpose(x: Array, /) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") - return Array._new(np.swapaxes(x._array, -1, -2)) + return Array._new(np.swapaxes(x._array, -1, -2), device=x.device) # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: @@ -61,6 +67,9 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: elif axis < min(-1, -x1.ndim, -x2.ndim): raise ValueError("axis is out of bounds for x1 and x2") + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + # In versions of the standard prior to 2023.12, vecdot applied axis after # broadcasting. This is different from applying it before broadcasting # when axis is nonnegative. The below code keeps this behavior for @@ -78,4 +87,4 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x2_ = np.moveaxis(x2_, axis, -1) res = x1_[..., None, :] @ x2_[..., None] - return Array._new(res[..., 0, 0]) + return Array._new(res[..., 0, 0], device=x1.device) diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 702d259..d775835 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -25,8 +25,12 @@ def concat( # Note: Casting rules here are different from the np.concatenate default # (no for scalars with axis=None, no cross-kind casting) dtype = result_type(*arrays) + if len({a.device for a in arrays}) > 1: + raise ValueError("concat inputs must all be on the same device") + result_device = arrays[0].device + arrays = tuple(a._array for a in arrays) - return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype)) + return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype), device=result_device) def expand_dims(x: Array, /, *, axis: int) -> Array: @@ -35,7 +39,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array: See its docstring for more information. """ - return Array._new(np.expand_dims(x._array, axis)) + return Array._new(np.expand_dims(x._array, axis), device=x.device) def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: @@ -44,7 +48,7 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> See its docstring for more information. """ - return Array._new(np.flip(x._array, axis=axis)) + return Array._new(np.flip(x._array, axis=axis), device=x.device) @requires_api_version('2023.12') def moveaxis( @@ -58,7 +62,7 @@ def moveaxis( See its docstring for more information. """ - return Array._new(np.moveaxis(x._array, source, destination)) + return Array._new(np.moveaxis(x._array, source, destination), device=x.device) # Note: The function name is different here (see also matrix_transpose). # Unlike transpose(), the axes argument is required. @@ -68,7 +72,7 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: See its docstring for more information. """ - return Array._new(np.transpose(x._array, axes)) + return Array._new(np.transpose(x._array, axes), device=x.device) @requires_api_version('2023.12') def repeat( @@ -89,6 +93,8 @@ def repeat( raise RuntimeError("repeat() with repeats as an array requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") if repeats.dtype not in _integer_dtypes: raise TypeError("The repeats array must have an integer dtype") + if x.device != repeats.device: + raise ValueError(f"Arrays from two different devices ({x.device} and {repeats.device}) can not be combined.") elif isinstance(repeats, int): repeats = asarray(repeats) else: @@ -100,7 +106,7 @@ def repeat( # infeasable, and even if they are present by mistake, this will # lead to underflow and an error. repeats = astype(repeats, int64) - return Array._new(np.repeat(x._array, repeats._array, axis=axis)) + return Array._new(np.repeat(x._array, repeats._array, axis=axis), device=x.device) # Note: the optional argument is called 'shape', not 'newshape' def reshape(x: Array, @@ -123,7 +129,7 @@ def reshape(x: Array, if copy is False and not np.shares_memory(data, reshaped): raise AttributeError("Incompatible shape for in-place modification.") - return Array._new(reshaped) + return Array._new(reshaped, device=x.device) def roll( @@ -138,7 +144,7 @@ def roll( See its docstring for more information. """ - return Array._new(np.roll(x._array, shift, axis=axis)) + return Array._new(np.roll(x._array, shift, axis=axis), device=x.device) def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: @@ -147,7 +153,7 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: See its docstring for more information. """ - return Array._new(np.squeeze(x._array, axis=axis)) + return Array._new(np.squeeze(x._array, axis=axis), device=x.device) def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: @@ -158,8 +164,11 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> """ # Call result type here just to raise on disallowed type combinations result_type(*arrays) + if len({a.device for a in arrays}) > 1: + raise ValueError("concat inputs must all be on the same device") + result_device = arrays[0].device arrays = tuple(a._array for a in arrays) - return Array._new(np.stack(arrays, axis=axis)) + return Array._new(np.stack(arrays, axis=axis), device=result_device) @requires_api_version('2023.12') @@ -172,7 +181,7 @@ def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array: # Note: NumPy allows repetitions to be an int or array if not isinstance(repetitions, tuple): raise TypeError("repetitions must be a tuple") - return Array._new(np.tile(x._array, repetitions)) + return Array._new(np.tile(x._array, repetitions), device=x.device) # Note: this function is new @requires_api_version('2023.12') diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 7314895..0d7c0c8 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -19,7 +19,7 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - """ if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in argmax") - return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims)), device=x.device) def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: @@ -30,7 +30,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - """ if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in argmin") - return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims)), device=x.device) @requires_data_dependent_shapes @@ -43,7 +43,7 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: # Note: nonzero is disallowed on 0-dimensional arrays if x.ndim == 0: raise ValueError("nonzero is not allowed on 0-dimensional arrays") - return tuple(Array._new(i) for i in np.nonzero(x._array)) + return tuple(Array._new(i, device=x.device) for i in np.nonzero(x._array)) @requires_api_version('2023.12') def searchsorted( @@ -61,12 +61,16 @@ def searchsorted( """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in searchsorted") + + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + sorter = sorter._array if sorter is not None else None # TODO: The sort order of nans and signed zeros is implementation # dependent. Should we error/warn if they are present? # x1 must be 1-D, but NumPy already requires this. - return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter)) + return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device) def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ @@ -76,5 +80,9 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) + + if len({a.device for a in (condition, x1, x2)}) > 1: + raise ValueError("where inputs must all be on the same device") + x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.where(condition._array, x1._array, x2._array)) + return Array._new(np.where(condition._array, x1._array, x2._array), device=x1.device) diff --git a/array_api_strict/_set_functions.py b/array_api_strict/_set_functions.py index e6ca939..7bd5bad 100644 --- a/array_api_strict/_set_functions.py +++ b/array_api_strict/_set_functions.py @@ -55,10 +55,10 @@ def unique_all(x: Array, /) -> UniqueAllResult: # See https://github.com/numpy/numpy/issues/20638 inverse_indices = inverse_indices.reshape(x.shape) return UniqueAllResult( - Array._new(values), - Array._new(indices), - Array._new(inverse_indices), - Array._new(counts), + Array._new(values, device=x.device), + Array._new(indices, device=x.device), + Array._new(inverse_indices, device=x.device), + Array._new(counts, device=x.device), ) @@ -72,7 +72,7 @@ def unique_counts(x: Array, /) -> UniqueCountsResult: equal_nan=False, ) - return UniqueCountsResult(*[Array._new(i) for i in res]) + return UniqueCountsResult(*[Array._new(i, device=x.device) for i in res]) @requires_data_dependent_shapes @@ -92,7 +92,8 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult: # np.unique() flattens inverse indices, but they need to share x's shape # See https://github.com/numpy/numpy/issues/20638 inverse_indices = inverse_indices.reshape(x.shape) - return UniqueInverseResult(Array._new(values), Array._new(inverse_indices)) + return UniqueInverseResult(Array._new(values, device=x.device), + Array._new(inverse_indices, device=x.device)) @requires_data_dependent_shapes @@ -109,4 +110,4 @@ def unique_values(x: Array, /) -> Array: return_inverse=False, equal_nan=False, ) - return Array._new(res) + return Array._new(res, device=x.device) diff --git a/array_api_strict/_sorting_functions.py b/array_api_strict/_sorting_functions.py index 9b8cb04..765bd9e 100644 --- a/array_api_strict/_sorting_functions.py +++ b/array_api_strict/_sorting_functions.py @@ -33,7 +33,7 @@ def argsort( normalised_axis = axis if axis >= 0 else x.ndim + axis max_i = x.shape[normalised_axis] - 1 res = max_i - res - return Array._new(res) + return Array._new(res, device=x.device) # Note: the descending keyword argument is new in this function def sort( @@ -51,4 +51,4 @@ def sort( res = np.sort(x._array, axis=axis, kind=kind) if descending: res = np.flip(res, axis=axis) - return Array._new(res) + return Array._new(res, device=x.device) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 39e3736..6ea9746 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -44,7 +44,7 @@ def cumulative_sum( if axis < 0: axis += x.ndim x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis) - return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype)) + return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device) def max( x: Array, @@ -55,7 +55,7 @@ def max( ) -> Array: if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in max") - return Array._new(np.max(x._array, axis=axis, keepdims=keepdims)) + return Array._new(np.max(x._array, axis=axis, keepdims=keepdims), device=x.device) def mean( @@ -67,7 +67,7 @@ def mean( ) -> Array: if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in mean") - return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims)) + return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims), device=x.device) def min( @@ -79,7 +79,7 @@ def min( ) -> Array: if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in min") - return Array._new(np.min(x._array, axis=axis, keepdims=keepdims)) + return Array._new(np.min(x._array, axis=axis, keepdims=keepdims), device=x.device) def prod( @@ -104,7 +104,7 @@ def prod( dtype = np.complex128 else: dtype = dtype._np_dtype - return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims)) + return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims), device=x.device) def std( @@ -118,7 +118,7 @@ def std( # Note: the keyword argument correction is different here if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in std") - return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims)) + return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device) def sum( @@ -143,7 +143,7 @@ def sum( dtype = np.complex128 else: dtype = dtype._np_dtype - return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims)) + return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims), device=x.device) def var( @@ -157,4 +157,4 @@ def var( # Note: the keyword argument correction is different here if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in var") - return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims)) + return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device) diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index eb1b834..05a479c 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -27,7 +27,7 @@ Protocol, ) -from ._array_object import Array, _cpu_device +from ._array_object import Array, _device from ._dtypes import _DType _T_co = TypeVar("_T_co", covariant=True) @@ -37,7 +37,7 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -Device = _cpu_device +Device = _device Dtype = _DType diff --git a/array_api_strict/_utility_functions.py b/array_api_strict/_utility_functions.py index c91fa58..0d44ecb 100644 --- a/array_api_strict/_utility_functions.py +++ b/array_api_strict/_utility_functions.py @@ -21,7 +21,7 @@ def all( See its docstring for more information. """ - return Array._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims)), device=x.device) def any( @@ -36,4 +36,4 @@ def any( See its docstring for more information. """ - return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims)), device=x.device) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index dad6696..902c398 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -319,7 +319,7 @@ def test_python_scalar_construtors(): def test_device_property(): a = ones((3, 4)) assert a.device == CPU_DEVICE - assert a.device != 'cpu' + assert not isinstance(a.device, str) assert all(equal(a.to_device(CPU_DEVICE), a)) assert_raises(ValueError, lambda: a.to_device('cpu')) @@ -349,6 +349,17 @@ def test___array__(): assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64))) assert b.dtype == np.float64 +def test_array_conversion(): + # Check that arrays on the CPU device can be converted to NumPy + # but arrays on other devices can't + a = ones((2, 3)) + np.asarray(a) + + for device in ("device1", "device2"): + a = ones((2, 3), device=array_api_strict.Device(device)) + with pytest.raises(RuntimeError, match="Can not convert array"): + np.asarray(a) + def test_allow_newaxis(): a = ones(5) indexed_a = a[None, :] diff --git a/array_api_strict/tests/test_device_support.py b/array_api_strict/tests/test_device_support.py new file mode 100644 index 0000000..0f3d6b5 --- /dev/null +++ b/array_api_strict/tests/test_device_support.py @@ -0,0 +1,38 @@ +import pytest + +import array_api_strict + + +@pytest.mark.parametrize( + "func_name", + ( + "fft", + "ifft", + "fftn", + "ifftn", + "irfft", + "irfftn", + "hfft", + "fftshift", + "ifftshift", + ), +) +def test_fft_device_support_complex(func_name): + func = getattr(array_api_strict.fft, func_name) + x = array_api_strict.asarray( + [1, 2.0], + dtype=array_api_strict.complex64, + device=array_api_strict.Device("device1"), + ) + y = func(x) + + assert x.device == y.device + + +@pytest.mark.parametrize("func_name", ("rfft", "rfftn", "ihfft")) +def test_fft_device_support_real(func_name): + func = getattr(array_api_strict.fft, func_name) + x = array_api_strict.asarray([1, 2.0], device=array_api_strict.Device("device1")) + y = func(x) + + assert x.device == y.device diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 870361e..de11edf 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,4 +1,4 @@ -from inspect import getfullargspec, getmodule +from inspect import signature, getmodule from numpy.testing import assert_raises @@ -17,9 +17,16 @@ ) from .._flags import set_array_api_strict_flags +import array_api_strict + def nargs(func): - return len(getfullargspec(func).args) + """Count number of 'array' arguments a function takes.""" + positional_only = 0 + for param in signature(func).parameters.values(): + if param.kind == param.POSITIONAL_ONLY: + positional_only += 1 + return positional_only elementwise_function_input_types = { @@ -90,12 +97,57 @@ def nargs(func): "trunc": "real numeric", } + +def test_nargs(): + # Explicitly check number of arguments for a few functions + assert nargs(array_api_strict.logaddexp) == 2 + assert nargs(array_api_strict.atan2) == 2 + assert nargs(array_api_strict.clip) == 1 + + # All elementwise functions take one or two array arguments + # if not, it is probably a bug in `nargs` or the definition + # of the function (missing trailing `, /`). + for func_name in elementwise_function_input_types: + func = getattr(_elementwise_functions, func_name) + assert nargs(func) in (1, 2) + + def test_missing_functions(): # Ensure the above dictionary is complete. import array_api_strict._elementwise_functions as mod mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod] assert set(mod_funcs) == set(elementwise_function_input_types) + +def test_function_device_persists(): + # Test that the device of the input and output array are the same + def _array_vals(dtypes): + for d in dtypes: + yield asarray(1., dtype=d) + + # Use the latest version of the standard so all functions are included + set_array_api_strict_flags(api_version="2023.12") + + for func_name, types in elementwise_function_input_types.items(): + dtypes = _dtype_categories[types] + func = getattr(_elementwise_functions, func_name) + + for x in _array_vals(dtypes): + if nargs(func) == 2: + # This way we don't have to deal with incompatible + # types of the two arguments. + r = func(x, x) + assert r.device == x.device + + else: + # `atanh` needs a slightly different input value from + # everyone else + if func_name == "atanh": + x -= 0.1 + r = func(x) + assert r.device == x.device + + def test_function_types(): # Test that every function accepts only the required input types. We only # test the negative cases here (error). The positive cases are tested in @@ -128,12 +180,12 @@ def _array_vals(): or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes ): - assert_raises(TypeError, lambda: func(x, y)) + assert_raises(TypeError, func, x, y) if x.dtype not in dtypes or y.dtype not in dtypes: - assert_raises(TypeError, lambda: func(x, y)) + assert_raises(TypeError, func, x, y) else: if x.dtype not in dtypes: - assert_raises(TypeError, lambda: func(x)) + assert_raises(TypeError, func, x) def test_bitwise_shift_error(): diff --git a/array_api_strict/tests/test_indexing_functions.py b/array_api_strict/tests/test_indexing_functions.py index fabe688..f9fff58 100644 --- a/array_api_strict/tests/test_indexing_functions.py +++ b/array_api_strict/tests/test_indexing_functions.py @@ -22,3 +22,23 @@ def test_take_function(x, indices, axis, expected): indices = xp.asarray(indices) out = xp.take(x, indices, axis=axis) assert xp.all(out == xp.asarray(expected)) + + +def test_take_device(): + x = xp.asarray([2, 3]) + indices = xp.asarray([1, 1, 0]) + xp.take(x, indices) + + x = xp.asarray([2, 3]) + indices = xp.asarray([1, 1, 0], device=xp.Device("device1")) + with pytest.raises(ValueError, match="Arrays from two different devices"): + xp.take(x, indices) + + x = xp.asarray([2, 3], device=xp.Device("device1")) + indices = xp.asarray([1, 1, 0]) + with pytest.raises(ValueError, match="Arrays from two different devices"): + xp.take(x, indices) + + x = xp.asarray([2, 3], device=xp.Device("device1")) + indices = xp.asarray([1, 1, 0], device=xp.Device("device1")) + xp.take(x, indices) diff --git a/array_api_strict/tests/test_sorting_functions.py b/array_api_strict/tests/test_sorting_functions.py index c479260..716a651 100644 --- a/array_api_strict/tests/test_sorting_functions.py +++ b/array_api_strict/tests/test_sorting_functions.py @@ -21,3 +21,17 @@ def test_stable_desc_argsort(obj, axis, expected): x = xp.asarray(obj) out = xp.argsort(x, axis=axis, stable=True, descending=True) assert xp.all(out == xp.asarray(expected)) + + +def test_argsort_device(): + x = xp.asarray([1., 2., -1., 3.141], device=xp.Device("device1")) + y = xp.argsort(x) + + assert y.device == x.device + + +def test_sort_device(): + x = xp.asarray([1., 2., -1., 3.141], device=xp.Device("device1")) + y = xp.sort(x) + + assert y.device == x.device \ No newline at end of file