diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 9f168cb..1f58f36 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -11,11 +11,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] - numpy-version: ['1.26', 'dev'] + python-version: ['3.12', '3.13'] + numpy-version: ['1.26', '2.2', 'dev'] exclude: - - python-version: '3.8' - numpy-version: 'dev' + - python-version: '3.13' + numpy-version: '1.26' steps: - name: Checkout array-api-strict @@ -38,7 +38,7 @@ jobs: if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy; else - python -m pip install 'numpy>=1.26,<2.0'; + python -m pip install 'numpy=='${{ matrix.numpy-version }}; fi python -m pip install ${GITHUB_WORKSPACE}/array-api-strict python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index a917441..9c2d22c 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -66,9 +66,6 @@ def __hash__(self): _default = object() -# See https://github.com/data-apis/array-api-strict/issues/67 and the comment -# on __array__ below. -_allow_array = True class Array: """ @@ -157,26 +154,22 @@ def __repr__(self: Array, /) -> str: # This was implemented historically for compatibility, and removing it has # caused issues for some libraries (see # https://github.com/data-apis/array-api-strict/issues/67). - def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]: - # We have to allow this to be internally enabled as there's no other - # easy way to parse a list of Array objects in asarray(). - if _allow_array: - if self._device != CPU_DEVICE: - raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") - # copy keyword is new in 2.0.0; for older versions don't use it - # retry without that keyword. - if np.__version__[0] < '2': - return np.asarray(self._array, dtype=dtype) - elif np.__version__.startswith('2.0.0-dev0'): - # Handle dev version for which we can't know based on version - # number whether or not the copy keyword is supported. - try: - return np.asarray(self._array, dtype=dtype, copy=copy) - except TypeError: - return np.asarray(self._array, dtype=dtype) - else: - return np.asarray(self._array, dtype=dtype, copy=copy) - raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported") + + # Instead of `__array__` we now implement the buffer protocol. + # Note that it makes array-apis-strict requiring python>=3.12 + def __buffer__(self, flags): + if self._device != CPU_DEVICE: + raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") + return memoryview(self._array) + def __release_buffer(self, buffer): + # XXX anything to do here? + pass + + def __array__(self, *args, **kwds): + # a stub for python < 3.12; otherwise numpy silently produces object arrays + raise TypeError( + "Interoperation with NumPy requires python >= 3.12. Please upgrade." + ) # These are various helper functions to make the array behavior match the # spec in places where it either deviates from or is more strict than diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 460dba9..015cbca 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -1,6 +1,5 @@ from __future__ import annotations -from contextlib import contextmanager from typing import TYPE_CHECKING, List, Optional, Tuple, Union if TYPE_CHECKING: @@ -16,19 +15,6 @@ import numpy as np -@contextmanager -def allow_array(): - """ - Temporarily enable Array.__array__. This is needed for np.array to parse - list of lists of Array objects. - """ - from . import _array_object - original_value = _array_object._allow_array - try: - _array_object._allow_array = True - yield - finally: - _array_object._allow_array = original_value def _check_valid_dtype(dtype): # Note: Only spelling dtypes as the dtype objects is supported. @@ -112,8 +98,8 @@ def asarray( # Give a better error message in this case. NumPy would convert this # to an object array. TODO: This won't handle large integers in lists. raise OverflowError("Integer out of bounds for array dtypes") - with allow_array(): - res = np.array(obj, dtype=_np_dtype, copy=copy) + + res = np.array(obj, dtype=_np_dtype, copy=copy) return Array._new(res, device=device) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 8f185f0..29d4013 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -1,3 +1,4 @@ +import sys import operator from builtins import all as all_ @@ -351,6 +352,10 @@ def test_array_properties(): assert b.mT.shape == (3, 2) +@pytest.mark.xfail(sys.version_info.major*100 + sys.version_info.minor < 312, + reason="array conversion relies on buffer protocol, and " + "requires python >= 3.12" +) def test_array_conversion(): # Check that arrays on the CPU device can be converted to NumPy # but arrays on other devices can't. Note this is testing the logic in @@ -361,25 +366,23 @@ def test_array_conversion(): for device in ("device1", "device2"): a = ones((2, 3), device=array_api_strict.Device(device)) - with pytest.raises(RuntimeError, match="Can not convert array"): + with pytest.raises((RuntimeError, TypeError)): asarray([a]) -def test__array__(): - # __array__ should work for now + # __buffer__ should work for now for conversion to numpy a = ones((2, 3)) - np.array(a) - - # Test the _allow_array private global flag for disabling it in the - # future. - from .. import _array_object - original_value = _array_object._allow_array - try: - _array_object._allow_array = False - a = ones((2, 3)) - with pytest.raises(ValueError, match="Conversion from an array_api_strict array to a NumPy ndarray is not supported"): - np.array(a) - finally: - _array_object._allow_array = original_value + na = np.array(a) + assert na.shape == (2, 3) + assert na.dtype == np.float64 + +@pytest.mark.skipif(not sys.version_info.major*100 + sys.version_info.minor < 312, + reason="conversion to numpy errors out unless python >= 3.12" +) +def test_array_conversion_2(): + a = ones((2, 3)) + with pytest.raises(TypeError): + np.array(a) + def test_allow_newaxis(): a = ones(5) diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index fc4e3cb..6908bc4 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -1,3 +1,4 @@ +import sys import warnings from numpy.testing import assert_raises @@ -97,7 +98,12 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 0) -def test_asarray_list_of_lists(): + +@pytest.mark.xfail(sys.version_info.major*100 + sys.version_info.minor < 312, + reason="array conversion relies on buffer protocol, and " + "requires python >= 3.12" +) +def test_asarray_list_of_arrays(): a = asarray(1, dtype=int16) b = asarray([1], dtype=int16) res = asarray([a, a])