Skip to content

Commit

Permalink
Add acosh. (#12)
Browse files Browse the repository at this point in the history
* Add acosh.

* Fix rebase
  • Loading branch information
pearu committed Jun 24, 2024
1 parent c788ffd commit df06098
Show file tree
Hide file tree
Showing 20 changed files with 1,208 additions and 21 deletions.
67 changes: 61 additions & 6 deletions functional_algorithms/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,6 @@ def real_asinh(ctx, x: float):
This algorithm is based on the StableHLO v1.1.4 function CHLO_AsinhOp.
To avoid overflow in x * x, we use
asinh(x) = log(2) + log(x)
when abs(x) > sqrt(max),
To avoid underflow in 1 + x * x, we'll define z = hypot(1, x) and
write
Expand All @@ -360,6 +354,12 @@ def real_asinh(ctx, x: float):
It turns out, this is accurate for all abs(x) < sqrt(max).
To avoid overflow in x ** 2, we'll use
asinh(x) = log(2) + log(x)
when abs(x) > sqrt(max),
For x < 0, we'll use
asinh(x) = -asinh(-x)
Expand Down Expand Up @@ -512,3 +512,58 @@ def acos(ctx, z: complex | float):
if z.is_complex:
return complex_acos(ctx, z)
return real_acos(ctx, z)


def complex_acosh(ctx, z: complex):
"""Inverse hyperbolic cosine on complex input:
acosh(z) = sqrt(z - 1) / sqrt(1 - z) * acos(z)
= I * acos(z) # when z.imag >= 0
= -I * acos(z) # otherwise
"""
w = complex_acos(ctx, z)
r = ctx.complex(-w.imag, w.real)
return ctx.select(z.imag < 0, -r, r)


def real_acosh(ctx, x: float):
"""Inverse hyperbolic cosine on real input:
acosh(x) = log(x + sqrt(x * x - 1))
= log(x + sqrt(x+1)*sqrt(x-1)))
= log(1 + x-1 + sqrt(x+1)*sqrt(x-1)))
= log1p(sqrt(x-1) * (sqrt(x+1) + sqrt(x-1)))
The last expression avoids errors from cancellations when x is
close to one. This also ensures the nan result when x < 1 because
sqrt(x') returns nan when x' < 0.
To avoid overflow in multiplication for large x (x > max / 2),
we'll use
acosh(x) = log(2) + log(x)
"""
one = ctx.constant(1, x)
two = ctx.constant(2, x)
sqxm1 = ctx.sqrt(x - one)
sqxp1 = ctx.sqrt(x + one)
a0 = ctx.log(two) + ctx.log(x)
a1 = ctx.log1p(sqxm1 * (sqxp1 + sqxm1))

safe_max_limit_coefficient = ctx.parameters.get("safe_max_limit_coefficient", None)
if safe_max_limit_coefficient is None:
safe_max_limit = ctx.constant("largest", x) / 2
else:
safe_max_limit = ctx.constant("largest", x) * safe_max_limit_coefficient
return ctx.select(x >= safe_max_limit, a0, a1)


def acosh(ctx, z: complex | float):
"""Inverse hyperbolic cosine on complex and real inputs.
See complex_acosh and real_acosh for more information.
"""
if z.is_complex:
return complex_acosh(ctx, z)
return real_acosh(ctx, z)
4 changes: 2 additions & 2 deletions functional_algorithms/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def ceil(self, x):
def floor(self, x):
return Expr(self, "floor", (x,))

def copysign(self, x):
return Expr(self, "copysign", (x,))
def copysign(self, x, y):
return Expr(self, "copysign", (x, y))

def round(self, x):
return Expr(self, "round", (x,))
Expand Down
5 changes: 5 additions & 0 deletions functional_algorithms/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,9 @@ def real(self):
def imag(self):
return self.context.imag(self)

def conj(self):
return self.context.conj(self)

def __lt__(self, other):
return self.context.lt(self, other)

Expand Down Expand Up @@ -653,6 +656,8 @@ def get_type(self):
"floor",
"logical_not",
"sign",
"copysign",
"conj",
}:
return self.operands[0].get_type()
elif self.kind in {
Expand Down
18 changes: 18 additions & 0 deletions functional_algorithms/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,24 @@ def negative(self, expr):
if x.kind == "negative":
return x.operands[0]

def conj(self, expr):

(x,) = expr.operands

if x.kind == "constant":
value, like = x.operands
if isinstance(value, (int, float)):
return x
if isinstance(value, complex):
return x.context.constant(value.conjugate(), like)

if x.kind == "complex":
real, imag = x.operands
return x.context.complex(real, -imag)

if x.kind == "conj":
return x

def real(self, expr):

(x,) = expr.operands
Expand Down
1 change: 1 addition & 0 deletions functional_algorithms/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
absolute=[(":float32",), (":float64",), (":complex64",), (":complex128",)],
hypot=[(":float32", ":float32"), (":float64", ":float64")],
acos=[(":float32",), (":float64",), (":complex64",), (":complex128",)],
acosh=[(":float32",), (":float64",), (":complex64",), (":complex128",)],
asin=[(":float32",), (":float64",), (":complex64",), (":complex128",)],
asinh=[(":float32",), (":float64",), (":complex64",), (":complex128",)],
# complex_asin=[(":complex", ":complex")],
Expand Down
6 changes: 4 additions & 2 deletions functional_algorithms/targets/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def make_complex(r, i):
trace_arguments = dict(
absolute=[(":complex128",), (":complex64",)],
acos=[(":complex128",), (":complex64",), (":float64",), (":float32",)],
acosh=[(":complex128",), (":complex64",), (":float64",), (":float32",)],
asin=[(":complex128",), (":complex64",), (":float64",), (":float32",)],
asinh=[(":complex128",), (":complex64",), (":float64",), (":float32",)],
hypot=[(":float32", ":float32"), (":float64", ":float64")],
Expand Down Expand Up @@ -86,7 +87,7 @@ def make_comment(message):
log10="numpy.log10({0})",
ceil="numpy.ceil({0})",
floor="numpy.floor({0})",
copysign="numpy.copysign({0})",
copysign="numpy.copysign({0}, {1})",
round=NotImplemented,
sign="numpy.sign({0})",
trunc="numpy.trunc({0})",
Expand All @@ -111,7 +112,8 @@ def make_comment(message):
largest="numpy.finfo({type}).max",
posinf="{type}(numpy.inf)",
neginf="-{type}(numpy.inf)",
pi="numpy.pi",
pi="{type}(numpy.pi)",
nan="{type}(numpy.nan)",
)

type_to_target = dict(
Expand Down
1 change: 1 addition & 0 deletions functional_algorithms/targets/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
trace_arguments = dict(
absolute=[(":complex",)],
acos=[(":complex",), (":float",)],
acosh=[(":complex",), (":float",)],
asin=[(":complex",), (":float",)],
asinh=[(":complex",), (":float",)],
hypot=[(":float", ":float")],
Expand Down
7 changes: 4 additions & 3 deletions functional_algorithms/targets/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
trace_arguments = dict(
absolute=[(":complex",)],
acos=[(":float",), (":complex",)],
acosh=[(":float",), (":complex",)],
asin=[(":float",), (":complex",)],
asinh=[(":float",), (":complex",)],
hypot=[(":float", ":float")],
Expand Down Expand Up @@ -55,10 +56,10 @@ def make_comment(message):
bitwise_right_shift=NotImplemented,
maximum="StableHLO_MaxOp",
minimum="StableHLO_MinOp",
acos=NotImplemented,
acosh=NotImplemented,
acos="CHLO_AcosOp",
acosh="CHLO_AcoshOp",
asin="CHLO_AsinOp",
asinh=NotImplemented,
asinh="CHLO_AsinhOp",
atan=NotImplemented,
atanh=NotImplemented,
atan2="StableHLO_Atan2Op",
Expand Down
3 changes: 3 additions & 0 deletions functional_algorithms/targets/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
hypot=[(":float", ":float")],
complex_acos=[(":complex", ":complex")],
real_acos=[(":float", ":float")],
complex_acosh=[(":complex", ":complex")],
real_acosh=[(":float", ":float")],
complex_asin=[(":complex", ":complex")],
real_asin=[(":float", ":float")],
complex_asinh=[(":complex", ":complex")],
real_asinh=[(":float", ":float")],
square=[(":float", ":float"), (":complex", ":complex")],
acos=[(":float", ":float"), (":complex", ":complex")],
acosh=[(":float", ":float"), (":complex", ":complex")],
asin=[(":float", ":float"), (":complex", ":complex")],
asinh=[(":float", ":float"), (":complex", ":complex")],
)
Expand Down
13 changes: 8 additions & 5 deletions functional_algorithms/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def dtype_name(request):
return request.param


@pytest.fixture(scope="function", params=["absolute", "acos", "asin", "asinh", "hypot", "square"])
@pytest.fixture(scope="function", params=["absolute", "acos", "acosh", "asin", "asinh", "hypot", "square"])
def func_name(request):
return request.param


@pytest.fixture(scope="function", params=["absolute", "acos", "asin", "asinh", "square"])
@pytest.fixture(scope="function", params=["absolute", "acos", "acosh", "asin", "asinh", "square"])
def unary_func_name(request):
return request.param

Expand All @@ -45,7 +45,7 @@ def test_unary(dtype_name, unary_func_name, flush_subnormals):

func = targets.numpy.as_function(graph2, debug=0)

if unary_func_name in {"acos", "asin", "asinh"}:
if unary_func_name in {"acos", "asin", "asinh", "acosh"}:
extra_prec_multiplier = 20
else:
extra_prec_multiplier = 1
Expand All @@ -54,10 +54,13 @@ def test_unary(dtype_name, unary_func_name, flush_subnormals):
unary_func_name,
)

size = 31
size = 51
if dtype in {numpy.complex64, numpy.complex128}:
samples = utils.complex_samples(
(size, size), dtype=dtype, include_huge=True, include_subnormal=not flush_subnormals
(size, size),
dtype=dtype,
include_huge=True,
include_subnormal=not flush_subnormals,
).flatten()
else:
samples = utils.real_samples(size * size, dtype=dtype, include_subnormal=not flush_subnormals).flatten()
Expand Down
31 changes: 29 additions & 2 deletions functional_algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,33 @@ def arccos(self, x):
return ctx.nan
return ctx.acos(x)

def arccosh(self, x):
ctx = x.context

if isinstance(x, ctx.mpc):
# Workaround mpmath 1.3 bug in acosh(+-inf+-infj) evaluation
# (see mpmath/mpmath#749).
pi = ctx.pi
inf = ctx.inf
zero = ctx.zero
if ctx.isinf(x.real):
sign_imag = -1 if x.imag < 0 else 1
imag = (
(3 if x.real < 0 else 1) * sign_imag * pi / 4
if ctx.isinf(x.imag)
else (sign_imag * pi if x.real < 0 else zero)
)
return ctx.make_mpc((inf._mpf_, imag._mpf_))
elif ctx.isinf(x.imag):
sign_imag = -1 if x.imag < 0 else 1
imag = sign_imag * pi / 2
return ctx.make_mpc((inf._mpf_, imag._mpf_))
else:
if x < 1:
# otherwise, mpmath.acosh would return complex value
return ctx.nan
return ctx.acosh(x)


class numpy_with_mpmath:
"""Namespace of universal functions on numpy arrays that use mpmath
Expand All @@ -482,7 +509,7 @@ def __init__(self, **params):
self.params = params

def __getattr__(self, name):
name = dict(asinh="arcsinh", acos="arccos", asin="arcsin").get(name, name)
name = dict(asinh="arcsinh", acos="arccos", asin="arcsin", acosh="arccosh").get(name, name)
if name in self._vfunc_cache:
return self._vfunc_cache[name]
if hasattr(mpmath_array_api, name):
Expand Down Expand Up @@ -847,7 +874,7 @@ def validate_function(
else:
v1 = func(sample)
v2 = reference_results[index][()]
assert v1.dtype == v2.dtype, (v1, v2)
assert v1.dtype == v2.dtype, (sample, v1, v2)
ulp = diff_ulp(v1, v2, flush_subnormals=flush_subnormals)
ulp_stats[ulp] += 1
if ulp > 2 and verbose:
Expand Down
4 changes: 4 additions & 0 deletions results/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ MPMath functions using multi-precision arithmetic.
| acos | float64 | 986273 | 12357 | 1366 | 5 | - | - |
| acos | complex64 | 804838 | 196477 | 680 | 6 | - | - |
| acos | complex128 | 701506 | 300255 | 238 | 2 | - | - |
| acosh | float32 | 989143 | 10829 | 29 | - | - | - |
| acosh | float64 | 947718 | 52275 | 8 | - | - | - |
| acosh | complex64 | 804838 | 196477 | 680 | 6 | - | - |
| acosh | complex128 | 701506 | 300255 | 238 | 2 | - | - |
| asin | float32 | 937353 | 61812 | 810 | 26 | - | - |
| asin | float64 | 983317 | 16662 | 22 | - | - | - |
| asin | complex64 | 808487 | 191878 | 1592 | 44 | - | - |
Expand Down
Loading

0 comments on commit df06098

Please sign in to comment.