Skip to content

Commit

Permalink
Merge pull request #112 from honno/use-dtype-eq
Browse files Browse the repository at this point in the history
Implement `EqualityMapping` and use for relevant dtype helpers
  • Loading branch information
asmeurer authored Apr 20, 2022
2 parents 9816011 + ed23bfa commit 63ebadb
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 75 deletions.
30 changes: 26 additions & 4 deletions array_api_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from functools import wraps

from hypothesis import strategies as st
from hypothesis.extra.array_api import make_strategies_namespace
from hypothesis.extra import array_api

from ._array_module import mod as _xp

__all__ = ["xps"]

xps = make_strategies_namespace(_xp)


# We monkey patch floats() to always disable subnormals as they are out-of-scope

Expand All @@ -23,5 +21,29 @@ def floats(*a, **kw):

st.floats = floats


# We do the same with xps.from_dtype() - this is not strictly necessary, as
# the underlying floats() will never generate subnormals. We only do this
# because internal logic in xps.from_dtype() assumes xp.finfo() has its
# attributes as scalar floats, which is expected behaviour but disrupts many
# unrelated tests.
try:
__from_dtype = array_api._from_dtype

@wraps(__from_dtype)
def _from_dtype(*a, **kw):
kw["allow_subnormal"] = False
return __from_dtype(*a, **kw)

array_api._from_dtype = _from_dtype
except AttributeError:
# Ignore monkey patching if Hypothesis changes the private API
pass


xps = array_api.make_strategies_namespace(_xp)


from . import _version
__version__ = _version.get_versions()['version']

__version__ = _version.get_versions()["version"]
188 changes: 117 additions & 71 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Mapping
from functools import lru_cache
from typing import NamedTuple, Tuple, Union
from typing import Any, NamedTuple, Sequence, Tuple, Union
from warnings import warn

from . import _array_module as xp
Expand Down Expand Up @@ -36,6 +37,49 @@
]


class EqualityMapping(Mapping):
"""
Mapping that uses equality for indexing
Typical mappings (e.g. the built-in dict) use hashing for indexing. This
isn't ideal for the Array API, as no __hash__() method is specified for
dtype objects - but __eq__() is!
See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects
"""

def __init__(self, key_value_pairs: Sequence[Tuple[Any, Any]]):
keys = [k for k, _ in key_value_pairs]
for i, key in enumerate(keys):
if not (key == key): # specifically checking __eq__, not __neq__
raise ValueError("Key {key!r} does not have equality with itself")
other_keys = keys[:]
other_keys.pop(i)
for other_key in other_keys:
if key == other_key:
raise ValueError("Key {key!r} has equality with key {other_key!r}")
self._key_value_pairs = key_value_pairs

def __getitem__(self, key):
for k, v in self._key_value_pairs:
if key == k:
return v
else:
raise KeyError(f"{key!r} not found")

def __iter__(self):
return (k for k, _ in self._key_value_pairs)

def __len__(self):
return len(self._key_value_pairs)

def __str__(self):
return "{" + ", ".join(f"{k!r}: {v!r}" for k, v in self._key_value_pairs) + "}"

def __repr__(self):
return f"EqualityMapping({self})"


_uint_names = ("uint8", "uint16", "uint32", "uint64")
_int_names = ("int8", "int16", "int32", "int64")
_float_names = ("float32", "float64")
Expand All @@ -51,14 +95,16 @@
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes


dtype_to_name = {getattr(xp, name): name for name in _dtype_names}
dtype_to_name = EqualityMapping([(getattr(xp, name), name) for name in _dtype_names])


dtype_to_scalars = {
xp.bool: [bool],
**{d: [int] for d in all_int_dtypes},
**{d: [int, float] for d in float_dtypes},
}
dtype_to_scalars = EqualityMapping(
[
(xp.bool, [bool]),
*[(d, [int]) for d in all_int_dtypes],
*[(d, [int, float]) for d in float_dtypes],
]
)


