Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement EqualityMapping and use for relevant dtype helpers #112

Merged
merged 6 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions array_api_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from functools import wraps

from hypothesis import strategies as st
from hypothesis.extra.array_api import make_strategies_namespace
from hypothesis.extra import array_api

from ._array_module import mod as _xp

__all__ = ["xps"]

xps = make_strategies_namespace(_xp)


# We monkey patch floats() to always disable subnormals as they are out-of-scope

Expand All @@ -23,5 +21,29 @@ def floats(*a, **kw):

st.floats = floats


# We do the same with xps.from_dtype() - this is not strictly necessary, as
# the underlying floats() will never generate subnormals. We only do this
# because internal logic in xps.from_dtype() assumes xp.finfo() has its
# attributes as scalar floats, which is expected behaviour but disrupts many
# unrelated tests.
try:
__from_dtype = array_api._from_dtype

@wraps(__from_dtype)
def _from_dtype(*a, **kw):
kw["allow_subnormal"] = False
return __from_dtype(*a, **kw)

array_api._from_dtype = _from_dtype
except AttributeError:
# Ignore monkey patching if Hypothesis changes the private API
pass


xps = array_api.make_strategies_namespace(_xp)


from . import _version
__version__ = _version.get_versions()['version']

__version__ = _version.get_versions()["version"]
188 changes: 117 additions & 71 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Mapping
from functools import lru_cache
from typing import NamedTuple, Tuple, Union
from typing import Any, NamedTuple, Sequence, Tuple, Union
from warnings import warn

from . import _array_module as xp
Expand Down Expand Up @@ -36,6 +37,49 @@
]


class EqualityMapping(Mapping):
"""
Mapping that uses equality for indexing

Typical mappings (e.g. the built-in dict) use hashing for indexing. This
isn't ideal for the Array API, as no __hash__() method is specified for
dtype objects - but __eq__() is!

See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects
"""

def __init__(self, key_value_pairs: Sequence[Tuple[Any, Any]]):
keys = [k for k, _ in key_value_pairs]
for i, key in enumerate(keys):
if not (key == key): # specifically checking __eq__, not __neq__
raise ValueError("Key {key!r} does not have equality with itself")
other_keys = keys[:]
other_keys.pop(i)
for other_key in other_keys:
if key == other_key:
raise ValueError("Key {key!r} has equality with key {other_key!r}")
self._key_value_pairs = key_value_pairs

def __getitem__(self, key):
for k, v in self._key_value_pairs:
if key == k:
return v
else:
raise KeyError(f"{key!r} not found")

def __iter__(self):
return (k for k, _ in self._key_value_pairs)

def __len__(self):
return len(self._key_value_pairs)

def __str__(self):
return "{" + ", ".join(f"{k!r}: {v!r}" for k, v in self._key_value_pairs) + "}"

def __repr__(self):
return f"EqualityMapping({self})"


_uint_names = ("uint8", "uint16", "uint32", "uint64")
_int_names = ("int8", "int16", "int32", "int64")
_float_names = ("float32", "float64")
Expand All @@ -51,14 +95,16 @@
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes


dtype_to_name = {getattr(xp, name): name for name in _dtype_names}
dtype_to_name = EqualityMapping([(getattr(xp, name), name) for name in _dtype_names])


dtype_to_scalars = {
xp.bool: [bool],
**{d: [int] for d in all_int_dtypes},
**{d: [int, float] for d in float_dtypes},
}
dtype_to_scalars = EqualityMapping(
[
(xp.bool, [bool]),
*[(d, [int]) for d in all_int_dtypes],
*[(d, [int, float]) for d in float_dtypes],
]
)


def is_int_dtype(dtype):
Expand Down Expand Up @@ -90,31 +136,32 @@ class MinMax(NamedTuple):
max: Union[int, float]


dtype_ranges = {
xp.int8: MinMax(-128, +127),
xp.int16: MinMax(-32_768, +32_767),
xp.int32: MinMax(-2_147_483_648, +2_147_483_647),
xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807),
xp.uint8: MinMax(0, +255),
xp.uint16: MinMax(0, +65_535),
xp.uint32: MinMax(0, +4_294_967_295),
xp.uint64: MinMax(0, +18_446_744_073_709_551_615),
xp.float32: MinMax(-3.4028234663852886e+38, 3.4028234663852886e+38),
xp.float64: MinMax(-1.7976931348623157e+308, 1.7976931348623157e+308),
}
dtype_ranges = EqualityMapping(
[
(xp.int8, MinMax(-128, +127)),
(xp.int16, MinMax(-32_768, +32_767)),
(xp.int32, MinMax(-2_147_483_648, +2_147_483_647)),
(xp.int64, MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807)),
(xp.uint8, MinMax(0, +255)),
(xp.uint16, MinMax(0, +65_535)),
(xp.uint32, MinMax(0, +4_294_967_295)),
(xp.uint64, MinMax(0, +18_446_744_073_709_551_615)),
(xp.float32, MinMax(-3.4028234663852886e38, 3.4028234663852886e38)),
(xp.float64, MinMax(-1.7976931348623157e308, 1.7976931348623157e308)),
]
)

dtype_nbits = {
**{d: 8 for d in [xp.int8, xp.uint8]},
**{d: 16 for d in [xp.int16, xp.uint16]},
**{d: 32 for d in [xp.int32, xp.uint32, xp.float32]},
**{d: 64 for d in [xp.int64, xp.uint64, xp.float64]},
}
dtype_nbits = EqualityMapping(
[(d, 8) for d in [xp.int8, xp.uint8]]
+ [(d, 16) for d in [xp.int16, xp.uint16]]
+ [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]]
+ [(d, 64) for d in [xp.int64, xp.uint64, xp.float64]]
)


