Skip to content

Commit 6503181

Browse files
committed
BUG: expand_dims: handle positive/negative duplicates
1 parent 0a0bed1 commit 6503181

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/array_api_extra/_funcs.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def expand_dims(
6565
a : array
6666
axis : int or tuple of ints
6767
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
68-
If multiple positions are provided, they should be unique.
68+
If multiple positions are provided, they should be unique (note that a position
69+
given by a positive index could also be referred to by a negative index -
70+
that will also result in an error).
6971
Default: ``(0,)``.
7072
xp : array_namespace
7173
The standard-compatible namespace for `a`.
@@ -114,16 +116,19 @@ def expand_dims(
114116
"""
115117
if not isinstance(axis, tuple):
116118
axis = (axis,)
117-
if len(set(axis)) != len(axis):
118-
err_msg = "Duplicate dimensions specified in `axis`."
119-
raise ValueError(err_msg)
120119
ndim = a.ndim + len(axis)
121120
if axis != () and (min(axis) < -ndim or max(axis) >= ndim):
122121
err_msg = (
123122
f"a provided axis position is out of bounds for array of dimension {a.ndim}"
124123
)
125124
raise IndexError(err_msg)
126125
axis = tuple(dim % ndim for dim in axis)
126+
if len(set(axis)) != len(axis):
127+
err_msg = "Duplicate dimensions specified in `axis`."
128+
raise ValueError(err_msg)
129+
if len(set(axis)) != len(axis):
130+
err_msg = "Duplicate dimensions specified in `axis`."
131+
raise ValueError(err_msg)
127132
for i in sorted(axis):
128133
a = xp.expand_dims(a, axis=i)
129134
return a

tests/test_funcs.py

+6
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,9 @@ def test_repeated_axis(self):
182182
a = xp.empty((3, 3, 3))
183183
with pytest.raises(ValueError, match="Duplicate dimensions"):
184184
expand_dims(a, axis=(1, 1), xp=xp)
185+
186+
def test_positive_negative_repeated(self):
187+
# https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817
188+
a = xp.empty((2, 3, 4, 5))
189+
with pytest.raises(ValueError, match="Duplicate dimensions"):
190+
expand_dims(a, axis=(3, -3), xp=xp)

0 commit comments

Comments
 (0)