Skip to content

Commit b32c173

Browse files
committed
Favour ParseError to assertions
1 parent c8d2339 commit b32c173

File tree

2 files changed

+83
-36
lines changed

2 files changed

+83
-36
lines changed

Diff for: array_api_tests/meta/test_special_cases.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,7 @@
44

55

66
def test_parse_result():
7-
s_result = "an implementation-dependent approximation to ``+3π/4``"
8-
assert parse_result(s_result).value == 3 * math.pi / 4
7+
check_result, _ = parse_result(
8+
"an implementation-dependent approximation to ``+3π/4``"
9+
)
10+
assert check_result(3 * math.pi / 4)

Diff for: array_api_tests/test_special_cases.py

+79-34
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@
2727

2828
pytestmark = pytest.mark.ci
2929

30+
# The special case test casess are built on runtime via the parametrized
31+
# test_unary and test_binary functions. Most of this file consists of utility
32+
# classes and functions, all bought together to create the test cases (pytest
33+
# params), to finally be run through the general test logic of either test_unary
34+
# or test_binary.
35+
36+
3037
UnaryCheck = Callable[[float], bool]
3138
BinaryCheck = Callable[[float, float], bool]
3239

@@ -46,13 +53,13 @@ def strict_eq(i: float) -> bool:
4653
return strict_eq
4754

4855

49-
def make_neq(v: float) -> UnaryCheck:
50-
eq = make_strict_eq(v)
56+
def make_strict_neq(v: float) -> UnaryCheck:
57+
strict_eq = make_strict_eq(v)
5158

52-
def neq(i: float) -> bool:
53-
return not eq(i)
59+
def strict_neq(i: float) -> bool:
60+
return not strict_eq(i)
5461

55-
return neq
62+
return strict_neq
5663

5764

5865
def make_rough_eq(v: float) -> UnaryCheck:
@@ -121,14 +128,25 @@ def abs_cond(i: float) -> bool:
121128

122129

123130
@dataclass
124-
class ValueParseError(ValueError):
131+
class ParseError(ValueError):
125132
value: str
126133

127134

128135
def parse_value(value_str: str) -> float:
136+
"""
137+
Parse a value string to return a float, e.g.
138+
139+
>>> parse_value('1')
140+
1.
141+
>>> parse_value('-infinity')
142+
-float('inf')
143+
>>> parse_value('3π/4')
144+
2.356194490192345
145+
146+
"""
129147
m = r_value.match(value_str)
130148
if m is None:
131-
raise ValueParseError(value_str)
149+
raise ParseError(value_str)
132150
if pi_m := r_pi.match(m.group(2)):
133151
value = math.pi
134152
if numerator := pi_m.group(1):
@@ -150,10 +168,19 @@ def parse_value(value_str: str) -> float:
150168

151169

152170
def parse_inline_code(inline_code: str) -> float:
171+
"""
172+
Parse a Sphinx code string to return a float, e.g.
173+
174+
>>> parse_value('``0``')
175+
0.
176+
>>> parse_value('``NaN``')
177+
float('nan')
178+
179+
"""
153180
if m := r_code.match(inline_code):
154181
return parse_value(m.group(1))
155182
else:
156-
raise ValueParseError(inline_code)
183+
raise ParseError(inline_code)
157184

158185

159186
r_not = re.compile("not (.+)")
@@ -165,16 +192,37 @@ def parse_inline_code(inline_code: str) -> float:
165192

166193

167194
class FromDtypeFunc(Protocol):
195+
"""
196+
Type hint for functions that return an elements strategy for arrays of the
197+
given dtype, e.g. xps.from_dtype().
198+
"""
199+
168200
def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]:
169201
...
170202

171203

172204
@dataclass
173205
class BoundFromDtype(FromDtypeFunc):
206+
"""
207+
A callable which bounds kwargs and strategy filters to xps.from_dtype() or
208+
equivalent function.
209+
210+
211+
212+
"""
213+
174214
kwargs: Dict[str, Any] = field(default_factory=dict)
175215
filter_: Optional[Callable[[Array], bool]] = None
176216
base_func: Optional[FromDtypeFunc] = None
177217

218+
def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]:
219+
assert len(kw) == 0 # sanity check
220+
from_dtype = self.base_func or xps.from_dtype
221+
strat = from_dtype(dtype, **self.kwargs)
222+
if self.filter_ is not None:
223+
strat = strat.filter(self.filter_)
224+
return strat
225+
178226
def __add__(self, other: BoundFromDtype) -> BoundFromDtype:
179227
for k in self.kwargs.keys():
180228
if k in other.kwargs.keys():
@@ -202,14 +250,6 @@ def __add__(self, other: BoundFromDtype) -> BoundFromDtype:
202250

203251
return BoundFromDtype(kwargs, filter_, base_func)
204252

205-
def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]:
206-
assert len(kw) == 0 # sanity check
207-
from_dtype = self.base_func or xps.from_dtype
208-
strat = from_dtype(dtype, **self.kwargs)
209-
if self.filter_ is not None:
210-
strat = strat.filter(self.filter_)
211-
return strat
212-
213253