dtype_signed = {
**{d: True for d in int_dtypes},
**{d: False for d in uint_dtypes},
}
dtype_signed = EqualityMapping(
[(d, True) for d in int_dtypes] + [(d, False) for d in uint_dtypes]
)


if isinstance(xp.asarray, _UndefinedStub):
Expand All @@ -137,52 +184,51 @@ class MinMax(NamedTuple):
default_uint = xp.uint64


_numeric_promotions = {
_numeric_promotions = [
# ints
(xp.int8, xp.int8): xp.int8,
(xp.int8, xp.int16): xp.int16,
(xp.int8, xp.int32): xp.int32,
(xp.int8, xp.int64): xp.int64,
(xp.int16, xp.int16): xp.int16,
(xp.int16, xp.int32): xp.int32,
(xp.int16, xp.int64): xp.int64,
(xp.int32, xp.int32): xp.int32,
(xp.int32, xp.int64): xp.int64,
(xp.int64, xp.int64): xp.int64,
((xp.int8, xp.int8), xp.int8),
((xp.int8, xp.int16), xp.int16),
((xp.int8, xp.int32), xp.int32),
((xp.int8, xp.int64), xp.int64),
((xp.int16, xp.int16), xp.int16),
((xp.int16, xp.int32), xp.int32),
((xp.int16, xp.int64), xp.int64),
((xp.int32, xp.int32), xp.int32),
((xp.int32, xp.int64), xp.int64),
((xp.int64, xp.int64), xp.int64),
# uints
(xp.uint8, xp.uint8): xp.uint8,
(xp.uint8, xp.uint16): xp.uint16,
(xp.uint8, xp.uint32): xp.uint32,
(xp.uint8, xp.uint64): xp.uint64,
(xp.uint16, xp.uint16): xp.uint16,
(xp.uint16, xp.uint32): xp.uint32,
(xp.uint16, xp.uint64): xp.uint64,
(xp.uint32, xp.uint32): xp.uint32,
(xp.uint32, xp.uint64): xp.uint64,
(xp.uint64, xp.uint64): xp.uint64,
((xp.uint8, xp.uint8), xp.uint8),
((xp.uint8, xp.uint16), xp.uint16),
((xp.uint8, xp.uint32), xp.uint32),
((xp.uint8, xp.uint64), xp.uint64),
((xp.uint16, xp.uint16), xp.uint16),
((xp.uint16, xp.uint32), xp.uint32),
((xp.uint16, xp.uint64), xp.uint64),
((xp.uint32, xp.uint32), xp.uint32),
((xp.uint32, xp.uint64), xp.uint64),
((xp.uint64, xp.uint64), xp.uint64),
# ints and uints (mixed sign)
(xp.int8, xp.uint8): xp.int16,
(xp.int8, xp.uint16): xp.int32,
(xp.int8, xp.uint32): xp.int64,
(xp.int16, xp.uint8): xp.int16,
(xp.int16, xp.uint16): xp.int32,
(xp.int16, xp.uint32): xp.int64,
(xp.int32, xp.uint8): xp.int32,
(xp.int32, xp.uint16): xp.int32,
(xp.int32, xp.uint32): xp.int64,
(xp.int64, xp.uint8): xp.int64,
(xp.int64, xp.uint16): xp.int64,
(xp.int64, xp.uint32): xp.int64,
((xp.int8, xp.uint8), xp.int16),
((xp.int8, xp.uint16), xp.int32),
((xp.int8, xp.uint32), xp.int64),
((xp.int16, xp.uint8), xp.int16),
((xp.int16, xp.uint16), xp.int32),
((xp.int16, xp.uint32), xp.int64),
((xp.int32, xp.uint8), xp.int32),
((xp.int32, xp.uint16), xp.int32),
((xp.int32, xp.uint32), xp.int64),
((xp.int64, xp.uint8), xp.int64),
((xp.int64, xp.uint16), xp.int64),
((xp.int64, xp.uint32), xp.int64),
# floats
(xp.float32, xp.float32): xp.float32,
(xp.float32, xp.float64): xp.float64,
(xp.float64, xp.float64): xp.float64,
}
promotion_table = {
(xp.bool, xp.bool): xp.bool,
**_numeric_promotions,
**{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()},
}
((xp.float32, xp.float32), xp.float32),
((xp.float32, xp.float64), xp.float64),
((xp.float64, xp.float64), xp.float64),
]
_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions]
_promotion_table = list(set(_numeric_promotions))
_promotion_table.insert(0, ((xp.bool, xp.bool), xp.bool))
promotion_table = EqualityMapping(_promotion_table)


def result_type(*dtypes: DataType):
Expand Down
37 changes: 37 additions & 0 deletions array_api_tests/meta/test_equality_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from ..dtype_helpers import EqualityMapping


def test_raises_on_distinct_eq_key():
with pytest.raises(ValueError):
EqualityMapping([(float("nan"), "value")])


def test_raises_on_indistinct_eq_keys():
class AlwaysEq:
def __init__(self, hash):
self._hash = hash

def __eq__(self, other):
return True

def __hash__(self):
return self._hash

with pytest.raises(ValueError):
EqualityMapping([(AlwaysEq(0), "value1"), (AlwaysEq(1), "value2")])


def test_key_error():
mapping = EqualityMapping([("key", "value")])
with pytest.raises(KeyError):
mapping["nonexistent key"]


def test_iter():
mapping = EqualityMapping([("key", "value")])
it = iter(mapping)
assert next(it) == "key"
with pytest.raises(StopIteration):
next(it)