def is_int_dtype(dtype):
Expand Down Expand Up @@ -90,31 +136,32 @@ class MinMax(NamedTuple):
max: Union[int, float]


dtype_ranges = {
xp.int8: MinMax(-128, +127),
xp.int16: MinMax(-32_768, +32_767),
xp.int32: MinMax(-2_147_483_648, +2_147_483_647),
xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807),
xp.uint8: MinMax(0, +255),
xp.uint16: MinMax(0, +65_535),
xp.uint32: MinMax(0, +4_294_967_295),
xp.uint64: MinMax(0, +18_446_744_073_709_551_615),
xp.float32: MinMax(-3.4028234663852886e+38, 3.4028234663852886e+38),
xp.float64: MinMax(-1.7976931348623157e+308, 1.7976931348623157e+308),
}
dtype_ranges = EqualityMapping(
[
(xp.int8, MinMax(-128, +127)),
(xp.int16, MinMax(-32_768, +32_767)),
(xp.int32, MinMax(-2_147_483_648, +2_147_483_647)),
(xp.int64, MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807)),
(xp.uint8, MinMax(0, +255)),
(xp.uint16, MinMax(0, +65_535)),
(xp.uint32, MinMax(0, +4_294_967_295)),
(xp.uint64, MinMax(0, +18_446_744_073_709_551_615)),
(xp.float32, MinMax(-3.4028234663852886e38, 3.4028234663852886e38)),
(xp.float64, MinMax(-1.7976931348623157e308, 1.7976931348623157e308)),
]
)

dtype_nbits = {
**{d: 8 for d in [xp.int8, xp.uint8]},
**{d: 16 for d in [xp.int16, xp.uint16]},
**{d: 32 for d in [xp.int32, xp.uint32, xp.float32]},
**{d: 64 for d in [xp.int64, xp.uint64, xp.float64]},
}
dtype_nbits = EqualityMapping(
[(d, 8) for d in [xp.int8, xp.uint8]]
+ [(d, 16) for d in [xp.int16, xp.uint16]]
+ [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]]
+ [(d, 64) for d in [xp.int64, xp.uint64, xp.float64]]
)


dtype_signed = {
**{d: True for d in int_dtypes},
**{d: False for d in uint_dtypes},
}
dtype_signed = EqualityMapping(
[(d, True) for d in int_dtypes] + [(d, False) for d in uint_dtypes]
)


if isinstance(xp.asarray, _UndefinedStub):
Expand All @@ -137,52 +184,51 @@ class MinMax(NamedTuple):
default_uint = xp.uint64


