Skip to content

Commit 6b0079b

Browse files
authored
Merge pull request #48 from asmeurer/less-strict-iter
Allow iteration on 1-D arrays
2 parents b1303a1 + 7c55e47 commit 6b0079b

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

Diff for: array_api_strict/_array_object.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -677,10 +677,16 @@ def __iter__(self: Array, /):
677677
"""
678678
Performs the operation __iter__.
679679
"""
680-
# Manually disable iteration, since __getitem__ raises IndexError on
681-
# things like ones((3, 3))[0], which causes list(ones((3, 3))) to give
682-
# [].
683-
raise TypeError("array iteration is not allowed in array-api-strict")
680+
# Manually disable iteration on higher dimensional arrays, since
681+
# __getitem__ raises IndexError on things like ones((3, 3))[0], which
682+
# causes list(ones((3, 3))) to give [].
683+
if self.ndim > 1:
684+
raise TypeError("array iteration is not allowed in array-api-strict")
685+
# Allow iteration for 1-D arrays. The array API doesn't strictly
686+
# define __iter__, but it doesn't disallow it. The default Python
687+
# behavior is to implement iter as a[0], a[1], ... when __getitem__ is
688+
# implemented, which implies iteration on 1-D arrays.
689+
return (Array._new(i) for i in self._array)
684690

685691
def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
686692
"""

Diff for: array_api_strict/tests/test_array_object.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import operator
2+
from builtins import all as all_
23

34
from numpy.testing import assert_raises, suppress_warnings
45
import numpy as np
@@ -21,6 +22,7 @@
2122
int32,
2223
int64,
2324
uint64,
25+
float64,
2426
bool as bool_,
2527
)
2628
from .._flags import set_array_api_strict_flags
@@ -423,8 +425,12 @@ def test_array_namespace():
423425
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
424426
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))
425427

426-
def test_no_iter():
427-
pytest.raises(TypeError, lambda: iter(ones(3)))
428+
def test_iter():
429+
pytest.raises(TypeError, lambda: iter(asarray(3)))
430+
assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)]
431+
assert all_(isinstance(a, Array) for a in iter(ones(3)))
432+
assert all_(a.shape == () for a in iter(ones(3)))
433+
assert all_(a.dtype == float64 for a in iter(ones(3)))
428434
pytest.raises(TypeError, lambda: iter(ones((3, 3))))
429435

430436
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])

Diff for: docs/changelog.md

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# Changelog
22

3+
### 2.0.1 (2024-07-01)
4+
5+
## Minor Changes
6+
7+
- Re-allow iteration on 1-D arrays. A change from 2.0 fixed iter() raising on
8+
n-D arrays but also made 1-D arrays raise. The standard does not explicitly
9+
disallow iteration on 1-D arrays, and the default Python `__iter__`
10+
implementation allows it to work, so for now, it is kept intact as working.
11+
312
## 2.0 (2024-06-27)
413

514
### Major Changes

0 commit comments

Comments
 (0)