@@ -65,7 +65,9 @@ def expand_dims(
65
65
a : array
66
66
axis : int or tuple of ints
67
67
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
68
- If multiple positions are provided, they should be unique.
68
+ If multiple positions are provided, they should be unique (note that a position
69
+ given by a positive index could also be referred to by a negative index -
70
+ that will also result in an error).
69
71
Default: ``(0,)``.
70
72
xp : array_namespace
71
73
The standard-compatible namespace for `a`.
@@ -114,16 +116,19 @@ def expand_dims(
114
116
"""
115
117
if not isinstance (axis , tuple ):
116
118
axis = (axis ,)
117
- if len (set (axis )) != len (axis ):
118
- err_msg = "Duplicate dimensions specified in `axis`."
119
- raise ValueError (err_msg )
120
119
ndim = a .ndim + len (axis )
121
120
if axis != () and (min (axis ) < - ndim or max (axis ) >= ndim ):
122
121
err_msg = (
123
122
f"a provided axis position is out of bounds for array of dimension { a .ndim } "
124
123
)
125
124
raise IndexError (err_msg )
126
125
axis = tuple (dim % ndim for dim in axis )
126
+ if len (set (axis )) != len (axis ):
127
+ err_msg = "Duplicate dimensions specified in `axis`."
128
+ raise ValueError (err_msg )
129
+ if len (set (axis )) != len (axis ):
130
+ err_msg = "Duplicate dimensions specified in `axis`."
131
+ raise ValueError (err_msg )
127
132
for i in sorted (axis ):
128
133
a = xp .expand_dims (a , axis = i )
129
134
return a
0 commit comments