Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support testing with XLA CPU client that has FTZ disabled. #75

Merged
merged 2 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 57 additions & 10 deletions functional_algorithms/tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,23 @@ def test_unary(unary_func_name, backend, device, dtype):
max_valid_ulp_count = params["max_valid_ulp_count"]
extra_prec_multiplier = params["extra_prec_multiplier"]
samples_limits = params["samples_limits"]
# JAX flushes subnormals to zero
include_subnormal = backend != "jax"
# JAX with CPU flushes subnormals to zero
include_subnormal = False if device == "cpu" and backend == "jax" else True
# include_subnormal = True

mpmath = fa.utils.numpy_with_mpmath(extra_prec_multiplier=extra_prec_multiplier)
mpmath = fa.utils.numpy_with_mpmath(extra_prec_multiplier=extra_prec_multiplier, flush_subnormals=not include_subnormal)

reference = getattr(mpmath, unary_func_name)
npy_reference = getattr(fa.utils.numpy_with_numpy(), unary_func_name)

re_blocks, im_blocks = 101, 51
re_blocks, im_blocks = 101, 52
re_blocksize, im_blocksize = 20, 20

if 0:
# for testing
re_blocks, im_blocks = 51, 26
re_blocksize, im_blocksize = 5, 5

re_size, im_size = re_blocks * re_blocksize, im_blocks * im_blocksize

if dtype in {numpy.complex64, numpy.complex128}:
Expand All @@ -57,9 +64,47 @@ def test_unary(unary_func_name, backend, device, dtype):
assert samples.shape == (im_size, re_size)

expected = reference.call(samples)
result = func(samples)

