Skip to content

Commit 2b677e3

Browse files
committed
Note check_result def problems
1 parent b32c173 commit 2b677e3

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

Diff for: array_api_tests/test_special_cases.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,15 @@ def check_result(i1: float, i2: float, result: float) -> bool:
633633
else:
634634
raise ValueError(f"{eq_to=} must be FIRST or SECOND")
635635

636+
return check_result
637+
638+
639+
def make_check_result(check_just_result: UnaryCheck) -> BinaryResultCheck:
640+
def check_result(i1: float, i2: float, result: float) -> bool:
641+
return check_just_result(result)
642+
643+
return check_result
644+
636645

637646
def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
638647
for k in kw.keys():
@@ -809,6 +818,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
809818
if result_m is None:
810819
raise ParseError(case_m.group(2))
811820
result_str = result_m.group(1)
821+
# Like with partial_cond, do not define check_result via the def keyword
812822
if m := r_array_element.match(result_str):
813823
sign, x_no = m.groups()
814824
result_expr = f"{sign}x{x_no}_i"
@@ -817,9 +827,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
817827
)
818828
else:
819829
_check_result, result_expr = parse_result(result_m.group(1))
820-
821-
def check_result(i1: float, i2: float, result: float) -> bool:
822-
return _check_result(result)
830+
check_result = make_check_result(_check_result)
823831

824832
cond_expr = " and ".join(partial_exprs)
825833

0 commit comments

Comments
 (0)