214254
def wrap_strat_as_from_dtype(strat: st.SearchStrategy[float]) -> FromDtypeFunc:
215255
def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
@@ -238,7 +278,8 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
238278
strat = st.just(value)
239279
elif m := r_equal_to.match(cond_str):
240280
value = parse_value(m.group(1))
241-
assert not math.isnan(value) # sanity check
281+
if math.isnan(value):
282+
raise ParseError(cond_str)
242283
cond = lambda i: i == value
243284
expr_template = "{} == " + m.group(1)
244285
elif m := r_gt.match(cond_str):
@@ -317,14 +358,16 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
317358
return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1)
318359

319360
else:
320-
raise ValueParseError(cond_str)
361+
raise ParseError(cond_str)
321362

322363
if strat is not None:
323-
# sanity checks
324-
assert not not_cond
325-
assert kwargs == {}
326-
assert filter_ is None
327-
assert from_dtype is None
364+
if (
365+
not_cond
366+
or len(kwargs) != 0
367+
or filter_ is not None
368+
or from_dtype is not None
369+
):
370+
raise ParseError(cond_str)
328371
return cond, expr_template, wrap_strat_as_from_dtype(strat)
329372

330373
if not_cond:
@@ -365,7 +408,7 @@ def check_result(result: float) -> bool:
365408

366409
expr = "-"
367410
else:
368-
raise ValueParseError(result_str)
411+
raise ParseError(result_str)
369412

370413
return check_result, expr
371414

@@ -461,7 +504,7 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
461504
if m := r_unary_case.search(case):
462505
try:
463506
case = UnaryCase.from_strings(*m.groups())
464-
except ValueParseError as e:
507+
except ParseError as e:
465508
warn(f"not machine-readable: '{e.value}'")
466509
continue
467510
cases.append(case)
@@ -609,7 +652,8 @@ def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
609652

610653
def parse_binary_case(case_str: str) -> BinaryCase:
611654
case_m = r_binary_case.match(case_str)
612-
assert case_m is not None # sanity check
655+
if case_m is None:
656+
raise ParseError(case_str)
613657
cond_strs = r_cond_sep.split(case_m.group(1))
614658

615659
partial_conds = []
@@ -619,7 +663,8 @@ def parse_binary_case(case_str: str) -> BinaryCase:
619663
for cond_str in cond_strs:
620664
if m := r_input_is_array_element.match(cond_str):
621665
in_sign, in_no, other_sign, other_no = m.groups()
622-
assert in_sign == "" and other_no != in_no # sanity check
666+
if in_sign != "" or other_no == in_no:
667+
raise ParseError(cond_str)
623668
partial_expr = f"{in_sign}x{in_no}_i == {other_sign}x{other_no}_i"
624669
input_wrapper = lambda i: -i if other_sign == "-" else noop
625670
shared_from_dtype = lambda d, **kw: st.shared(
@@ -649,7 +694,7 @@ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
649694
return shared_from_dtype(dtype, **kw).map(input_wrapper)
650695

651696
else:
652-
raise ValueParseError(cond_str)
697+
raise ParseError(cond_str)
653698

654699
x1_cond_from_dtypes.append(BoundFromDtype(base_func=_x1_cond_from_dtype))
655700
x2_cond_from_dtypes.append(BoundFromDtype(base_func=_x2_cond_from_dtype))
@@ -667,7 +712,7 @@ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
667712
else:
668713
cond_m = r_cond.match(cond_str)
669714
if cond_m is None:
670-
raise ValueParseError(cond_str)
715+
raise ParseError(cond_str)
671716
input_str, value_str = cond_m.groups()
672717

673718
if value_str == "the same mathematical sign":
@@ -712,7 +757,7 @@ def partial_cond(i1: float, i2: float) -> bool:
712757
partial_expr = f"({partial_expr})"
713758
cond_arg = BinaryCondArg.EITHER
714759
else:
715-
raise ValueParseError(input_str)
760+
raise ParseError(input_str)
716761
partial_cond = make_binary_cond( # type: ignore
717762
cond_arg, unary_cond, input_wrapper=input_wrapper
718763
)
@@ -762,7 +807,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
762807

763808
result_m = r_result.match(case_m.group(2))
764809
if result_m is None:
765-
raise ValueParseError(case_m.group(2))
810+
raise ParseError(case_m.group(2))
766811
result_str = result_m.group(1)
767812
if m := r_array_element.match(result_str):
768813
sign, x_no = m.groups()
@@ -787,15 +832,15 @@ def cond(i1: float, i2: float) -> bool:
787832
x1_cond_from_dtype = x1_cond_from_dtypes[0]
788833
else:
789834
if not all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes):
790-
raise ValueParseError(case_str)
835+
raise ParseError(case_str)
791836
x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype())
792837
if len(x2_cond_from_dtypes) == 0:
793838
x2_cond_from_dtype = xps.from_dtype
794839
elif len(x2_cond_from_dtypes) == 1:
795840
x2_cond_from_dtype = x2_cond_from_dtypes[0]
796841
else:
797842
if not all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes):
798-
raise ValueParseError(case_str)
843+
raise ParseError(case_str)
799844
x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype())
800845

801846
return BinaryCase(
@@ -829,7 +874,7 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
829874
try:
830875
case = parse_binary_case(case_str)
831876
cases.append(case)
832-
except ValueParseError as e:
877+
except ParseError as e:
833878
warn(f"not machine-readable: '{e.value}'")
834879
else:
835880
if not r_remaining_case.match(case_str):

0 commit comments

Comments
 (0)