From 772943c64b779a6f914309d95d7e6a19d02c9aac Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Tue, 26 Nov 2024 11:48:30 +0200 Subject: [PATCH 1/2] Support testing with XLA CPU client that has FTZ disabled. --- functional_algorithms/tests/test_accuracy.py | 64 +++++++++++++++++--- functional_algorithms/tests/test_utils.py | 20 ++++++ functional_algorithms/utils.py | 31 +++++++--- 3 files changed, 95 insertions(+), 20 deletions(-) diff --git a/functional_algorithms/tests/test_accuracy.py b/functional_algorithms/tests/test_accuracy.py index c70707b..ae6c3d8 100644 --- a/functional_algorithms/tests/test_accuracy.py +++ b/functional_algorithms/tests/test_accuracy.py @@ -36,16 +36,23 @@ def test_unary(unary_func_name, backend, device, dtype): max_valid_ulp_count = params["max_valid_ulp_count"] extra_prec_multiplier = params["extra_prec_multiplier"] samples_limits = params["samples_limits"] - # JAX flushes subnormals to zero - include_subnormal = backend != "jax" + # JAX with CPU flushes subnormals to zero + include_subnormal = False if device == "cpu" and backend == "jax" else True + # include_subnormal = True - mpmath = fa.utils.numpy_with_mpmath(extra_prec_multiplier=extra_prec_multiplier) + mpmath = fa.utils.numpy_with_mpmath(extra_prec_multiplier=extra_prec_multiplier, flush_subnormals=not include_subnormal) reference = getattr(mpmath, unary_func_name) npy_reference = getattr(fa.utils.numpy_with_numpy(), unary_func_name) - re_blocks, im_blocks = 101, 51 + re_blocks, im_blocks = 101, 52 re_blocksize, im_blocksize = 20, 20 + + if 0: + # for testing + re_blocks, im_blocks = 51, 26 + re_blocksize, im_blocksize = 5, 5 + re_size, im_size = re_blocks * re_blocksize, im_blocks * im_blocksize if dtype in {numpy.complex64, numpy.complex128}: @@ -57,9 +64,44 @@ def test_unary(unary_func_name, backend, device, dtype): assert samples.shape == (im_size, re_size) expected = reference.call(samples) - result = func(samples) - ulp = fa.utils.diff_ulp(result, expected, flush_subnormals=not include_subnormal) + if backend == "jax" and device == "cpu" and include_subnormal: + # XLA CPU client enables FTZ to follow TF convention. To + # disable FTZ, replace all occurances of + # tsl::port::ScopedFlushDenormal flush; + # with + # tsl::port::ScopedDontFlushDenormal flush; + # in xla/pjrt/cpu/cpu_client.cc. + # + # However, disabling FTZ in XLA CPU client is effective only for the first + # part of samples evaluations. To workaround this, we'll evaluate + # JAX functions blockwise: + eval_blocksize = im_size + assert re_size < 2**14 + while eval_blocksize * re_size > 2**14: + for p in [2, 3, 5, 7, 11, 13, 17, 19, 23]: + if eval_blocksize % p == 0: + eval_blocksize //= p + break + else: + assert 0 # adjust re/im_size/blocksize parameters to avoid this + assert im_size % eval_blocksize == 0, (im_size, eval_blocksize) + result = numpy.concatenate( + tuple(func(samples[k * eval_blocksize : (k + 1) * eval_blocksize]) for k in range(im_size // eval_blocksize)) + ) + else: + result = func(samples) + for j in range(im_size): + for i in range(re_size): + r = func(samples[j, i])[()] + assert numpy.array_equal(r, result[j, i], equal_nan=True), ( + (j, i), + samples[j, i], + r, + result[j, i], + ) + + ulp = fa.utils.diff_ulp(result, expected, flush_subnormals=not include_subnormal, equal_nan=True) if numpy.all(ulp == 0): return @@ -71,11 +113,11 @@ def test_unary(unary_func_name, backend, device, dtype): bulp = numpy.zeros((im_blocks, re_blocks), dtype=ulp.dtype) for j, blocks in enumerate(numpy.split(ulp, im_blocks, axis=0)): for i, block in enumerate(numpy.split(blocks, re_blocks, axis=1)): + samples_block = samples[j * im_blocksize : (j + 1) * im_blocksize, i * re_blocksize : (i + 1) * re_blocksize] ind = numpy.unravel_index(numpy.argmax(block, axis=None), block.shape) - j_, i_ = j * im_blocksize + ind[0], i * re_blocksize + ind[1] - bsamples[j, i] = samples[j_, i_] - bulp[j, i] = ulp[j_, i_] - + assert block[ind[0], ind[1]] == numpy.max(block) + bsamples[j, i] = samples_block[ind[0], ind[1]] + bulp[j, i] = block[ind[0], ind[1]] try: fa_reference = getattr(fa.utils.numpy_with_algorithms(dtype=dtype), unary_func_name) except Exception as msg: @@ -138,6 +180,8 @@ def test_unary(unary_func_name, backend, device, dtype): np_value = npy_reference(samples[re, im]) r = func(samples[re, im]) e = reference(samples[re, im]) + u = fa.utils.diff_ulp(r, e, flush_subnormals=not include_subnormal, equal_nan=True) + assert u == value, (u, value, (re, im)) if fa_reference is not None: fa_value = fa_reference(samples[re, im]) rows.append((value, samples[re, im], r, e, np_value, fa_value)) diff --git a/functional_algorithms/tests/test_utils.py b/functional_algorithms/tests/test_utils.py index 3f0a8fe..5e130a3 100644 --- a/functional_algorithms/tests/test_utils.py +++ b/functional_algorithms/tests/test_utils.py @@ -15,6 +15,26 @@ def dtype(request): return request.param +def test_diff_ulp(real_dtype): + if real_dtype == numpy.longdouble: + pytest.skip(f"support not implemented") + fi = numpy.finfo(real_dtype) + + assert utils.diff_ulp(real_dtype(0), fi.tiny, flush_subnormals=True) == 1 + assert utils.diff_ulp(real_dtype(0), numpy.nextafter(fi.tiny, fi.max), flush_subnormals=True) == 2 + + assert utils.diff_ulp(real_dtype(0), -fi.tiny, flush_subnormals=True) == 1 + assert utils.diff_ulp(real_dtype(0), fi.smallest_subnormal, flush_subnormals=False) == 1 + assert utils.diff_ulp(real_dtype(0), -fi.smallest_subnormal, flush_subnormals=False) == 1 + + assert utils.diff_ulp(fi.tiny, fi.tiny, flush_subnormals=True) == 0 + assert utils.diff_ulp(fi.tiny, fi.tiny, flush_subnormals=False) == 0 + assert utils.diff_ulp(fi.tiny, numpy.nextafter(fi.tiny, fi.max), flush_subnormals=True) == 1 + assert utils.diff_ulp(fi.tiny, numpy.nextafter(fi.tiny, fi.max), flush_subnormals=False) == 1 + + assert utils.diff_ulp(-fi.tiny, fi.tiny, flush_subnormals=True) == 2 + + def _check_real_samples( r, include_infinity=None, diff --git a/functional_algorithms/utils.py b/functional_algorithms/utils.py index f820a4f..14f75c0 100644 --- a/functional_algorithms/utils.py +++ b/functional_algorithms/utils.py @@ -442,8 +442,8 @@ def __call__(self, *args, **kwargs): assert self.device.upper() in str(getattr(a, "device", "cpu")).upper() mp_args.append(a) - with self.backend_context(context): - with warnings.catch_warnings(action="ignore"): + with warnings.catch_warnings(action="ignore"): + with self.backend_context(context): if self.pyfunc_is_vectorized: result = self.pyfunc(*mp_args, **kwargs) else: @@ -1536,7 +1536,7 @@ def isfloat(value): return isinstance(value, (float, numpy.floating)) -def diff_ulp(x, y, flush_subnormals=UNSPECIFIED) -> int: +def diff_ulp(x, y, flush_subnormals=UNSPECIFIED, equal_nan=False) -> int: """Return ULP distance between two floating point numbers. For complex inputs, return largest ULP among real and imaginary @@ -1545,11 +1545,14 @@ def diff_ulp(x, y, flush_subnormals=UNSPECIFIED) -> int: When flush_subnormals is set to True, ULP difference does not account for subnormals while subnormal values are rounded to nearest normal, ties to even. + + When equal_nan is set to True, ULP difference between nan values + of both quiet and signaling kinds is defined as 0. """ if isinstance(x, numpy.floating): uint = {numpy.float64: numpy.uint64, numpy.float32: numpy.uint32, numpy.float16: numpy.uint16}[x.dtype.type] - sx = -1 if x <= 0 else 1 - sy = -1 if y <= 0 else 1 + sx = -1 if x < 0 else (1 if x > 0 else 0) + sy = -1 if y < 0 else (1 if y > 0 else 0) x, y = abs(x), abs(y) ix, iy = int(x.view(uint)), int(y.view(uint)) if numpy.isfinite(x) and numpy.isfinite(y): @@ -1561,18 +1564,26 @@ def diff_ulp(x, y, flush_subnormals=UNSPECIFIED) -> int: iy = iy - i if iy > i else (0 if 2 * iy <= i else 1) if sx != sy: # distance is measured through 0 value - return ix + iy - return ix - iy if ix >= iy else iy - ix + result = ix + iy + else: + result = ix - iy if ix >= iy else iy - ix + return result elif ix == iy and sx == sy: return 0 + elif numpy.isnan(x) and numpy.isnan(y): + if equal_nan: + return 0 return {numpy.float64: 2**64, numpy.float32: 2**32, numpy.float16: 2**16}[x.dtype.type] elif isinstance(x, numpy.complexfloating): return max( - diff_ulp(x.real, y.real, flush_subnormals=flush_subnormals), - diff_ulp(x.imag, y.imag, flush_subnormals=flush_subnormals), + diff_ulp(x.real, y.real, flush_subnormals=flush_subnormals, equal_nan=equal_nan), + diff_ulp(x.imag, y.imag, flush_subnormals=flush_subnormals, equal_nan=equal_nan), ) elif isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray): - return numpy.array([diff_ulp(x_, y_, flush_subnormals=flush_subnormals) for x_, y_ in zip(x, y)]) + if x.shape == () and y.shape == (): + return numpy.array(diff_ulp(x[()], y[()], flush_subnormals=flush_subnormals, equal_nan=equal_nan)) + assert x.shape == y.shape, (x.shape, y.shape) + return numpy.array([diff_ulp(x_, y_, flush_subnormals=flush_subnormals, equal_nan=equal_nan) for x_, y_ in zip(x, y)]) raise NotImplementedError(type(x)) From 5331262a062522040b0ec0f8b4ded67738873346 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Tue, 26 Nov 2024 11:53:38 +0200 Subject: [PATCH 2/2] Disable expensive sanity check --- functional_algorithms/tests/test_accuracy.py | 21 +++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/functional_algorithms/tests/test_accuracy.py b/functional_algorithms/tests/test_accuracy.py index ae6c3d8..d132136 100644 --- a/functional_algorithms/tests/test_accuracy.py +++ b/functional_algorithms/tests/test_accuracy.py @@ -91,15 +91,18 @@ def test_unary(unary_func_name, backend, device, dtype): ) else: result = func(samples) - for j in range(im_size): - for i in range(re_size): - r = func(samples[j, i])[()] - assert numpy.array_equal(r, result[j, i], equal_nan=True), ( - (j, i), - samples[j, i], - r, - result[j, i], - ) + + if 0: + # for sanity check + for j in range(im_size): + for i in range(re_size): + r = func(samples[j, i])[()] + assert numpy.array_equal(r, result[j, i], equal_nan=True), ( + (j, i), + samples[j, i], + r, + result[j, i], + ) ulp = fa.utils.diff_ulp(result, expected, flush_subnormals=not include_subnormal, equal_nan=True)