Skip to content

Commit

Permalink
Note check_result def problems
Browse files Browse the repository at this point in the history
  • Loading branch information
honno committed Mar 14, 2022
1 parent b32c173 commit 2b677e3
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions array_api_tests/test_special_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,15 @@ def check_result(i1: float, i2: float, result: float) -> bool:
else:
raise ValueError(f"{eq_to=} must be FIRST or SECOND")

return check_result


def make_check_result(check_just_result: UnaryCheck) -> BinaryResultCheck:
def check_result(i1: float, i2: float, result: float) -> bool:
return check_just_result(result)

return check_result


def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
for k in kw.keys():
Expand Down Expand Up @@ -809,6 +818,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
if result_m is None:
raise ParseError(case_m.group(2))
result_str = result_m.group(1)
# Like with partial_cond, do not define check_result via the def keyword
if m := r_array_element.match(result_str):
sign, x_no = m.groups()
result_expr = f"{sign}x{x_no}_i"
Expand All @@ -817,9 +827,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
)
else:
_check_result, result_expr = parse_result(result_m.group(1))

def check_result(i1: float, i2: float, result: float) -> bool:
return _check_result(result)
check_result = make_check_result(_check_result)

cond_expr = " and ".join(partial_exprs)

Expand Down

0 comments on commit 2b677e3

Please sign in to comment.