Skip to content

Commit

Permalink
Introduce use_upcast_<function> context parameter. Add native tan/tan…
Browse files Browse the repository at this point in the history
…h. (#31)

* Introduce use_upcast_<function> context parameter.
* Add native tan/tanh testing support.
  • Loading branch information
pearu authored Sep 5, 2024
1 parent ca34113 commit 23ae022
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 25 deletions.
33 changes: 30 additions & 3 deletions functional_algorithms/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,41 @@ def wrapper(ctx, *args, **kwargs):
raise NotImplementedError(
f"definition for {domain} {self.native_func_name} is not provided in algorithms"
)
return defn(ctx, *args, **kwargs)

if ctx.parameters.get(f"use_upcast_{self.native_func_name}", False):
ctx.parameters["using"].add(f"upcast {self.native_func_name}")
assert len(args) == 1, args
args = (ctx.upcast(args[0]),)

result = defn(ctx, *args, **kwargs)

if ctx.parameters.get(f"use_upcast_{self.native_func_name}", False):
result = ctx.downcast(result)

return result

return wrapper

@functools.wraps(func)
def wrapper(ctx, *args, **kwargs):
if ctx.parameters.get(f"use_native_{self.native_func_name}", False):
ctx.parameters["using"].add(f"native {self.native_func_name}")
return getattr(ctx, self.native_func_name)(*args, **kwargs)
func_ = getattr(type(ctx), self.native_func_name)
else:
func_ = func

if ctx.parameters.get(f"use_upcast_{self.native_func_name}", False):
ctx.parameters["using"].add(f"upcast {self.native_func_name}")
assert len(args) == 1, args
args = (ctx.upcast(args[0]),)

result = func(ctx, *args, **kwargs)
result = func_(ctx, *args, **kwargs)
if result is NotImplemented:
raise NotImplementedError(f"{self.native_func_name} not implemented for {self.domain} domain: {func.__name__}")

if ctx.parameters.get(f"use_upcast_{self.native_func_name}", False):
return ctx.downcast(result)

return result

self.registry[self.native_func_name] = wrapper
Expand Down Expand Up @@ -1114,6 +1136,11 @@ def atan(ctx, z: complex | float):
assert 0 # unreachable


@definition("tan", domain="real")
def real_tan(ctx, z: complex | float):
return NotImplemented


@definition("tan")
def tan(ctx, z: complex | float):
"""Tangent on complex and real inputs.
Expand Down
5 changes: 4 additions & 1 deletion functional_algorithms/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,10 @@ def implement_missing(self, target):
paths = ":".join([m.__name__ for m in self.context._paths])
raise NotImplementedError(f'{self.kind} for {target.__name__.split(".")[-1]} target [paths={paths}]')

result = self.context.call(func, self.operands).implement_missing(target)
result = self.context.call(func, self.operands)
if self._serialized == result._serialized:
return self
result = result.implement_missing(target)
else:
operands = tuple([operand.implement_missing(target) for operand in self.operands])
for o1, o2 in zip(operands, self.operands):
Expand Down
4 changes: 2 additions & 2 deletions functional_algorithms/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ def dtype_name(request):

@pytest.fixture(
scope="function",
params=["absolute", "acos", "acosh", "asin", "asinh", "hypot", "square", "sqrt", "angle", "atan", "atanh"],
params=["absolute", "acos", "acosh", "asin", "asinh", "hypot", "square", "sqrt", "angle", "atan", "atanh", "tan", "tanh"],
)
def func_name(request):
return request.param


@pytest.fixture(
scope="function",
params=["absolute", "acos", "acosh", "asin", "asinh", "square", "sqrt", "angle", "log1p", "atan", "atanh"],
params=["absolute", "acos", "acosh", "asin", "asinh", "square", "sqrt", "angle", "log1p", "atan", "atanh", "tan", "tanh"],
)
def unary_func_name(request):
return request.param
Expand Down
31 changes: 19 additions & 12 deletions functional_algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,25 +375,30 @@ def log1p(self, x):
def tan(self, x):
ctx = x.context
if isinstance(x, ctx.mpc):
# Workaround mpmath 1.3 bug in tan(+-inf+-infj) evaluation (see mpmath/mpmath#781).
if ctx.isinf(x.imag) and (ctx.isinf(x.real) or ctx.isfinite(x.real)):
if x.imag > 0:
return ctx.make_mpc((ctx.zero._mpf_, ctx.one._mpf_))
return ctx.make_mpc((ctx.zero._mpf_, (-ctx.one)._mpf_))
if ctx.isinf(x.real) and ctx.isfinite(x.imag):
return ctx.make_mpc((ctx.nan._mpf_, ctx.nan._mpf_))
if not (ctx.isfinite(x.real) and ctx.isfinite(x.imag)):
# tan(z) = -i * std::tanh(i * z)
ix = ctx.make_mpc(((-x.imag)._mpf_, x.real._mpf_))
w = self.tanh(ix)
return ctx.make_mpc((w.imag._mpf_, (-w.real)._mpf_))
return ctx.tan(x)

def tanh(self, x):
ctx = x.context
if isinstance(x, ctx.mpc):
# Workaround mpmath 1.3 bug in tanh(+-inf+-infj) evaluation (see mpmath/mpmath#781).
if ctx.isinf(x.imag) and (ctx.isinf(x.real) or ctx.isfinite(x.real)):
if x.imag > 0:
return ctx.make_mpc((ctx.zero._mpf_, ctx.one._mpf_))
return ctx.make_mpc((ctx.zero._mpf_, (-ctx.one)._mpf_))
if ctx.isinf(x.real) and ctx.isfinite(x.imag):
if ctx.isfinite(x.real) and not ctx.isfinite(x.imag):
if x.real == 0:
return ctx.make_mpc((ctx.zero._mpf_, ctx.nan._mpf_))
return ctx.make_mpc((ctx.nan._mpf_, ctx.nan._mpf_))
elif ctx.isinf(x.real):
if x.real >= 0:
return ctx.make_mpc((ctx.one._mpf_, ctx.zero._mpf_))
return ctx.make_mpc(((-ctx.one)._mpf_, ctx.zero._mpf_))
elif ctx.isnan(x.real):
if x.imag == 0:
return ctx.make_mpc((ctx.nan._mpf_, ctx.zero._mpf_))
return ctx.make_mpc((ctx.nan._mpf_, ctx.nan._mpf_))

return ctx.tanh(x)

def log2(self, x):
Expand Down Expand Up @@ -1297,6 +1302,8 @@ def function_validation_parameters(func_name, dtype):
max_bound_ulp_width = dict(complex64=3, complex128=3).get(dtype_name, max_bound_ulp_width)
elif func_name in {"atanh", "atan"}:
extra_prec_multiplier = 20
elif func_name in {"tanh", "tan"}:
extra_prec_multiplier = 20
return dict(
extra_prec_multiplier=extra_prec_multiplier,
max_valid_ulp_count=max_valid_ulp_count,
Expand Down
15 changes: 11 additions & 4 deletions results/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ limit in general: there may exist function-function dependent regions
in complex plane where `ulp_width` needs to be larger to pass the
"out-of-ulp-range counts is zero" test.

Finally, "using native <function?" means "using the corresponding
numpy <function>".
Finally,
- "using native <function>" means "using the corresponding numpy <function>",
- "using upcast <function>" means that the function arguments are
upcasted to a dtype with bits doubled, and the function results are
downcasted to a dtype with bits split half.

| Function | dtype | dULP=0 (exact) | dULP=1 | dULP=2 | dULP=3 | dULP>3 | Notes |
| -------- | ----- | -------------- | ------ | ------ | ------ | ------ | ----- |
Expand Down Expand Up @@ -91,6 +94,10 @@ numpy <function>".
| log1p<sup>3</sup> | complex64 | 902287 | 97840<sup>41454</sup> | 1582<sup>44</sup> | 102 | 190 | - |
| log1p<sup>1</sup> | complex128 | 801864 | 200067<sup>188447</sup> | 64<sup>10</sup> | 6 | - | - |
| tan | float32 | 866723 | 132062 | 1168 | 48 | - | using native tan |
| tan | complex64 | 817727 | 159602 | 21188 | 2958 | 526 | using upcast tan |
| tan<sub>2</sub> | float32 | 1000001 | - | - | - | - | using upcast tan, native tan |
| tan | complex64 | 783679 | 197584 | 19902 | 776 | 60 | using native tan |
| tan<sub>2</sub> | complex64 | 1001417 | 584 | - | - | - | using upcast tan, native tan |
| tanh | float32 | 985109 | 14892 | - | - | - | using native tanh |
| tanh | complex64 | 779679 | 197584 | 19902 | 776 | 4060 | using native tanh |
| tanh<sub>2</sub> | float32 | 1000001 | - | - | - | - | using native tanh, upcast tanh |
| tanh | complex64 | 783679 | 197584 | 19902 | 776 | 60 | using native tanh |
| tanh<sub>2</sub> | complex64 | 1001417 | 584 | - | - | - | using native tanh, upcast tanh |
12 changes: 9 additions & 3 deletions results/estimate_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,19 @@ def get_inputs():
("log1p", np.complex128, {}),
# ("tan", np.float32, dict()), # real_tan is not implemented
("tan", np.float32, dict(use_native_tan=True)),
# ("tan", np.float32, dict(use_upcast_tan=True)),
("tan", np.float32, dict(use_native_tan=True, use_upcast_tan=True)),
# ("tan", np.float64, {}), # real_tan is not implemented
# ("tan", np.complex64, {}), # tan is not implemented
("tan", np.complex64, dict(use_native_tan=True)),
("tan", np.complex64, dict(use_native_tan=True, use_upcast_tan=True)),
# ("tan", np.complex128, {}), # tan is not implemented
("tanh", np.float32, dict(use_native_tanh=True)),
("tanh", np.float32, dict(use_native_tanh=True, use_upcast_tanh=True)),
# ("tanh", np.float64, {}), # real_tanh is not implemented
# ("tanh", np.complex64, {}), # tanh is not implemented
# ("tanh", np.complex64, dict(use_upcast_tan=True, use_upcast_tanh=True, use_upcast_cos=True)),
("tanh", np.complex64, dict(use_native_tanh=True)),
("tanh", np.complex64, dict(use_native_tanh=True, use_upcast_tanh=True)),
# ("tanh", np.complex128, {}), # tanh is not implemented
]:
validation_parameters = fa.utils.function_validation_parameters(func_name, dtype)
Expand Down Expand Up @@ -190,8 +193,11 @@ def main():
in complex plane where `ulp_width` needs to be larger to pass the
"out-of-ulp-range counts is zero" test.
Finally, "using native <function?" means "using the corresponding
numpy <function>".
Finally,
- "using native <function>" means "using the corresponding numpy <function>",
- "using upcast <function>" means that the function arguments are
upcasted to a dtype with bits doubled, and the function results are
downcasted to a dtype with bits split half.
""",
file=f,
)
Expand Down

0 comments on commit 23ae022

Please sign in to comment.