ulp = fa.utils.diff_ulp(result, expected, flush_subnormals=not include_subnormal)
if backend == "jax" and device == "cpu" and include_subnormal:
# XLA CPU client enables FTZ to follow TF convention. To
# disable FTZ, replace all occurances of
# tsl::port::ScopedFlushDenormal flush;
# with
# tsl::port::ScopedDontFlushDenormal flush;
# in xla/pjrt/cpu/cpu_client.cc.
#
# However, disabling FTZ in XLA CPU client is effective only for the first
# part of samples evaluations. To workaround this, we'll evaluate
# JAX functions blockwise:
eval_blocksize = im_size
assert re_size < 2**14
while eval_blocksize * re_size > 2**14:
for p in [2, 3, 5, 7, 11, 13, 17, 19, 23]:
if eval_blocksize % p == 0:
eval_blocksize //= p
break
else:
assert 0 # adjust re/im_size/blocksize parameters to avoid this
assert im_size % eval_blocksize == 0, (im_size, eval_blocksize)
result = numpy.concatenate(
tuple(func(samples[k * eval_blocksize : (k + 1) * eval_blocksize]) for k in range(im_size // eval_blocksize))
)
else:
result = func(samples)

if 0:
# for sanity check
for j in range(im_size):
for i in range(re_size):
r = func(samples[j, i])[()]
assert numpy.array_equal(r, result[j, i], equal_nan=True), (
(j, i),
samples[j, i],
r,
result[j, i],
)

ulp = fa.utils.diff_ulp(result, expected, flush_subnormals=not include_subnormal, equal_nan=True)

if numpy.all(ulp == 0):
return
Expand All @@ -71,11 +116,11 @@ def test_unary(unary_func_name, backend, device, dtype):
bulp = numpy.zeros((im_blocks, re_blocks), dtype=ulp.dtype)
for j, blocks in enumerate(numpy.split(ulp, im_blocks, axis=0)):
for i, block in enumerate(numpy.split(blocks, re_blocks, axis=1)):
samples_block = samples[j * im_blocksize : (j + 1) * im_blocksize, i * re_blocksize : (i + 1) * re_blocksize]
ind = numpy.unravel_index(numpy.argmax(block, axis=None), block.shape)
j_, i_ = j * im_blocksize + ind[0], i * re_blocksize + ind[1]
bsamples[j, i] = samples[j_, i_]
bulp[j, i] = ulp[j_, i_]

assert block[ind[0], ind[1]] == numpy.max(block)
bsamples[j, i] = samples_block[ind[0], ind[1]]
bulp[j, i] = block[ind[0], ind[1]]
try:
fa_reference = getattr(fa.utils.numpy_with_algorithms(dtype=dtype), unary_func_name)
except Exception as msg:
Expand Down Expand Up @@ -138,6 +183,8 @@ def test_unary(unary_func_name, backend, device, dtype):
np_value = npy_reference(samples[re, im])
r = func(samples[re, im])
e = reference(samples[re, im])
u = fa.utils.diff_ulp(r, e, flush_subnormals=not include_subnormal, equal_nan=True)
assert u == value, (u, value, (re, im))
if fa_reference is not None:
fa_value = fa_reference(samples[re, im])
rows.append((value, samples[re, im], r, e, np_value, fa_value))
Expand Down
20 changes: 20 additions & 0 deletions functional_algorithms/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,26 @@ def dtype(request):
return request.param


def test_diff_ulp(real_dtype):
if real_dtype == numpy.longdouble:
pytest.skip(f"support not implemented")
fi = numpy.finfo(real_dtype)

assert utils.diff_ulp(real_dtype(0), fi.tiny, flush_subnormals=True) == 1
assert utils.diff_ulp(real_dtype(0), numpy.nextafter(fi.tiny, fi.max), flush_subnormals=True) == 2

assert utils.diff_ulp(real_dtype(0), -fi.tiny, flush_subnormals=True) == 1
assert utils.diff_ulp(real_dtype(0), fi.smallest_subnormal, flush_subnormals=False) == 1
assert utils.diff_ulp(real_dtype(0), -fi.smallest_subnormal, flush_subnormals=False) == 1

assert utils.diff_ulp(fi.tiny, fi.tiny, flush_subnormals=True) == 0
assert utils.diff_ulp(fi.tiny, fi.tiny, flush_subnormals=False) == 0
assert utils.diff_ulp(fi.tiny, numpy.nextafter(fi.tiny, fi.max), flush_subnormals=True) == 1
assert utils.diff_ulp(fi.tiny, numpy.nextafter(fi.tiny, fi.max), flush_subnormals=False) == 1

assert utils.diff_ulp(-fi.tiny, fi.tiny, flush_subnormals=True) == 2


def _check_real_samples(
r,
include_infinity=None,
Expand Down
31 changes: 21 additions & 10 deletions functional_algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,8 @@ def __call__(self, *args, **kwargs):
assert self.device.upper() in str(getattr(a, "device", "cpu")).upper()
mp_args.append(a)

with self.backend_context(context):
with warnings.catch_warnings(action="ignore"):
with warnings.catch_warnings(action="ignore"):
with self.backend_context(context):
if self.pyfunc_is_vectorized:
result = self.pyfunc(*mp_args, **kwargs)
else:
Expand Down Expand Up @@ -1536,7 +1536,7 @@ def isfloat(value):
return isinstance(value, (float, numpy.floating))


def diff_ulp(x, y, flush_subnormals=UNSPECIFIED) -> int:
def diff_ulp(x, y, flush_subnormals=UNSPECIFIED, equal_nan=False) -> int:
"""Return ULP distance between two floating point numbers.

For complex inputs, return largest ULP among real and imaginary
Expand All @@ -1545,11 +1545,14 @@ def diff_ulp(x, y, flush_subnormals=UNSPECIFIED) -> int:
When flush_subnormals is set to True, ULP difference does not
account for subnormals while subnormal values are rounded to
nearest normal, ties to even.

When equal_nan is set to True, ULP difference between nan values
of both quiet and signaling kinds is defined as 0.
"""
if isinstance(x, numpy.floating):
uint = {numpy.float64: numpy.uint64, numpy.float32: numpy.uint32, numpy.float16: numpy.uint16}[x.dtype.type]
sx = -1 if x <= 0 else 1
sy = -1 if y <= 0 else 1
sx = -1 if x < 0 else (1 if x > 0 else 0)
sy = -1 if y < 0 else (1 if y > 0 else 0)
x, y = abs(x), abs(y)
ix, iy = int(x.view(uint)), int(y.view(uint))
if numpy.isfinite(x) and numpy.isfinite(y):
Expand All @@ -1561,18 +1564,26 @@ def diff_ulp(x, y, flush_subnormals=UNSPECIFIED) -> int:
iy = iy - i if iy > i else (0 if 2 * iy <= i else 1)
if sx != sy:
# distance is measured through 0 value
return ix + iy
return ix - iy if ix >= iy else iy - ix
result = ix + iy
else:
result = ix - iy if ix >= iy else iy - ix
return result
elif ix == iy and sx == sy:
return 0
elif numpy.isnan(x) and numpy.isnan(y):
if equal_nan:
return 0
return {numpy.float64: 2**64, numpy.float32: 2**32, numpy.float16: 2**16}[x.dtype.type]
elif isinstance(x, numpy.complexfloating):
return max(
diff_ulp(x.real, y.real, flush_subnormals=flush_subnormals),
diff_ulp(x.imag, y.imag, flush_subnormals=flush_subnormals),
diff_ulp(x.real, y.real, flush_subnormals=flush_subnormals, equal_nan=equal_nan),
diff_ulp(x.imag, y.imag, flush_subnormals=flush_subnormals, equal_nan=equal_nan),
)
elif isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
return numpy.array([diff_ulp(x_, y_, flush_subnormals=flush_subnormals) for x_, y_ in zip(x, y)])
if x.shape == () and y.shape == ():
return numpy.array(diff_ulp(x[()], y[()], flush_subnormals=flush_subnormals, equal_nan=equal_nan))
assert x.shape == y.shape, (x.shape, y.shape)
return numpy.array([diff_ulp(x_, y_, flush_subnormals=flush_subnormals, equal_nan=equal_nan) for x_, y_ in zip(x, y)])

raise NotImplementedError(type(x))

Expand Down