_numeric_promotions = {
_numeric_promotions = [
# ints
(xp.int8, xp.int8): xp.int8,
(xp.int8, xp.int16): xp.int16,
(xp.int8, xp.int32): xp.int32,
(xp.int8, xp.int64): xp.int64,
(xp.int16, xp.int16): xp.int16,
(xp.int16, xp.int32): xp.int32,
(xp.int16, xp.int64): xp.int64,
(xp.int32, xp.int32): xp.int32,
(xp.int32, xp.int64): xp.int64,
(xp.int64, xp.int64): xp.int64,
((xp.int8, xp.int8), xp.int8),
((xp.int8, xp.int16), xp.int16),
((xp.int8, xp.int32), xp.int32),
((xp.int8, xp.int64), xp.int64),
((xp.int16, xp.int16), xp.int16),
((xp.int16, xp.int32), xp.int32),
((xp.int16, xp.int64), xp.int64),
((xp.int32, xp.int32), xp.int32),
((xp.int32, xp.int64), xp.int64),
((xp.int64, xp.int64), xp.int64),
# uints
(xp.uint8, xp.uint8): xp.uint8,
(xp.uint8, xp.uint16): xp.uint16,
(xp.uint8, xp.uint32): xp.uint32,
(xp.uint8, xp.uint64): xp.uint64,
(xp.uint16, xp.uint16): xp.uint16,
(xp.uint16, xp.uint32): xp.uint32,
(xp.uint16, xp.uint64): xp.uint64,
(xp.uint32, xp.uint32): xp.uint32,
(xp.uint32, xp.uint64): xp.uint64,
(xp.uint64, xp.uint64): xp.uint64,
((xp.uint8, xp.uint8), xp.uint8),
((xp.uint8, xp.uint16), xp.uint16),
((xp.uint8, xp.uint32), xp.uint32),
((xp.uint8, xp.uint64), xp.uint64),
((xp.uint16, xp.uint16), xp.uint16),
((xp.uint16, xp.uint32), xp.uint32),
((xp.uint16, xp.uint64), xp.uint64),
((xp.uint32, xp.uint32), xp.uint32),
((xp.uint32, xp.uint64), xp.uint64),
((xp.uint64, xp.uint64), xp.uint64),
# ints and uints (mixed sign)
(xp.int8, xp.uint8): xp.int16,
(xp.int8, xp.uint16): xp.int32,
(xp.int8, xp.uint32): xp.int64,
(xp.int16, xp.uint8): xp.int16,
(xp.int16, xp.uint16): xp.int32,
(xp.int16, xp.uint32): xp.int64,
(xp.int32, xp.uint8): xp.int32,
(xp.int32, xp.uint16): xp.int32,
(xp.int32, xp.uint32): xp.int64,
(xp.int64, xp.uint8): xp.int64,
(xp.int64, xp.uint16): xp.int64,
(xp.int64, xp.uint32): xp.int64,
((xp.int8, xp.uint8), xp.int16),
((xp.int8, xp.uint16), xp.int32),
((xp.int8, xp.uint32), xp.int64),
((xp.int16, xp.uint8), xp.int16),
((xp.int16, xp.uint16), xp.int32),
((xp.int16, xp.uint32), xp.int64),
((xp.int32, xp.uint8), xp.int32),
((xp.int32, xp.uint16), xp.int32),
((xp.int32, xp.uint32), xp.int64),
((xp.int64, xp.uint8), xp.int64),
((xp.int64, xp.uint16), xp.int64),
((xp.int64, xp.uint32), xp.int64),
# floats
(xp.float32, xp.float32): xp.float32,
(xp.float32, xp.float64): xp.float64,
(xp.float64, xp.float64): xp.float64,
}
promotion_table = {
(xp.bool, xp.bool): xp.bool,
**_numeric_promotions,
**{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()},
}
((xp.float32, xp.float32), xp.float32),
((xp.float32, xp.float64), xp.float64),
((xp.float64, xp.float64), xp.float64),
]
_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions]
_promotion_table = list(set(_numeric_promotions))
_promotion_table.insert(0, ((xp.bool, xp.bool), xp.bool))
promotion_table = EqualityMapping(_promotion_table)


def result_type(*dtypes: DataType):
Expand Down
37 changes: 37 additions & 0 deletions array_api_tests/meta/test_equality_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from ..dtype_helpers import EqualityMapping


def test_raises_on_distinct_eq_key():
with pytest.raises(ValueError):
EqualityMapping([(float("nan"), "value")])


def test_raises_on_indistinct_eq_keys():
class AlwaysEq:
def __init__(self, hash):
self._hash = hash

def __eq__(self, other):
return True

def __hash__(self):
return self._hash

with pytest.raises(ValueError):
EqualityMapping([(AlwaysEq(0), "value1"), (AlwaysEq(1), "value2")])


def test_key_error():
mapping = EqualityMapping([("key", "value")])
with pytest.raises(KeyError):
mapping["nonexistent key"]


def test_iter():
mapping = EqualityMapping([("key", "value")])
it = iter(mapping)
assert next(it) == "key"
with pytest.raises(StopIteration):
next(it)

0 comments on commit 63ebadb

Please sign in to comment.