Skip to content

Commit

Permalink
Update casting behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Apr 1, 2024
1 parent a9097c2 commit 28db1dc
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def not_overflowing_array(array) -> bool:
else:
return True

return bool(np.all((array < info.max / array.size) & (array > info.min / array.size)))
result = bool(np.all((array < info.max / array.size) & (array > info.min / array.size)))
# note(f"returning {result}, {array.min()} vs {info.min}, {array.max()} vs {info.max}")
return result


@settings(suppress_health_check=[HealthCheck.filter_too_much])
Expand All @@ -57,14 +59,19 @@ def not_overflowing_array(array) -> bool:
def test_groupby_reduce(array, dtype, func):
# overflow behaviour differs between bincount and sum (for example)
assume(not_overflowing_array(array))
# numpy-groupies always does the calculation in float64
assume(func != "var" and "f2" not in array.dtype.str)
# arg* with nans in array are weird
assume("arg" not in func and not np.any(np.isnan(array).ravel()))

axis = -1
by = np.ones((array.shape[-1],), dtype=dtype)
kwargs = {"q": 0.8} if "quantile" in func else {}
# numpy-groupies always does the calculation in float64
if ("var" in func or "std" in func or "sum" in func) and array.dtype.kinde == "f":
# bincount accumulates in float64
kwargs.setdefault("dtype", np.float64)
cast_to = array.dtype
else:
cast_to = None

with np.errstate(invalid="ignore", divide="ignore"):
actual, _ = groupby_reduce(
Expand All @@ -73,4 +80,6 @@ def test_groupby_reduce(array, dtype, func):
expected = getattr(np, func)(array, axis=axis, keepdims=True, **kwargs)
note(("expected: ", expected, "actual: ", actual))
tolerance = {"rtol": 1e-13, "atol": 1e-16} if "var" in func or "std" in func else {}
if cast_to:
actual = actual.astype(cast_to)
assert_equal(expected, actual, tolerance)

0 comments on commit 28db1dc

Please sign in to comment.