From 7503d40ad260d4f3d363b80bf2a22b228db11526 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Tue, 30 Jul 2024 16:37:50 +0300 Subject: [PATCH] Fix acos(-1) and add utils.extra_samples (#16) * Introduce utils.extra_samples. Use ULP-difference uniform samples. * Fix evaluation of acos(-1) --- functional_algorithms/algorithms.py | 8 +- functional_algorithms/targets/cpp.py | 2 + functional_algorithms/targets/python.py | 4 +- functional_algorithms/targets/stablehlo.py | 1 + .../tests/test_algorithms.py | 29 ++----- functional_algorithms/tests/test_utils.py | 47 +++++------ functional_algorithms/utils.py | 83 +++++++++++++++---- results/README.md | 44 +++++----- results/cpp/acos.cpp | 10 ++- results/numpy/acos.py | 12 ++- results/python/acos.py | 2 +- results/stablehlo/acos.td | 27 +++--- results/xla_client/acos.cc | 6 +- results/xla_client/real_acos.cc | 6 +- 14 files changed, 178 insertions(+), 103 deletions(-) diff --git a/functional_algorithms/algorithms.py b/functional_algorithms/algorithms.py index 6cc21cb..4aba6ec 100644 --- a/functional_algorithms/algorithms.py +++ b/functional_algorithms/algorithms.py @@ -496,11 +496,17 @@ def real_acos(ctx, x: float): To avoid cancellation errors at abs(x) close to 1, we'll use 1 - x * x == (1 - x) * (1 + x) + + At x == -1, we have arctan2(0, 0) and the above expression will + evaluate to 0 instead of expected pi. Therefore, we explicitly + define that arccos(-1) == pi. """ one = ctx.constant(1, x) + none = ctx.constant(-1, x) two = ctx.constant(2, x) sq = ctx.sqrt((one - x) * (one + x)) - return ctx(two * ctx.atan2(sq, one + x)) + r = two * ctx.atan2(sq, one + x) + return ctx(ctx.select(x != none, r, ctx.constant("pi", x))) def complex_acos(ctx, z: complex): diff --git a/functional_algorithms/targets/cpp.py b/functional_algorithms/targets/cpp.py index b9ceecc..46a4b8b 100644 --- a/functional_algorithms/targets/cpp.py +++ b/functional_algorithms/targets/cpp.py @@ -119,6 +119,8 @@ def make_comment(message): largest="std::numeric_limits<{type}>::max()", posinf="std::numeric_limits<{type}>::infinity()", neginf="-std::numeric_limits<{type}>::infinity()", + pi="M_PI", + nan="NAN", ) diff --git a/functional_algorithms/targets/python.py b/functional_algorithms/targets/python.py index 99f321e..c99716a 100644 --- a/functional_algorithms/targets/python.py +++ b/functional_algorithms/targets/python.py @@ -101,7 +101,9 @@ def make_comment(message): ne="({0}) != ({1})", ) -constant_to_target = dict(smallest="sys.float_info.min", largest="sys.float_info.max", posinf="math.inf", neginf="-math.inf") +constant_to_target = dict( + smallest="sys.float_info.min", largest="sys.float_info.max", posinf="math.inf", neginf="-math.inf", pi="math.pi" +) type_to_target = dict(integer="int", float="float", complex="complex", boolean="bool") diff --git a/functional_algorithms/targets/stablehlo.py b/functional_algorithms/targets/stablehlo.py index c29d1f6..250139d 100644 --- a/functional_algorithms/targets/stablehlo.py +++ b/functional_algorithms/targets/stablehlo.py @@ -105,6 +105,7 @@ def make_comment(message): smallest="StableHLO_ConstantLikeSmallestNormalizedValue", posinf="StableHLO_ConstantLikePosInfValue", neginf="StableHLO_ConstantLikeNegInfValue", + pi='StableHLO_ConstantLike<"M_PI">', ) diff --git a/functional_algorithms/tests/test_algorithms.py b/functional_algorithms/tests/test_algorithms.py index 52876ce..9eed14f 100644 --- a/functional_algorithms/tests/test_algorithms.py +++ b/functional_algorithms/tests/test_algorithms.py @@ -63,7 +63,13 @@ def test_unary(dtype_name, unary_func_name, flush_subnormals): include_subnormal=not flush_subnormals, ).flatten() else: - samples = utils.real_samples(size * size, dtype=dtype, include_subnormal=not flush_subnormals).flatten() + samples = utils.real_samples( + size * size, + dtype=dtype, + include_subnormal=not flush_subnormals, + ).flatten() + + samples = numpy.concatenate((samples, utils.extra_samples(unary_func_name, dtype))) matches_with_reference, ulp_stats = utils.validate_function( func, reference, samples, dtype, flush_subnormals=flush_subnormals @@ -81,27 +87,6 @@ def test_unary(dtype_name, unary_func_name, flush_subnormals): assert matches_with_reference # warning: also reference may be wrong - extra_samples = [] - if unary_func_name == "absolute" and dtype_name.startswith("complex"): - extra_samples.extend([1.0011048e35 + 3.4028235e38j]) - - if extra_samples: - samples = numpy.array(extra_samples, dtype=dtype) - matches_with_reference, ulp_stats = utils.validate_function( - func, reference, samples, dtype, flush_subnormals=flush_subnormals - ) - if not matches_with_reference: - print("Extra samples:") - gt3_ulp_total = 0 - for ulp in sorted(ulp_stats): - if ulp in {0, 1, 2, 3}: - print(f" dULP={ulp}: {ulp_stats[ulp]}") - elif ulp > 0: - gt3_ulp_total += ulp_stats[ulp] - else: - print(f" dULP>3: {gt3_ulp_total}") - assert matches_with_reference # warning: also reference may be wrong - def test_binary(dtype_name, binary_func_name, flush_subnormals): if dtype_name.startswith("complex") and binary_func_name in {"hypot"}: diff --git a/functional_algorithms/tests/test_utils.py b/functional_algorithms/tests/test_utils.py index 42a693e..38074db 100644 --- a/functional_algorithms/tests/test_utils.py +++ b/functional_algorithms/tests/test_utils.py @@ -11,9 +11,16 @@ def dtype(request): def _check_real_samples( - r, include_infinity=None, include_zero=None, include_subnormal=None, include_nan=None, nonnegative=None, include_huge=None + r, + include_infinity=None, + include_zero=None, + include_subnormal=None, + include_nan=None, + nonnegative=None, + include_huge=None, ): fi = numpy.finfo(r.dtype) + size = r.size if nonnegative: if include_zero: assert r[0] == 0 @@ -30,27 +37,27 @@ def _check_real_samples( if include_nan: assert numpy.isnan(r[-1]) if include_infinity: - if include_huge and r.size > 9: + if include_huge and size > 9: assert numpy.nextafter(r[-4], numpy.inf, dtype=r.dtype) == fi.max assert r[-3] == fi.max assert numpy.isposinf(r[-2]) else: assert r[-2] == fi.max - if include_huge and r.size > 9: + if include_huge and size > 9: assert numpy.nextafter(r[-3], numpy.inf, dtype=r.dtype) == fi.max - for i in range(r.size - 2): + for i in range(size - 2): assert r[i] < r[i + 1] else: if include_infinity: - if include_huge and r.size > 8: + if include_huge and size > 8: assert numpy.nextafter(r[-3], numpy.inf, dtype=r.dtype) == fi.max assert r[-2] == fi.max assert numpy.isposinf(r[-1]) else: assert r[-1] == fi.max - if include_huge and r.size > 8: + if include_huge and size > 8: assert numpy.nextafter(r[-2], numpy.inf, dtype=r.dtype) == fi.max - for i in range(r.size - 1): + for i in range(size - 1): assert r[i] < r[i + 1] else: if include_infinity: @@ -74,20 +81,7 @@ def _check_real_samples( for i in range(r.size - 1): assert r[i] < r[i + 1] if include_zero: - size = r.size - if include_nan: - size -= 1 - if include_infinity: - if include_huge: - loc = (size - 1) // 2 - else: - loc = (size - 1) // 2 - else: - if include_huge: - loc = (size - 1) // 2 - else: - loc = (size - 1) // 2 - assert r[loc] == 0, (include_nan, include_infinity, include_huge) + loc = numpy.where(r == 0)[0][0] if include_subnormal: assert r[loc + 1] == fi.smallest_subnormal assert r[loc - 1] == -fi.smallest_subnormal @@ -97,9 +91,14 @@ def _check_real_samples( def _iter_samples_parameters(): - for include_huge, include_subnormal, include_infinity, include_zero, include_nan, nonnegative in itertools.product( - *(([False, True],) * 6) - ): + for ( + include_huge, + include_subnormal, + include_infinity, + include_zero, + include_nan, + nonnegative, + ) in itertools.product(*(([False, True],) * 6)): yield dict( include_huge=include_huge, include_subnormal=include_subnormal, diff --git a/functional_algorithms/utils.py b/functional_algorithms/utils.py index 7054cb2..15a93c8 100644 --- a/functional_algorithms/utils.py +++ b/functional_algorithms/utils.py @@ -579,6 +579,40 @@ def worker(ctx, s, e, r, v): assert 0 # unreachable +def extra_samples(name, dtype): + """Return a list of samples that are special to a given function. + + Parameters + ---------- + name: str + The name of a function + dtype: + Floating-point or complex dtype + + Returns + ------- + values: list + Values of function inputs. + """ + is_complex = "complex" in str(dtype) + is_float = "float" in str(dtype) + assert is_float or is_complex, dtype + values = [] + # Notice that real/complex_samples already include special values + # such as 0, -inf, inf, smallest subnormals or normals, so don't + # specify these here. + if is_float: + if name in {"acos", "asin"}: + for v in [-1, 1]: + values.append(numpy.nextafter(v, v - 1, dtype=dtype)) + values.append(v) + values.append(numpy.nextafter(v, v + 1, dtype=dtype)) + if is_complex: + if name == "absolute": + values.append(1.0011048e35 + 3.4028235e38j) + return numpy.array(values, dtype=dtype) + + def real_samples( size=10, dtype=numpy.float32, @@ -615,31 +649,52 @@ def real_samples( if isinstance(dtype, str): dtype = getattr(numpy, dtype) assert dtype in {numpy.float32, numpy.float64}, dtype + utype = {numpy.float32: numpy.uint32, numpy.float64: numpy.uint64}[dtype] fi = numpy.finfo(dtype) - start = fi.minexp + fi.negep + 1 if include_subnormal else fi.minexp - end = fi.maxexp num = size // 2 if not nonnegative else size if include_infinity: num -= 1 - with warnings.catch_warnings(action="ignore"): - finite_positive = numpy.logspace(start, end, base=2, num=num, dtype=dtype) + min_value = dtype(fi.smallest_subnormal if include_subnormal else fi.smallest_normal) + max_value = dtype(fi.max) + if 1: + # The following method gives a sample distibution that is + # uniform with respect to ULP distance between positive + # neighboring samples + finite_positive = numpy.linspace(min_value.view(utype), max_value.view(utype), num=num, dtype=utype).view(dtype) + else: + start = fi.minexp + fi.negep + 1 if include_subnormal else fi.minexp + end = fi.maxexp + with warnings.catch_warnings(action="ignore"): + # Note that logspace gives samples distribution that is + # approximately uniform with respect to ULP distance between + # neighboring normal samples. For subnormal samples, logspace + # produces repeated samples that will be eliminated below via + # numpy.unique. + finite_positive = numpy.logspace(start, end, base=2, num=num, dtype=dtype) + finite_positive[-1] = max_value + + if include_huge and num > 3: + huge = -numpy.nextafter(-max_value, numpy.inf, dtype=dtype) + finite_positive[-2] = huge - finite_positive[-1] = fi.max - if include_huge and size > 7: - finite_positive[-2] = -numpy.nextafter(-fi.max, numpy.inf, dtype=dtype) parts = [] + extra = [] if not nonnegative: if include_infinity: - parts.append(numpy.array([-numpy.inf], dtype=dtype)) + extra.append(-numpy.inf) parts.append(-finite_positive[::-1]) if include_zero: - parts.append(numpy.array([0], dtype=dtype)) + extra.append(0) parts.append(finite_positive) if include_infinity: - parts.append(numpy.array([numpy.inf], dtype=dtype)) + extra.append(numpy.inf) if include_nan: - parts.append(numpy.array([numpy.nan], dtype=dtype)) - return numpy.concatenate(parts) + extra.append(numpy.nan) + parts.append(numpy.array(extra, dtype=dtype)) + + # Using unique because logspace produces repeated subnormals when + # size is large + return numpy.unique(numpy.concatenate(parts)) def complex_samples( @@ -777,8 +832,8 @@ def complex_pair_samples( include_zero=include_zero, include_subnormal=include_subnormal, include_nan=include_nan, - nonnegative=nonnegative, include_huge=include_huge, + nonnegative=nonnegative, ) s2 = complex_samples( size=size[1], @@ -787,8 +842,8 @@ def complex_pair_samples( include_zero=include_zero, include_subnormal=include_subnormal, include_nan=include_nan, - nonnegative=nonnegative, include_huge=include_huge, + nonnegative=nonnegative, ) shape1 = s1.shape shape2 = s2.shape diff --git a/results/README.md b/results/README.md index 8fa74ac..4446ed4 100644 --- a/results/README.md +++ b/results/README.md @@ -10,25 +10,25 @@ MPMath functions using multi-precision arithmetic. | -------- | ----- | ------------- | ----- | ----- | ----- | ----- | --------- | | absolute | float32 | 1000001 | - | - | - | - | - | | absolute | float64 | 1000001 | - | - | - | - | - | -| absolute | complex64 | 977109 | 24892 | - | - | - | - | -| absolute | complex128 | 989613 | 12372 | 16 | - | - | - | -| acos | float32 | 504842 | 488335 | 6599 | 225 | - | - | -| acos | float64 | 986273 | 12357 | 1366 | 5 | - | - | -| acos | complex64 | 804838 | 196477 | 680 | 6 | - | - | -| acos | complex128 | 701506 | 300255 | 238 | 2 | - | - | -| acosh | float32 | 989143 | 10829 | 29 | - | - | - | -| acosh | float64 | 947718 | 52275 | 8 | - | - | - | -| acosh | complex64 | 804838 | 196477 | 680 | 6 | - | - | -| acosh | complex128 | 701506 | 300255 | 238 | 2 | - | - | -| asin | float32 | 937353 | 61812 | 810 | 26 | - | - | -| asin | float64 | 983317 | 16662 | 22 | - | - | - | -| asin | complex64 | 808487 | 191878 | 1592 | 44 | - | - | -| asin | complex128 | 694103 | 307326 | 560 | 12 | - | - | -| asinh | float32 | 922791 | 77144 | 66 | - | - | - | -| asinh | float64 | 829637 | 170270 | 94 | - | - | - | -| asinh | complex64 | 808487 | 191878 | 1592 | 44 | - | - | -| asinh | complex128 | 694103 | 307326 | 560 | 12 | - | - | -| square | float32 | 997293 | 2708 | - | - | - | - | -| square | float64 | 999649 | 352 | - | - | - | - | -| square | complex64 | 974577 | 27424 | - | - | - | - | -| square | complex128 | 994505 | 7496 | - | - | - | - | +| absolute | complex64 | 967753 | 33696 | 552 | - | - | - | +| absolute | complex128 | 991753 | 10104 | 144 | - | - | - | +| acos | float32 | 548396 | 444291 | 7072 | 242 | - | - | +| acos | float64 | 985930 | 12727 | 1338 | 6 | - | - | +| acos | complex64 | 810108 | 191263 | 622 | 8 | - | - | +| acos | complex128 | 690209 | 311554 | 238 | - | - | - | +| acosh | float32 | 988269 | 11704 | 28 | - | - | - | +| acosh | float64 | 946246 | 53752 | 3 | - | - | - | +| acosh | complex64 | 810108 | 191263 | 622 | 8 | - | - | +| acosh | complex128 | 690209 | 311554 | 238 | - | - | - | +| asin | float32 | 974679 | 24368 | 942 | 12 | - | - | +| asin | float64 | 995197 | 4776 | 28 | - | - | - | +| asin | complex64 | 807415 | 193174 | 1320 | 92 | - | - | +| asin | complex128 | 687179 | 313978 | 844 | - | - | - | +| asinh | float32 | 916129 | 83790 | 82 | - | - | - | +| asinh | float64 | 825453 | 174482 | 66 | - | - | - | +| asinh | complex64 | 807415 | 193174 | 1320 | 92 | - | - | +| asinh | complex128 | 687179 | 313978 | 844 | - | - | - | +| square | float32 | 997347 | 2654 | - | - | - | - | +| square | float64 | 999593 | 408 | - | - | - | - | +| square | complex64 | 976809 | 25192 | - | - | - | - | +| square | complex128 | 995833 | 6168 | - | - | - | - | diff --git a/results/cpp/acos.cpp b/results/cpp/acos.cpp index 8143f16..b61266c 100644 --- a/results/cpp/acos.cpp +++ b/results/cpp/acos.cpp @@ -14,13 +14,19 @@ float acos_0(float z) { float one = 1; float add_one_z = (one) + (z); - return (2) * (std::atan2(std::sqrt(((one) - (z)) * (add_one_z)), add_one_z)); + return (((z) != (-1)) + ? ((2) * (std::atan2(std::sqrt(((one) - (z)) * (add_one_z)), + add_one_z))) + : (M_PI)); } double acos_1(double z) { double one = 1; double add_one_z = (one) + (z); - return (2) * (std::atan2(std::sqrt(((one) - (z)) * (add_one_z)), add_one_z)); + return (((z) != (-1)) + ? ((2) * (std::atan2(std::sqrt(((one) - (z)) * (add_one_z)), + add_one_z))) + : (M_PI)); } std::complex acos_2(std::complex z) { diff --git a/results/numpy/acos.py b/results/numpy/acos.py index e88753e..2926441 100644 --- a/results/numpy/acos.py +++ b/results/numpy/acos.py @@ -218,7 +218,11 @@ def acos_2(z: numpy.float64) -> numpy.float64: z = numpy.float64(z) one: numpy.float64 = numpy.float64(1) add_one_z: numpy.float64 = (one) + (z) - result = (numpy.float64(2)) * (numpy.arctan2(numpy.sqrt(((one) - (z)) * (add_one_z)), add_one_z)) + result = ( + ((numpy.float64(2)) * (numpy.arctan2(numpy.sqrt(((one) - (z)) * (add_one_z)), add_one_z))) + if ((z) != (numpy.float64(-1))) + else (numpy.float64(numpy.float64(numpy.pi))) + ) return result @@ -227,5 +231,9 @@ def acos_3(z: numpy.float32) -> numpy.float32: z = numpy.float32(z) one: numpy.float32 = numpy.float32(1) add_one_z: numpy.float32 = (one) + (z) - result = (numpy.float32(2)) * (numpy.arctan2(numpy.sqrt(((one) - (z)) * (add_one_z)), add_one_z)) + result = ( + ((numpy.float32(2)) * (numpy.arctan2(numpy.sqrt(((one) - (z)) * (add_one_z)), add_one_z))) + if ((z) != (numpy.float32(-1))) + else (numpy.float32(numpy.float32(numpy.pi))) + ) return result diff --git a/results/python/acos.py b/results/python/acos.py index 148e945..0fc3ee2 100644 --- a/results/python/acos.py +++ b/results/python/acos.py @@ -96,4 +96,4 @@ def acos_0(z: complex) -> complex: def acos_1(z: float) -> float: one: float = 1 add_one_z: float = (one) + (z) - return (2) * (math.atan2(math.sqrt(((one) - (z)) * (add_one_z)), add_one_z)) + return ((2) * (math.atan2(math.sqrt(((one) - (z)) * (add_one_z)), add_one_z))) if ((z) != (-1)) else (math.pi) diff --git a/results/stablehlo/acos.td b/results/stablehlo/acos.td index bdb1384..8d3051a 100644 --- a/results/stablehlo/acos.td +++ b/results/stablehlo/acos.td @@ -6,16 +6,23 @@ def : Pat<(acos_0 NonComplexElementType:$z), - (StableHLO_MulOp - (StableHLO_ConstantLike<"2"> $z), - (StableHLO_Atan2Op - (StableHLO_SqrtOp - (StableHLO_MulOp - (StableHLO_SubtractOp - (StableHLO_ConstantLike<"1">:$one $z), - $z), - (StableHLO_AddOp:$add_one_z $one, $z))), - $add_one_z))>; + (StableHLO_SelectOp + (StableHLO_CompareOp + $z, + (StableHLO_ConstantLike<"-1"> $z), + StableHLO_ComparisonDirectionValue<"NE">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_MulOp + (StableHLO_ConstantLike<"2"> $z), + (StableHLO_Atan2Op + (StableHLO_SqrtOp + (StableHLO_MulOp + (StableHLO_SubtractOp + (StableHLO_ConstantLike<"1">:$one $z), + $z), + (StableHLO_AddOp:$add_one_z $one, $z))), + $add_one_z)), + (StableHLO_ConstantLike<"M_PI"> $z))>; def : Pat<(acos_1 ComplexElementType:$z), (StableHLO_ComplexOp diff --git a/results/xla_client/acos.cc b/results/xla_client/acos.cc index 40a7293..dc3709e 100644 --- a/results/xla_client/acos.cc +++ b/results/xla_client/acos.cc @@ -13,8 +13,10 @@ template XlaOp acos_0(XlaOp z) { XlaOp one = ScalarLike(z, 1); XlaOp add_one_z = Add(one, z); - return Mul(ScalarLike(z, 2), - Atan2(Sqrt(Mul(Sub(one, z), add_one_z)), add_one_z)); + return Select(Ne(z, ScalarLike(z, -1)), + Mul(ScalarLike(z, 2), + Atan2(Sqrt(Mul(Sub(one, z), add_one_z)), add_one_z)), + ScalarLike(z, M_PI)); } template diff --git a/results/xla_client/real_acos.cc b/results/xla_client/real_acos.cc index 8035511..55495f0 100644 --- a/results/xla_client/real_acos.cc +++ b/results/xla_client/real_acos.cc @@ -13,6 +13,8 @@ template XlaOp real_acos_0(XlaOp x) { XlaOp one = ScalarLike(x, 1); XlaOp add_one_x = Add(one, x); - return Mul(ScalarLike(x, 2), - Atan2(Sqrt(Mul(Sub(one, x), add_one_x)), add_one_x)); + return Select(Ne(x, ScalarLike(x, -1)), + Mul(ScalarLike(x, 2), + Atan2(Sqrt(Mul(Sub(one, x), add_one_x)), add_one_x)), + ScalarLike(x, M_PI)); } \ No newline at end of file