Skip to content

Commit

Permalink
Add complex log1p (#21)
Browse files Browse the repository at this point in the history
* Add complex log1p

* Increase mpmath precision for log1p

* Add ulp-range test to function validation.
  • Loading branch information
pearu committed Aug 17, 2024
1 parent b8a67c5 commit d3be68a
Show file tree
Hide file tree
Showing 16 changed files with 772 additions and 99 deletions.
95 changes: 95 additions & 0 deletions functional_algorithms/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,3 +672,98 @@ def sqrt(ctx, z: complex | float):

def angle(ctx, z: complex):
return ctx.atan2(z.imag, z.real)


def kahan3(ctx, x1: float, x2: float, x3: float):
"""Kahan sum of three floating-point numbers"""
s = x1 + x2
c = (s - x1) - x2
y3 = x3 - c
return s + y3


def kahan4(ctx, x1: float, x2: float, x3: float, x4: float):
"""Kahan sum of four floating-point numbers"""
t2 = x1 + x2
c2 = (t2 - x1) - x2
y3 = x3 - c2
t3 = t2 + y3
c3 = (t3 - t2) - y3
y4 = x4 - c3
return t3 + y4


def fma(ctx, x, y, z):
"""Evaluate x * y + z"""
return x * y + z


def complex_log1p(ctx, z: complex):
"""Logarithm of 1 + z on complex input:
log1p(x + I * y) = 0.5 * log((x+1)**2 + y**2) + I * arctan2(y, x + 1)
where
x and y are real and imaginary parts of the input to log1p, and
I is imaginary unit.
Let's define
mx = max(abs(x + 1), abs(y))
mn = min(abs(x + 1), abs(y))
then the real part of the complex log1p value reads
real(log(x + I * y)) = log(hypot(x + 1, y))
= log(mx * sqrt(1 + (mn / mx) ** 2))
= log(mx) + 0.5 * log1p((mn / mx) ** 2)
where
log(mx) = log(max(abs(x + 1), abs(y)))
= log1p(max(abs(x + 1) - 1, abs(y) - 1))
= log1p(select(x + 1 >= abs(y), x, mx - 1))
To handle mn == mx == inf case, we'll use
log1p((mn / mx) ** 2) = log1p(select(mn == mx, one, (mn / mx) ** 2))
Problematic regions
-------------------
Notice that when abs(y) < 1 and abs(x + 0.5 * y ** 2) is small,
catastrophic cancellation errors occur in evaluating the real part
of complex log1p:
log(hypot(x + 1, y)) = 0.5 * log1p(2 * x + y * y + x * x)
where the magnitude of the correct value `x * x` is smaller than
the rounding errors occurring from addition `2 * x + y * y`,
especially when using FP32. A similar phenomenon is expected when
x is close to -2 and abs(y) is small so that rounding errors from
`2 * x + x ** 2` dominate over `y ** 2`.
"""
x = z.real
y = z.imag
one = ctx.constant(1, x)
half = ctx.constant(0.5, x)
xp1 = x + one
axp1 = abs(xp1)
ay = abs(y)
mx = ctx.maximum(axp1, ay)
mn = ctx.minimum(axp1, ay)
r = mn / mx
re = ctx.log1p(ctx.select(xp1 >= ay, x, mx - one)) + half * ctx.log1p(ctx.select(ctx.eq(mn, mx), one, r * r))
im = ctx.atan2(y, xp1)
return ctx(ctx.complex(re, im))


def log1p(ctx, z: complex | float):
"""log(1 + z)
See complex_log1p for more information.
"""
if z.is_complex:
return complex_log1p(ctx, z)
return ctx.log1p(z)
21 changes: 20 additions & 1 deletion functional_algorithms/expr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import struct
import warnings
from .utils import UNSPECIFIED
from . import algorithms
Expand Down Expand Up @@ -96,6 +97,24 @@ def normalize(context, operands):
return tuple(new_operands)


def toidentifier(value):
if isinstance(value, bool):
return str(value)
elif isinstance(value, int):
if value < 0:
return "neg" + str(-value)
return str(value)
elif isinstance(value, float):
return "f" + hex(struct.unpack("<I", struct.pack("<f", value))[0])[1:]
elif isinstance(value, complex):
return "c" + toidentifier(value.real) + toidentifier(value.imag)
elif isinstance(value, str):
assert value.isidentifier(), value
return value
else:
raise NotImplementedError(type(value))


def make_ref(expr):
ref = expr.props.get("ref", UNSPECIFIED)
if ref is not UNSPECIFIED:
Expand All @@ -105,7 +124,7 @@ def make_ref(expr):
if expr.kind == "constant":
if isinstance(expr.operands[0], Expr):
return f"{expr.kind}_{make_ref(expr.operands[0])}"
return f"{expr.kind}_{expr.operands[0]}"
return f"{expr.kind}_{toidentifier(expr.operands[0])}"
lst = [expr.kind] + list(map(make_ref, expr.operands))
return "_".join(lst)

Expand Down
1 change: 1 addition & 0 deletions functional_algorithms/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
trace_arguments = dict(
square=[(":float32",), (":float64",), (":complex64",), (":complex128",)],
absolute=[(":float32",), (":float64",), (":complex64",), (":complex128",)],
log1p=[(":float32",), (":float64",), (":complex64",), (":complex128",)],
hypot=[(":float32", ":float32"), (":float64", ":float64")],
asin_acos_kernel=[(":complex64",), (":complex128",)],
acos=[(":float32",), (":float64",), (":complex64",), (":complex128",)],
Expand Down
1 change: 1 addition & 0 deletions functional_algorithms/targets/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def make_complex(r, i):
acosh=[(":complex128",), (":complex64",), (":float64",), (":float32",)],
asin=[(":complex128",), (":complex64",), (":float64",), (":float32",)],
asinh=[(":complex128",), (":complex64",), (":float64",), (":float32",)],
log1p=[(":complex128",), (":complex64",), (":float64",), (":float32",)],
hypot=[(":float32", ":float32"), (":float64", ":float64")],
square=[(":complex128",), (":complex64",), (":float64",), (":float32",)],
)
Expand Down
1 change: 1 addition & 0 deletions functional_algorithms/targets/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
asinh=[(":complex",), (":float",)],
hypot=[(":float", ":float")],
square=[(":float",), (":complex",)],
log1p=[(":float",), (":complex",)],
)


