27
27
28
28
pytestmark = pytest .mark .ci
29
29
30
+ # The special case test casess are built on runtime via the parametrized
31
+ # test_unary and test_binary functions. Most of this file consists of utility
32
+ # classes and functions, all bought together to create the test cases (pytest
33
+ # params), to finally be run through the general test logic of either test_unary
34
+ # or test_binary.
35
+
36
+
30
37
UnaryCheck = Callable [[float ], bool ]
31
38
BinaryCheck = Callable [[float , float ], bool ]
32
39
@@ -46,13 +53,13 @@ def strict_eq(i: float) -> bool:
46
53
return strict_eq
47
54
48
55
49
- def make_neq (v : float ) -> UnaryCheck :
50
- eq = make_strict_eq (v )
56
+ def make_strict_neq (v : float ) -> UnaryCheck :
57
+ strict_eq = make_strict_eq (v )
51
58
52
- def neq (i : float ) -> bool :
53
- return not eq (i )
59
+ def strict_neq (i : float ) -> bool :
60
+ return not strict_eq (i )
54
61
55
- return neq
62
+ return strict_neq
56
63
57
64
58
65
def make_rough_eq (v : float ) -> UnaryCheck :
@@ -121,14 +128,25 @@ def abs_cond(i: float) -> bool:
121
128
122
129
123
130
@dataclass
124
- class ValueParseError (ValueError ):
131
+ class ParseError (ValueError ):
125
132
value : str
126
133
127
134
128
135
def parse_value (value_str : str ) -> float :
136
+ """
137
+ Parse a value string to return a float, e.g.
138
+
139
+ >>> parse_value('1')
140
+ 1.
141
+ >>> parse_value('-infinity')
142
+ -float('inf')
143
+ >>> parse_value('3π/4')
144
+ 2.356194490192345
145
+
146
+ """
129
147
m = r_value .match (value_str )
130
148
if m is None :
131
- raise ValueParseError (value_str )
149
+ raise ParseError (value_str )
132
150
if pi_m := r_pi .match (m .group (2 )):
133
151
value = math .pi
134
152
if numerator := pi_m .group (1 ):
@@ -150,10 +168,19 @@ def parse_value(value_str: str) -> float:
150
168
151
169
152
170
def parse_inline_code (inline_code : str ) -> float :
171
+ """
172
+ Parse a Sphinx code string to return a float, e.g.
173
+
174
+ >>> parse_value('``0``')
175
+ 0.
176
+ >>> parse_value('``NaN``')
177
+ float('nan')
178
+
179
+ """
153
180
if m := r_code .match (inline_code ):
154
181
return parse_value (m .group (1 ))
155
182
else :
156
- raise ValueParseError (inline_code )
183
+ raise ParseError (inline_code )
157
184
158
185
159
186
r_not = re .compile ("not (.+)" )
@@ -165,16 +192,37 @@ def parse_inline_code(inline_code: str) -> float:
165
192
166
193
167
194
class FromDtypeFunc (Protocol ):
195
+ """
196
+ Type hint for functions that return an elements strategy for arrays of the
197
+ given dtype, e.g. xps.from_dtype().
198
+ """
199
+
168
200
def __call__ (self , dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
169
201
...
170
202
171
203
172
204
@dataclass
173
205
class BoundFromDtype (FromDtypeFunc ):
206
+ """
207
+ A callable which bounds kwargs and strategy filters to xps.from_dtype() or
208
+ equivalent function.
209
+
210
+
211
+
212
+ """
213
+
174
214
kwargs : Dict [str , Any ] = field (default_factory = dict )
175
215
filter_ : Optional [Callable [[Array ], bool ]] = None
176
216
base_func : Optional [FromDtypeFunc ] = None
177
217
218
+ def __call__ (self , dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
219
+ assert len (kw ) == 0 # sanity check
220
+ from_dtype = self .base_func or xps .from_dtype
221
+ strat = from_dtype (dtype , ** self .kwargs )
222
+ if self .filter_ is not None :
223
+ strat = strat .filter (self .filter_ )
224
+ return strat
225
+
178
226
def __add__ (self , other : BoundFromDtype ) -> BoundFromDtype :
179
227
for k in self .kwargs .keys ():
180
228
if k in other .kwargs .keys ():
@@ -202,14 +250,6 @@ def __add__(self, other: BoundFromDtype) -> BoundFromDtype:
202
250
203
251
return BoundFromDtype (kwargs , filter_ , base_func )
204
252
205
- def __call__ (self , dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
206
- assert len (kw ) == 0 # sanity check
207
- from_dtype = self .base_func or xps .from_dtype
208
- strat = from_dtype (dtype , ** self .kwargs )
209
- if self .filter_ is not None :
210
- strat = strat .filter (self .filter_ )
211
- return strat
212
-
213
253
214
254
def wrap_strat_as_from_dtype (strat : st .SearchStrategy [float ]) -> FromDtypeFunc :
215
255
def from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
@@ -238,7 +278,8 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
238
278
strat = st .just (value )
239
279
elif m := r_equal_to .match (cond_str ):
240
280
value = parse_value (m .group (1 ))
241
- assert not math .isnan (value ) # sanity check
281
+ if math .isnan (value ):
282
+ raise ParseError (cond_str )
242
283
cond = lambda i : i == value
243
284
expr_template = "{} == " + m .group (1 )
244
285
elif m := r_gt .match (cond_str ):
@@ -317,14 +358,16 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
317
358
return integers_from_dtype (dtype , ** kw ).filter (lambda n : n % 2 == 1 )
318
359
319
360
else :
320
- raise ValueParseError (cond_str )
361
+ raise ParseError (cond_str )
321
362
322
363
if strat is not None :
323
- # sanity checks
324
- assert not not_cond
325
- assert kwargs == {}
326
- assert filter_ is None
327
- assert from_dtype is None
364
+ if (
365
+ not_cond
366
+ or len (kwargs ) != 0
367
+ or filter_ is not None
368
+ or from_dtype is not None
369
+ ):
370
+ raise ParseError (cond_str )
328
371
return cond , expr_template , wrap_strat_as_from_dtype (strat )
329
372
330
373
if not_cond :
@@ -365,7 +408,7 @@ def check_result(result: float) -> bool:
365
408
366
409
expr = "-"
367
410
else :
368
- raise ValueParseError (result_str )
411
+ raise ParseError (result_str )
369
412
370
413
return check_result , expr
371
414
@@ -461,7 +504,7 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
461
504
if m := r_unary_case .search (case ):
462
505
try :
463
506
case = UnaryCase .from_strings (* m .groups ())
464
- except ValueParseError as e :
507
+ except ParseError as e :
465
508
warn (f"not machine-readable: '{ e .value } '" )
466
509
continue
467
510
cases .append (case )
@@ -609,7 +652,8 @@ def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
609
652
610
653
def parse_binary_case (case_str : str ) -> BinaryCase :
611
654
case_m = r_binary_case .match (case_str )
612
- assert case_m is not None # sanity check
655
+ if case_m is None :
656
+ raise ParseError (case_str )
613
657
cond_strs = r_cond_sep .split (case_m .group (1 ))
614
658
615
659
partial_conds = []
@@ -619,7 +663,8 @@ def parse_binary_case(case_str: str) -> BinaryCase:
619
663
for cond_str in cond_strs :
620
664
if m := r_input_is_array_element .match (cond_str ):
621
665
in_sign , in_no , other_sign , other_no = m .groups ()
622
- assert in_sign == "" and other_no != in_no # sanity check
666
+ if in_sign != "" or other_no == in_no :
667
+ raise ParseError (cond_str )
623
668
partial_expr = f"{ in_sign } x{ in_no } _i == { other_sign } x{ other_no } _i"
624
669
input_wrapper = lambda i : - i if other_sign == "-" else noop
625
670
shared_from_dtype = lambda d , ** kw : st .shared (
@@ -649,7 +694,7 @@ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
649
694
return shared_from_dtype (dtype , ** kw ).map (input_wrapper )
650
695
651
696
else :
652
- raise ValueParseError (cond_str )
697
+ raise ParseError (cond_str )
653
698
654
699
x1_cond_from_dtypes .append (BoundFromDtype (base_func = _x1_cond_from_dtype ))
655
700
x2_cond_from_dtypes .append (BoundFromDtype (base_func = _x2_cond_from_dtype ))
@@ -667,7 +712,7 @@ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
667
712
else :
668
713
cond_m = r_cond .match (cond_str )
669
714
if cond_m is None :
670
- raise ValueParseError (cond_str )
715
+ raise ParseError (cond_str )
671
716
input_str , value_str = cond_m .groups ()
672
717
673
718
if value_str == "the same mathematical sign" :
@@ -712,7 +757,7 @@ def partial_cond(i1: float, i2: float) -> bool:
712
757
partial_expr = f"({ partial_expr } )"
713
758
cond_arg = BinaryCondArg .EITHER
714
759
else :
715
- raise ValueParseError (input_str )
760
+ raise ParseError (input_str )
716
761
partial_cond = make_binary_cond ( # type: ignore
717
762
cond_arg , unary_cond , input_wrapper = input_wrapper
718
763
)
@@ -762,7 +807,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
762
807
763
808
result_m = r_result .match (case_m .group (2 ))
764
809
if result_m is None :
765
- raise ValueParseError (case_m .group (2 ))
810
+ raise ParseError (case_m .group (2 ))
766
811
result_str = result_m .group (1 )
767
812
if m := r_array_element .match (result_str ):
768
813
sign , x_no = m .groups ()
@@ -787,15 +832,15 @@ def cond(i1: float, i2: float) -> bool:
787
832
x1_cond_from_dtype = x1_cond_from_dtypes [0 ]
788
833
else :
789
834
if not all (isinstance (fd , BoundFromDtype ) for fd in x1_cond_from_dtypes ):
790
- raise ValueParseError (case_str )
835
+ raise ParseError (case_str )
791
836
x1_cond_from_dtype = sum (x1_cond_from_dtypes , start = BoundFromDtype ())
792
837
if len (x2_cond_from_dtypes ) == 0 :
793
838
x2_cond_from_dtype = xps .from_dtype
794
839
elif len (x2_cond_from_dtypes ) == 1 :
795
840
x2_cond_from_dtype = x2_cond_from_dtypes [0 ]
796
841
else :
797
842
if not all (isinstance (fd , BoundFromDtype ) for fd in x2_cond_from_dtypes ):
798
- raise ValueParseError (case_str )
843
+ raise ParseError (case_str )
799
844
x2_cond_from_dtype = sum (x2_cond_from_dtypes , start = BoundFromDtype ())
800
845
801
846
return BinaryCase (
@@ -829,7 +874,7 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
829
874
try :
830
875
case = parse_binary_case (case_str )
831
876
cases .append (case )
832
- except ValueParseError as e :
877
+ except ParseError as e :
833
878
warn (f"not machine-readable: '{ e .value } '" )
834
879
else :
835
880
if not r_remaining_case .match (case_str ):
0 commit comments