@@ -633,6 +633,15 @@ def check_result(i1: float, i2: float, result: float) -> bool:
633
633
else :
634
634
raise ValueError (f"{ eq_to = } must be FIRST or SECOND" )
635
635
636
+ return check_result
637
+
638
+
639
+ def make_check_result (check_just_result : UnaryCheck ) -> BinaryResultCheck :
640
+ def check_result (i1 : float , i2 : float , result : float ) -> bool :
641
+ return check_just_result (result )
642
+
643
+ return check_result
644
+
636
645
637
646
def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
638
647
for k in kw .keys ():
@@ -733,7 +742,8 @@ def partial_cond(i1: float, i2: float) -> bool:
733
742
# partial_cond definition can mess up previous definitions
734
743
# in the partial_conds list. This is a hard-limitation of
735
744
# using local functions with the same name and that use the same
736
- # outer variables (i.e. unary_cond).
745
+ # outer variables (i.e. unary_cond). Use def in a called
746
+ # function avoids this problem.
737
747
input_wrapper = None
738
748
if m := r_input .match (input_str ):
739
749
x_no = m .group (1 )
@@ -809,6 +819,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
809
819
if result_m is None :
810
820
raise ParseError (case_m .group (2 ))
811
821
result_str = result_m .group (1 )
822
+ # Like with partial_cond, do not define check_result via the def keyword
812
823
if m := r_array_element .match (result_str ):
813
824
sign , x_no = m .groups ()
814
825
result_expr = f"{ sign } x{ x_no } _i"
@@ -817,9 +828,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
817
828
)
818
829
else :
819
830
_check_result , result_expr = parse_result (result_m .group (1 ))
820
-
821
- def check_result (i1 : float , i2 : float , result : float ) -> bool :
822
- return _check_result (result )
831
+ check_result = make_check_result (_check_result )
823
832
824
833
cond_expr = " and " .join (partial_exprs )
825
834
0 commit comments