@@ -31,27 +31,35 @@ def test_getitem(shape, data):
31
31
obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
32
32
x = xp .asarray (obj , dtype = dtype )
33
33
note (f"{ x = } " )
34
- key = data .draw (xps .indices (shape = shape ), label = "key" )
34
+ key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
35
35
36
36
out = x [key ]
37
37
38
38
ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
39
39
_key = tuple (key ) if isinstance (key , tuple ) else (key ,)
40
40
if Ellipsis in _key :
41
- start_a = _key .index (Ellipsis )
42
- stop_a = start_a + (len (shape ) - (len (_key ) - 1 ))
43
- slices = tuple (slice (None , None ) for _ in range (start_a , stop_a ))
44
- _key = _key [:start_a ] + slices + _key [start_a + 1 :]
41
+ nonexpanding_key = tuple (i for i in _key if i is not None )
42
+ start_a = nonexpanding_key .index (Ellipsis )
43
+ stop_a = start_a + (len (shape ) - (len (nonexpanding_key ) - 1 ))
44
+ slices = tuple (slice (None ) for _ in range (start_a , stop_a ))
45
+ start_pos = _key .index (Ellipsis )
46
+ _key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
45
47
axes_indices = []
46
48
out_shape = []
47
- for a , i in enumerate (_key ):
48
- if isinstance (i , int ):
49
- axes_indices .append ([i ])
49
+ a = 0
50
+ for i in _key :
51
+ if i is None :
52
+ out_shape .append (1 )
50
53
else :
51
- side = shape [a ]
52
- indices = range (side )[i ]
53
- axes_indices .append (indices )
54
- out_shape .append (len (indices ))
54
+ if isinstance (i , int ):
55
+ axes_indices .append ([i ])
56
+ else :
57
+ assert isinstance (i , slice ) # sanity check
58
+ side = shape [a ]
59
+ indices = range (side )[i ]
60
+ axes_indices .append (indices )
61
+ out_shape .append (len (indices ))
62
+ a += 1
55
63
out_shape = tuple (out_shape )
56
64
ph .assert_shape ("__getitem__" , out .shape , out_shape )
57
65
assume (all (len (indices ) > 0 for indices in axes_indices ))
@@ -104,8 +112,6 @@ def test_setitem(shape, data):
104
112
)
105
113
106
114
107
- # TODO: make mask tests optional
108
-
109
115
@pytest .mark .data_dependent_shapes
110
116
@given (hh .shapes (), st .data ())
111
117
def test_getitem_masking (shape , data ):
0 commit comments