Expand Down
1 change: 1 addition & 0 deletions functional_algorithms/targets/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
acosh=[(":float",), (":complex",)],
asin=[(":float",), (":complex",)],
asinh=[(":float",), (":complex",)],
log1p=[(":float",), (":complex",)],
hypot=[(":float", ":float")],
square=[(":float",), (":complex",)],
)
Expand Down
1 change: 1 addition & 0 deletions functional_algorithms/targets/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
acosh=[(":float", ":float"), (":complex", ":complex")],
asin=[(":float", ":float"), (":complex", ":complex")],
asinh=[(":float", ":float"), (":complex", ":complex")],
complex_log1p=[(":complex", ":complex")],
)

source_file_extension = ".cc"
Expand Down
46 changes: 36 additions & 10 deletions functional_algorithms/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def func_name(request):
return request.param


@pytest.fixture(scope="function", params=["absolute", "acos", "acosh", "asin", "asinh", "square", "sqrt", "angle"])
@pytest.fixture(scope="function", params=["absolute", "acos", "acosh", "asin", "asinh", "square", "sqrt", "angle", "log1p"])
def unary_func_name(request):
return request.param

Expand All @@ -48,15 +48,19 @@ def test_unary(dtype_name, unary_func_name, flush_subnormals):

func = targets.numpy.as_function(graph2, debug=0)

if unary_func_name in {"acos", "asin", "asinh", "acosh"}:
extra_prec_multiplier = 20
else:
extra_prec_multiplier = 1
params = utils.function_validation_parameters(unary_func_name, dtype_name)
max_valid_ulp_count = params["max_valid_ulp_count"]
max_bound_ulp_width = params["max_bound_ulp_width"]
extra_prec_multiplier = params["extra_prec_multiplier"]

reference = getattr(
utils.numpy_with_mpmath(extra_prec_multiplier=extra_prec_multiplier, flush_subnormals=flush_subnormals),
unary_func_name,
)

# samples consist of log-uniform grid of the complex plane plus
# any extra samples that cover the special regions for the given
# function.
size = 51
if dtype in {numpy.complex64, numpy.complex128}:
samples = utils.complex_samples(
Expand All @@ -74,19 +78,41 @@ def test_unary(dtype_name, unary_func_name, flush_subnormals):

samples = numpy.concatenate((samples, utils.extra_samples(unary_func_name, dtype)))

matches_with_reference, ulp_stats = utils.validate_function(
func, reference, samples, dtype, flush_subnormals=flush_subnormals
matches_with_reference, stats = utils.validate_function(
func,
reference,
samples,
dtype,
flush_subnormals=flush_subnormals,
max_valid_ulp_count=max_valid_ulp_count,
max_bound_ulp_width=max_bound_ulp_width,
)
if not matches_with_reference:
print("Samples:")
gt3_ulp_total = 0
gt3_outrange = 0
ulp_stats = stats["ulp"]
for ulp in sorted(ulp_stats):
if ulp in {0, 1, 2, 3}:
print(f" dULP={ulp}: {ulp_stats[ulp]}")
if ulp >= 0 and ulp <= max_valid_ulp_count:
outrange = stats["outrange"][ulp]
if outrange:
print(f" dULP={ulp}: {ulp_stats[ulp]} ({outrange=})")
else:
print(f" dULP={ulp}: {ulp_stats[ulp]}")
elif ulp > 0:
gt3_ulp_total += ulp_stats[ulp]
gt3_outrange += stats["outrange"][ulp]
elif ulp == -1:
c = ulp_stats[ulp]
if c:
print(f" total number of mismatches: {c}")
else:
assert 0, ulp # unreachable
else:
print(f" dULP>3: {gt3_ulp_total}")
if gt3_outrange:
print(f" dULP>{max_valid_ulp_count}: {gt3_ulp_total} (outrange={gt3_outrange})")
else:
print(f" dULP>{max_valid_ulp_count}: {gt3_ulp_total}")

assert matches_with_reference # warning: also reference may be wrong

Expand Down
Loading

0 comments on commit d3be68a

Please sign in to comment.