1
1
import math
2
2
from itertools import product
3
- from typing import List , get_args
3
+ from typing import List , Sequence , Tuple , Union , get_args
4
4
5
5
import pytest
6
6
from hypothesis import assume , given , note
12
12
from . import pytest_helpers as ph
13
13
from . import shape_helpers as sh
14
14
from . import xps
15
- from .typing import DataType , Param , Scalar , ScalarType , Shape
15
+ from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
16
+ from .typing import DataType , Index , Param , Scalar , ScalarType , Shape
16
17
17
18
pytestmark = pytest .mark .ci
18
19
19
20
20
- def scalar_objects (dtype : DataType , shape : Shape ) -> st .SearchStrategy [List [Scalar ]]:
21
+ def scalar_objects (
22
+ dtype : DataType , shape : Shape
23
+ ) -> st .SearchStrategy [Union [Scalar , List [Scalar ]]]:
21
24
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
22
25
size = math .prod (shape )
23
26
return st .lists (xps .from_dtype (dtype ), min_size = size , max_size = size ).map (
24
27
lambda l : sh .reshape (l , shape )
25
28
)
26
29
27
30
28
- @given (hh .shapes (min_side = 1 ), st .data ()) # TODO: test 0-sided arrays
29
- def test_getitem (shape , data ):
30
- dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
31
- obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
32
- x = xp .asarray (obj , dtype = dtype )
33
- note (f"{ x = } " )
34
- key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
35
-
36
- out = x [key ]
31
+ def normalise_key (key : Index , shape : Shape ) -> Tuple [Union [int , slice ], ...]:
32
+ """
33
+ Normalise an indexing key.
37
34
38
- ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
35
+ * If a non-tuple index, wrap as a tuple.
36
+ * Represent ellipsis as equivalent slices.
37
+ """
39
38
_key = tuple (key ) if isinstance (key , tuple ) else (key ,)
40
39
if Ellipsis in _key :
41
40
nonexpanding_key = tuple (i for i in _key if i is not None )
@@ -44,71 +43,109 @@ def test_getitem(shape, data):
44
43
slices = tuple (slice (None ) for _ in range (start_a , stop_a ))
45
44
start_pos = _key .index (Ellipsis )
46
45
_key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
46
+ return _key
47
+
48
+
49
+ def get_indexed_axes_and_out_shape (
50
+ key : Tuple [Union [int , slice , None ], ...], shape : Shape
51
+ ) -> Tuple [Tuple [Sequence [int ], ...], Shape ]:
52
+ """
53
+ From the (normalised) key and input shape, calculates:
54
+
55
+ * indexed_axes: For each dimension, the axes which the key indexes.
56
+ * out_shape: The resulting shape of indexing an array (of the input shape)
57
+ with the key.
58
+ """
47
59
axes_indices = []
48
60
out_shape = []
49
61
a = 0
50
- for i in _key :
62
+ for i in key :
51
63
if i is None :
52
64
out_shape .append (1 )
53
65
else :
66
+ side = shape [a ]
54
67
if isinstance (i , int ):
55
- axes_indices .append ([i ])
68
+ if i < 0 :
69
+ i += side
70
+ axes_indices .append ((i ,))
56
71
else :
57
- assert isinstance (i , slice ) # sanity check
58
- side = shape [a ]
59
72
indices = range (side )[i ]
60
73
axes_indices .append (indices )
61
74
out_shape .append (len (indices ))
62
75
a += 1
63
- out_shape = tuple (out_shape )
76
+ return tuple (axes_indices ), tuple (out_shape )
77
+
78
+
79
+ @given (shape = hh .shapes (), dtype = xps .scalar_dtypes (), data = st .data ())
80
+ def test_getitem (shape , dtype , data ):
81
+ zero_sided = any (side == 0 for side in shape )
82
+ if zero_sided :
83
+ x = xp .zeros (shape , dtype = dtype )
84
+ else :
85
+ obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
86
+ x = xp .asarray (obj , dtype = dtype )
87
+ note (f"{ x = } " )
88
+ key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
89
+
90
+ out = x [key ]
91
+
92
+ ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
93
+ _key = normalise_key (key , shape )
94
+ axes_indices , out_shape = get_indexed_axes_and_out_shape (_key , shape )
64
95
ph .assert_shape ("__getitem__" , out .shape , out_shape )
65
- assume (all (len (indices ) > 0 for indices in axes_indices ))
66
- out_obj = []
67
- for idx in product (* axes_indices ):
68
- val = obj
69
- for i in idx :
70
- val = val [i ]
71
- out_obj .append (val )
72
- out_obj = sh .reshape (out_obj , out_shape )
73
- expected = xp .asarray (out_obj , dtype = dtype )
74
- ph .assert_array ("__getitem__" , out , expected )
75
-
76
-
77
- @given (hh .shapes (min_side = 1 ), st .data ()) # TODO: test 0-sided arrays
78
- def test_setitem (shape , data ):
79
- dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
80
- obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
81
- x = xp .asarray (obj , dtype = dtype )
96
+ out_zero_sided = any (side == 0 for side in out_shape )
97
+ if not zero_sided and not out_zero_sided :
98
+ out_obj = []
99
+ for idx in product (* axes_indices ):
100
+ val = obj
101
+ for i in idx :
102
+ val = val [i ]
103
+ out_obj .append (val )
104
+ out_obj = sh .reshape (out_obj , out_shape )
105
+ expected = xp .asarray (out_obj , dtype = dtype )
106
+ ph .assert_array_elements ("__getitem__" , out , expected )
107
+
108
+
109
+ @given (
110
+ shape = hh .shapes (),
111
+ dtypes = oneway_promotable_dtypes (dh .all_dtypes ),
112
+ data = st .data (),
113
+ )
114
+ def test_setitem (shape , dtypes , data ):
115
+ zero_sided = any (side == 0 for side in shape )
116
+ if zero_sided :
117
+ x = xp .zeros (shape , dtype = dtypes .result_dtype )
118
+ else :
119
+ obj = data .draw (scalar_objects (dtypes .result_dtype , shape ), label = "obj" )
120
+ x = xp .asarray (obj , dtype = dtypes .result_dtype )
82
121
note (f"{ x = } " )
83
- # TODO: test setting non-0d arrays
84
- key = data .draw (xps .indices (shape = shape , max_dims = 0 ), label = "key" )
85
- value = data .draw (
86
- xps .from_dtype (dtype ) | xps .arrays (dtype = dtype , shape = ()), label = "value"
87
- )
122
+ key = data .draw (xps .indices (shape = shape ), label = "key" )
123
+ _key = normalise_key (key , shape )
124
+ axes_indices , out_shape = get_indexed_axes_and_out_shape (_key , shape )
125
+ value_strat = xps .arrays (dtype = dtypes .result_dtype , shape = out_shape )
126
+ if out_shape == ():
127
+ # We can pass scalars if we're only indexing one element
128
+ value_strat |= xps .from_dtype (dtypes .result_dtype )
129
+ value = data .draw (value_strat , label = "value" )
88
130
89
131
res = xp .asarray (x , copy = True )
90
132
res [key ] = value
91
133
92
134
ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
93
135
ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.shape" )
136
+ f_res = sh .fmt_idx ("x" , key )
94
137
if isinstance (value , get_args (Scalar )):
95
- msg = f"x[ { key } ] ={ res [key ]!r} , but should be { value = } [__setitem__()]"
138
+ msg = f"{ f_res } ={ res [key ]!r} , but should be { value = } [__setitem__()]"
96
139
if math .isnan (value ):
97
140
assert xp .isnan (res [key ]), msg
98
141
else :
99
142
assert res [key ] == value , msg
100
143
else :
101
- ph .assert_0d_equals (
102
- "__setitem__" , "value" , value , f"modified x[{ key } ]" , res [key ]
103
- )
104
- _key = key if isinstance (key , tuple ) else (key ,)
105
- assume (all (isinstance (i , int ) for i in _key )) # TODO: normalise slices and ellipsis
106
- _key = tuple (i if i >= 0 else s + i for i , s in zip (_key , x .shape ))
107
- unaffected_indices = list (sh .ndindex (res .shape ))
108
- unaffected_indices .remove (_key )
144
+ ph .assert_array_elements ("__setitem__" , res [key ], value , out_repr = f_res )
145
+ unaffected_indices = set (sh .ndindex (res .shape )) - set (product (* axes_indices ))
109
146
for idx in unaffected_indices :
110
147
ph .assert_0d_equals (
111
- "__setitem__" , f"old x[ { idx } ] " , x [idx ], f"modified x[ { idx } ] " , res [idx ]
148
+ "__setitem__" , f"old { f_res } " , x [idx ], f"modified { f_res } " , res [idx ]
112
149
)
113
150
114
151
0 commit comments