Skip to content

Commit

Permalink
Fix acos(-1) and add utils.extra_samples (#16)
Browse files Browse the repository at this point in the history
* Introduce utils.extra_samples. Use ULP-difference uniform samples.

* Fix evaluation of acos(-1)
  • Loading branch information
pearu authored Jul 30, 2024
1 parent bc35898 commit 7503d40
Show file tree
Hide file tree
Showing 14 changed files with 178 additions and 103 deletions.
8 changes: 7 additions & 1 deletion functional_algorithms/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,17 @@ def real_acos(ctx, x: float):
To avoid cancellation errors at abs(x) close to 1, we'll use
1 - x * x == (1 - x) * (1 + x)
At x == -1, we have arctan2(0, 0) and the above expression will
evaluate to 0 instead of expected pi. Therefore, we explicitly
define that arccos(-1) == pi.
"""
one = ctx.constant(1, x)
none = ctx.constant(-1, x)
two = ctx.constant(2, x)
sq = ctx.sqrt((one - x) * (one + x))
return ctx(two * ctx.atan2(sq, one + x))
r = two * ctx.atan2(sq, one + x)
return ctx(ctx.select(x != none, r, ctx.constant("pi", x)))


def complex_acos(ctx, z: complex):
Expand Down
2 changes: 2 additions & 0 deletions functional_algorithms/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def make_comment(message):
largest="std::numeric_limits<{type}>::max()",
posinf="std::numeric_limits<{type}>::infinity()",
neginf="-std::numeric_limits<{type}>::infinity()",
pi="M_PI",
nan="NAN",
)


Expand Down
4 changes: 3 additions & 1 deletion functional_algorithms/targets/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def make_comment(message):
ne="({0}) != ({1})",
)

constant_to_target = dict(smallest="sys.float_info.min", largest="sys.float_info.max", posinf="math.inf", neginf="-math.inf")
constant_to_target = dict(
smallest="sys.float_info.min", largest="sys.float_info.max", posinf="math.inf", neginf="-math.inf", pi="math.pi"
)

type_to_target = dict(integer="int", float="float", complex="complex", boolean="bool")

Expand Down
1 change: 1 addition & 0 deletions functional_algorithms/targets/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def make_comment(message):
smallest="StableHLO_ConstantLikeSmallestNormalizedValue",
posinf="StableHLO_ConstantLikePosInfValue",
neginf="StableHLO_ConstantLikeNegInfValue",
pi='StableHLO_ConstantLike<"M_PI">',
)


Expand Down
29 changes: 7 additions & 22 deletions functional_algorithms/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ def test_unary(dtype_name, unary_func_name, flush_subnormals):
include_subnormal=not flush_subnormals,
).flatten()
else:
samples = utils.real_samples(size * size, dtype=dtype, include_subnormal=not flush_subnormals).flatten()
samples = utils.real_samples(
size * size,
dtype=dtype,
include_subnormal=not flush_subnormals,
).flatten()

samples = numpy.concatenate((samples, utils.extra_samples(unary_func_name, dtype)))

matches_with_reference, ulp_stats = utils.validate_function(
func, reference, samples, dtype, flush_subnormals=flush_subnormals
Expand All @@ -81,27 +87,6 @@ def test_unary(dtype_name, unary_func_name, flush_subnormals):

assert matches_with_reference # warning: also reference may be wrong

extra_samples = []
if unary_func_name == "absolute" and dtype_name.startswith("complex"):
extra_samples.extend([1.0011048e35 + 3.4028235e38j])

if extra_samples:
samples = numpy.array(extra_samples, dtype=dtype)
matches_with_reference, ulp_stats = utils.validate_function(
func, reference, samples, dtype, flush_subnormals=flush_subnormals
)
if not matches_with_reference:
print("Extra samples:")
gt3_ulp_total = 0
for ulp in sorted(ulp_stats):
if ulp in {0, 1, 2, 3}:
print(f" dULP={ulp}: {ulp_stats[ulp]}")
elif ulp > 0:
gt3_ulp_total += ulp_stats[ulp]
else:
print(f" dULP>3: {gt3_ulp_total}")
assert matches_with_reference # warning: also reference may be wrong


def test_binary(dtype_name, binary_func_name, flush_subnormals):
if dtype_name.startswith("complex") and binary_func_name in {"hypot"}:
Expand Down
47 changes: 23 additions & 24 deletions functional_algorithms/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@ def dtype(request):


def _check_real_samples(
r, include_infinity=None, include_zero=None, include_subnormal=None, include_nan=None, nonnegative=None, include_huge=None
r,
include_infinity=None,
include_zero=None,
include_subnormal=None,
include_nan=None,
nonnegative=None,
include_huge=None,
):
fi = numpy.finfo(r.dtype)
size = r.size
if nonnegative:
if include_zero:
assert r[0] == 0
Expand All @@ -30,27 +37,27 @@ def _check_real_samples(
if include_nan:
assert numpy.isnan(r[-1])
if include_infinity:
if include_huge and r.size > 9:
if include_huge and size > 9:
assert numpy.nextafter(r[-4], numpy.inf, dtype=r.dtype) == fi.max
assert r[-3] == fi.max
assert numpy.isposinf(r[-2])
else:
assert r[-2] == fi.max
if include_huge and r.size > 9:
if include_huge and size > 9:
assert numpy.nextafter(r[-3], numpy.inf, dtype=r.dtype) == fi.max
for i in range(r.size - 2):
for i in range(size - 2):
assert r[i] < r[i + 1]
else:
if include_infinity:
if include_huge and r.size > 8:
if include_huge and size > 8:
assert numpy.nextafter(r[-3], numpy.inf, dtype=r.dtype) == fi.max
assert r[-2] == fi.max
assert numpy.isposinf(r[-1])
else:
assert r[-1] == fi.max
if include_huge and r.size > 8:
if include_huge and size > 8:
assert numpy.nextafter(r[-2], numpy.inf, dtype=r.dtype) == fi.max
for i in range(r.size - 1):
for i in range(size - 1):
assert r[i] < r[i + 1]
else:
if include_infinity:
Expand All @@ -74,20 +81,7 @@ def _check_real_samples(
for i in range(r.size - 1):
assert r[i] < r[i + 1]
if include_zero:
size = r.size
if include_nan:
size -= 1
if include_infinity:
if include_huge:
loc = (size - 1) // 2
else:
loc = (size - 1) // 2
else:
if include_huge:
loc = (size - 1) // 2
else:
loc = (size - 1) // 2
assert r[loc] == 0, (include_nan, include_infinity, include_huge)
loc = numpy.where(r == 0)[0][0]
if include_subnormal:
assert r[loc + 1] == fi.smallest_subnormal
assert r[loc - 1] == -fi.smallest_subnormal
Expand All @@ -97,9 +91,14 @@ def _check_real_samples(


def _iter_samples_parameters():
for include_huge, include_subnormal, include_infinity, include_zero, include_nan, nonnegative in itertools.product(
*(([False, True],) * 6)
):
for (
include_huge,
include_subnormal,
include_infinity,
include_zero,
include_nan,
nonnegative,
) in itertools.product(*(([False, True],) * 6)):
yield dict(
include_huge=include_huge,
include_subnormal=include_subnormal,
Expand Down
83 changes: 69 additions & 14 deletions functional_algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,40 @@ def worker(ctx, s, e, r, v):
assert 0 # unreachable


def extra_samples(name, dtype):
"""Return a list of samples that are special to a given function.
Parameters
----------
name: str
The name of a function
dtype:
Floating-point or complex dtype
Returns
-------
values: list
Values of function inputs.
"""
is_complex = "complex" in str(dtype)
is_float = "float" in str(dtype)
assert is_float or is_complex, dtype
values = []
# Notice that real/complex_samples already include special values
# such as 0, -inf, inf, smallest subnormals or normals, so don't
# specify these here.
if is_float:
if name in {"acos", "asin"}:
for v in [-1, 1]:
values.append(numpy.nextafter(v, v - 1, dtype=dtype))
values.append(v)
values.append(numpy.nextafter(v, v + 1, dtype=dtype))
if is_complex:
if name == "absolute":
values.append(1.0011048e35 + 3.4028235e38j)
return numpy.array(values, dtype=dtype)


def real_samples(
size=10,
dtype=numpy.float32,
Expand Down Expand Up @@ -615,31 +649,52 @@ def real_samples(
if isinstance(dtype, str):
dtype = getattr(numpy, dtype)
assert dtype in {numpy.float32, numpy.float64}, dtype
utype = {numpy.float32: numpy.uint32, numpy.float64: numpy.uint64}[dtype]
fi = numpy.finfo(dtype)
start = fi.minexp + fi.negep + 1 if include_subnormal else fi.minexp
end = fi.maxexp
num = size // 2 if not nonnegative else size
if include_infinity:
num -= 1
with warnings.catch_warnings(action="ignore"):
finite_positive = numpy.logspace(start, end, base=2, num=num, dtype=dtype)
min_value = dtype(fi.smallest_subnormal if include_subnormal else fi.smallest_normal)
max_value = dtype(fi.max)
if 1:
# The following method gives a sample distibution that is
# uniform with respect to ULP distance between positive
# neighboring samples
finite_positive = numpy.linspace(min_value.view(utype), max_value.view(utype), num=num, dtype=utype).view(dtype)
else:
start = fi.minexp + fi.negep + 1 if include_subnormal else fi.minexp
end = fi.maxexp
with warnings.catch_warnings(action="ignore"):
# Note that logspace gives samples distribution that is
# approximately uniform with respect to ULP distance between
# neighboring normal samples. For subnormal samples, logspace
# produces repeated samples that will be eliminated below via
# numpy.unique.
finite_positive = numpy.logspace(start, end, base=2, num=num, dtype=dtype)
finite_positive[-1] = max_value

if include_huge and num > 3:
huge = -numpy.nextafter(-max_value, numpy.inf, dtype=dtype)
finite_positive[-2] = huge

finite_positive[-1] = fi.max
if include_huge and size > 7:
finite_positive[-2] = -numpy.nextafter(-fi.max, numpy.inf, dtype=dtype)
parts = []
extra = []
if not nonnegative:
if include_infinity:
parts.append(numpy.array([-numpy.inf], dtype=dtype))
extra.append(-numpy.inf)
parts.append(-finite_positive[::-1])
if include_zero:
parts.append(numpy.array([0], dtype=dtype))
extra.append(0)
parts.append(finite_positive)
if include_infinity:
parts.append(numpy.array([numpy.inf], dtype=dtype))
extra.append(numpy.inf)
if include_nan:
parts.append(numpy.array([numpy.nan], dtype=dtype))
return numpy.concatenate(parts)
extra.append(numpy.nan)
parts.append(numpy.array(extra, dtype=dtype))

# Using unique because logspace produces repeated subnormals when
# size is large
return numpy.unique(numpy.concatenate(parts))


def complex_samples(
Expand Down Expand Up @@ -777,8 +832,8 @@ def complex_pair_samples(
include_zero=include_zero,
include_subnormal=include_subnormal,
include_nan=include_nan,
nonnegative=nonnegative,
include_huge=include_huge,
nonnegative=nonnegative,
)
s2 = complex_samples(
size=size[1],
Expand All @@ -787,8 +842,8 @@ def complex_pair_samples(
include_zero=include_zero,
include_subnormal=include_subnormal,
include_nan=include_nan,
nonnegative=nonnegative,
include_huge=include_huge,
nonnegative=nonnegative,
)
shape1 = s1.shape
shape2 = s2.shape
Expand Down
44 changes: 22 additions & 22 deletions results/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,25 @@ MPMath functions using multi-precision arithmetic.
| -------- | ----- | ------------- | ----- | ----- | ----- | ----- | --------- |
| absolute | float32 | 1000001 | - | - | - | - | - |
| absolute | float64 | 1000001 | - | - | - | - | - |
| absolute | complex64 | 977109 | 24892 | - | - | - | - |
| absolute | complex128 | 989613 | 12372 | 16 | - | - | - |
| acos | float32 | 504842 | 488335 | 6599 | 225 | - | - |
| acos | float64 | 986273 | 12357 | 1366 | 5 | - | - |
| acos | complex64 | 804838 | 196477 | 680 | 6 | - | - |
| acos | complex128 | 701506 | 300255 | 238 | 2 | - | - |
| acosh | float32 | 989143 | 10829 | 29 | - | - | - |
| acosh | float64 | 947718 | 52275 | 8 | - | - | - |
| acosh | complex64 | 804838 | 196477 | 680 | 6 | - | - |
| acosh | complex128 | 701506 | 300255 | 238 | 2 | - | - |
| asin | float32 | 937353 | 61812 | 810 | 26 | - | - |
| asin | float64 | 983317 | 16662 | 22 | - | - | - |
| asin | complex64 | 808487 | 191878 | 1592 | 44 | - | - |
| asin | complex128 | 694103 | 307326 | 560 | 12 | - | - |
| asinh | float32 | 922791 | 77144 | 66 | - | - | - |
| asinh | float64 | 829637 | 170270 | 94 | - | - | - |
| asinh | complex64 | 808487 | 191878 | 1592 | 44 | - | - |
| asinh | complex128 | 694103 | 307326 | 560 | 12 | - | - |
| square | float32 | 997293 | 2708 | - | - | - | - |
| square | float64 | 999649 | 352 | - | - | - | - |
| square | complex64 | 974577 | 27424 | - | - | - | - |
| square | complex128 | 994505 | 7496 | - | - | - | - |
| absolute | complex64 | 967753 | 33696 | 552 | - | - | - |
| absolute | complex128 | 991753 | 10104 | 144 | - | - | - |
| acos | float32 | 548396 | 444291 | 7072 | 242 | - | - |
| acos | float64 | 985930 | 12727 | 1338 | 6 | - | - |
| acos | complex64 | 810108 | 191263 | 622 | 8 | - | - |
| acos | complex128 | 690209 | 311554 | 238 | - | - | - |
| acosh | float32 | 988269 | 11704 | 28 | - | - | - |
| acosh | float64 | 946246 | 53752 | 3 | - | - | - |
| acosh | complex64 | 810108 | 191263 | 622 | 8 | - | - |
| acosh | complex128 | 690209 | 311554 | 238 | - | - | - |
| asin | float32 | 974679 | 24368 | 942 | 12 | - | - |
| asin | float64 | 995197 | 4776 | 28 | - | - | - |
| asin | complex64 | 807415 | 193174 | 1320 | 92 | - | - |
| asin | complex128 | 687179 | 313978 | 844 | - | - | - |
| asinh | float32 | 916129 | 83790 | 82 | - | - | - |
| asinh | float64 | 825453 | 174482 | 66 | - | - | - |
| asinh | complex64 | 807415 | 193174 | 1320 | 92 | - | - |
| asinh | complex128 | 687179 | 313978 | 844 | - | - | - |
| square | float32 | 997347 | 2654 | - | - | - | - |
| square | float64 | 999593 | 408 | - | - | - | - |
| square | complex64 | 976809 | 25192 | - | - | - | - |
| square | complex128 | 995833 | 6168 | - | - | - | - |
10 changes: 8 additions & 2 deletions results/cpp/acos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
float acos_0(float z) {
float one = 1;
float add_one_z = (one) + (z);
return (2) * (std::atan2(std::sqrt(((one) - (z)) * (add_one_z)), add_one_z));
return (((z) != (-1))
? ((2) * (std::atan2(std::sqrt(((one) - (z)) * (add_one_z)),
add_one_z)))
: (M_PI));
}

double acos_1(double z) {
double one = 1;
double add_one_z = (one) + (z);
return (2) * (std::atan2(std::sqrt(((one) - (z)) * (add_one_z)), add_one_z));
return (((z) != (-1))
? ((2) * (std::atan2(std::sqrt(((one) - (z)) * (add_one_z)),
add_one_z)))
: (M_PI));
}

std::complex<float> acos_2(std::complex<float> z) {
Expand Down
Loading

0 comments on commit 7503d40

Please sign in to comment.