Skip to content

Commit 9c4b26b

Browse files
committed
Note check_result def problems
1 parent 40e0f30 commit 9c4b26b

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

Diff for: array_api_tests/test_special_cases.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,15 @@ def check_result(i1: float, i2: float, result: float) -> bool:
633633
else:
634634
raise ValueError(f"{eq_to=} must be FIRST or SECOND")
635635

636+
return check_result
637+
638+
639+
def make_check_result(check_just_result: UnaryCheck) -> BinaryResultCheck:
640+
def check_result(i1: float, i2: float, result: float) -> bool:
641+
return check_just_result(result)
642+
643+
return check_result
644+
636645

637646
def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
638647
for k in kw.keys():
@@ -733,7 +742,8 @@ def partial_cond(i1: float, i2: float) -> bool:
733742
# partial_cond definition can mess up previous definitions
734743
# in the partial_conds list. This is a hard-limitation of
735744
# using local functions with the same name and that use the same
736-
# outer variables (i.e. unary_cond).
745+
# outer variables (i.e. unary_cond). Use def in a called
746+
# function avoids this problem.
737747
input_wrapper = None
738748
if m := r_input.match(input_str):
739749
x_no = m.group(1)
@@ -809,6 +819,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
809819
if result_m is None:
810820
raise ParseError(case_m.group(2))
811821
result_str = result_m.group(1)
822+
# Like with partial_cond, do not define check_result via the def keyword
812823
if m := r_array_element.match(result_str):
813824
sign, x_no = m.groups()
814825
result_expr = f"{sign}x{x_no}_i"
@@ -817,9 +828,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
817828
)
818829
else:
819830
_check_result, result_expr = parse_result(result_m.group(1))
820-
821-
def check_result(i1: float, i2: float, result: float) -> bool:
822-
return _check_result(result)
831+
check_result = make_check_result(_check_result)
823832

824833
cond_expr = " and ".join(partial_exprs)
825834

0 commit comments

Comments
 (0)