Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index testing improvements #124

Merged
merged 10 commits into from
May 24, 2022
1 change: 0 additions & 1 deletion array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def is_float_dtype(dtype):
# See https://github.com/numpy/numpy/issues/18434
if dtype is None:
return False
# TODO: Return True for float dtypes that aren't part of the spec e.g. np.float16
return dtype in float_dtypes


Expand Down
10 changes: 5 additions & 5 deletions array_api_tests/meta/test_pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def test_assert_dtype():
ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool)


def test_assert_array():
ph.assert_array("int zeros", xp.asarray(0), xp.asarray(0))
ph.assert_array("pos zeros", xp.asarray(0.0), xp.asarray(0.0))
def test_assert_array_elements():
ph.assert_array_elements("int zeros", xp.asarray(0), xp.asarray(0))
ph.assert_array_elements("pos zeros", xp.asarray(0.0), xp.asarray(0.0))
with raises(AssertionError):
ph.assert_array("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0))
ph.assert_array_elements("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0))
with raises(AssertionError):
ph.assert_array("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0))
ph.assert_array_elements("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0))
2 changes: 2 additions & 0 deletions array_api_tests/meta/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,15 @@ def test_roll_ndindex(shape, shifts, axes, expected):
((), "x"),
(42, "x[42]"),
((42,), "x[42]"),
((42, 7), "x[42, 7]"),
(slice(None, 2), "x[:2]"),
(slice(2, None), "x[2:]"),
(slice(0, 2), "x[0:2]"),
(slice(0, 2, -1), "x[0:2:-1]"),
(slice(None, None, -1), "x[::-1]"),
(slice(None, None), "x[:]"),
(..., "x[...]"),
((None, 42), "x[None, 42]"),
],
)
def test_fmt_idx(idx, expected):
Expand Down
24 changes: 13 additions & 11 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"assert_keepdimable_shape",
"assert_0d_equals",
"assert_fill",
"assert_array",
"assert_array_elements",
]


Expand Down Expand Up @@ -301,7 +301,7 @@ def assert_0d_equals(
>>> x = xp.asarray([0, 1, 2])
>>> res = xp.asarray(x, copy=True)
>>> res[0] = 42
>>> assert_0d_equals('__setitem__', 'x[0]', x[0], 'x[0]', res[0])
>>> assert_0d_equals('asarray', 'x[0]', x[0], 'x[0]', res[0])

is equivalent to

Expand Down Expand Up @@ -374,28 +374,30 @@ def assert_fill(
assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg


def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
def assert_array_elements(
func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw
):
"""
Assert array is (strictly) as expected, e.g.
Assert array elements are (strictly) as expected, e.g.

>>> x = xp.arange(5)
>>> out = xp.asarray(x)
>>> assert_array('asarray', out, x)
>>> assert_array_elements('asarray', out, x)

is equivalent to

>>> assert xp.all(out == x)

"""
assert_dtype(func_name, out.dtype, expected.dtype)
assert_shape(func_name, out.shape, expected.shape, **kw)
dh.result_type(out.dtype, expected.dtype) # sanity check
assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check
f_func = f"[{func_name}({fmt_kw(kw)})]"
if dh.is_float_dtype(out.dtype):
for idx in sh.ndindex(out.shape):
at_out = out[idx]
at_expected = expected[idx]
msg = (
f"{sh.fmt_idx('out', idx)}={at_out}, should be {at_expected} "
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
f"{f_func}"
)
if xp.isnan(at_expected):
Expand All @@ -411,6 +413,6 @@ def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
else:
assert at_out == at_expected, msg
else:
assert xp.all(out == expected), (
f"out not as expected {f_func}\n" f"{out=}\n{expected=}"
)
assert xp.all(
out == expected
), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}"
2 changes: 2 additions & 0 deletions array_api_tests/shape_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def fmt_i(i: AtomicIndex) -> str:
if i.step is not None:
res += f":{i.step}"
return res
elif i is None:
return "None"
else:
return "..."

Expand Down
137 changes: 87 additions & 50 deletions array_api_tests/test_array_object.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from itertools import product
from typing import List, get_args
from typing import List, Sequence, Tuple, Union, get_args

import pytest
from hypothesis import assume, given, note
Expand All @@ -12,30 +12,29 @@
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from .typing import DataType, Param, Scalar, ScalarType, Shape
from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape

pytestmark = pytest.mark.ci


def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scalar]]:
def scalar_objects(
dtype: DataType, shape: Shape
) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]:
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
size = math.prod(shape)
return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
lambda l: sh.reshape(l, shape)
)


@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
def test_getitem(shape, data):
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
obj = data.draw(scalar_objects(dtype, shape), label="obj")
x = xp.asarray(obj, dtype=dtype)
note(f"{x=}")
key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")

out = x[key]
def normalise_key(key: Index, shape: Shape) -> Tuple[Union[int, slice], ...]:
"""
Normalise an indexing key.

ph.assert_dtype("__getitem__", x.dtype, out.dtype)
* If a non-tuple index, wrap as a tuple.
* Represent ellipsis as equivalent slices.
"""
_key = tuple(key) if isinstance(key, tuple) else (key,)
if Ellipsis in _key:
nonexpanding_key = tuple(i for i in _key if i is not None)
Expand All @@ -44,71 +43,109 @@ def test_getitem(shape, data):
slices = tuple(slice(None) for _ in range(start_a, stop_a))
start_pos = _key.index(Ellipsis)
_key = _key[:start_pos] + slices + _key[start_pos + 1 :]
return _key


def get_indexed_axes_and_out_shape(
key: Tuple[Union[int, slice, None], ...], shape: Shape
) -> Tuple[Tuple[Sequence[int], ...], Shape]:
"""
From the (normalised) key and input shape, calculates:

* indexed_axes: For each dimension, the axes which the key indexes.
* out_shape: The resulting shape of indexing an array (of the input shape)
with the key.
"""
axes_indices = []
out_shape = []
a = 0
for i in _key:
for i in key:
if i is None:
out_shape.append(1)
else:
side = shape[a]
if isinstance(i, int):
axes_indices.append([i])
if i < 0:
i += side
axes_indices.append((i,))
else:
assert isinstance(i, slice) # sanity check
side = shape[a]
indices = range(side)[i]
axes_indices.append(indices)
out_shape.append(len(indices))
a += 1
out_shape = tuple(out_shape)
return tuple(axes_indices), tuple(out_shape)


@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data())
def test_getitem(shape, dtype, data):
zero_sided = any(side == 0 for side in shape)
if zero_sided:
x = xp.zeros(shape, dtype=dtype)
else:
obj = data.draw(scalar_objects(dtype, shape), label="obj")
x = xp.asarray(obj, dtype=dtype)
note(f"{x=}")
key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")

out = x[key]

ph.assert_dtype("__getitem__", x.dtype, out.dtype)
_key = normalise_key(key, shape)
axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape)
ph.assert_shape("__getitem__", out.shape, out_shape)
assume(all(len(indices) > 0 for indices in axes_indices))
out_obj = []
for idx in product(*axes_indices):
val = obj
for i in idx:
val = val[i]
out_obj.append(val)
out_obj = sh.reshape(out_obj, out_shape)
expected = xp.asarray(out_obj, dtype=dtype)
ph.assert_array("__getitem__", out, expected)


@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
def test_setitem(shape, data):
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
obj = data.draw(scalar_objects(dtype, shape), label="obj")
x = xp.asarray(obj, dtype=dtype)
out_zero_sided = any(side == 0 for side in out_shape)
if not zero_sided and not out_zero_sided:
out_obj = []
for idx in product(*axes_indices):
val = obj
for i in idx:
val = val[i]
out_obj.append(val)
out_obj = sh.reshape(out_obj, out_shape)
expected = xp.asarray(out_obj, dtype=dtype)
ph.assert_array_elements("__getitem__", out, expected)


@given(
shape=hh.shapes(),
dtypes=oneway_promotable_dtypes(dh.all_dtypes),
data=st.data(),
)
def test_setitem(shape, dtypes, data):
zero_sided = any(side == 0 for side in shape)
if zero_sided:
x = xp.zeros(shape, dtype=dtypes.result_dtype)
else:
obj = data.draw(scalar_objects(dtypes.result_dtype, shape), label="obj")
x = xp.asarray(obj, dtype=dtypes.result_dtype)
note(f"{x=}")
# TODO: test setting non-0d arrays
key = data.draw(xps.indices(shape=shape, max_dims=0), label="key")
value = data.draw(
xps.from_dtype(dtype) | xps.arrays(dtype=dtype, shape=()), label="value"
)
key = data.draw(xps.indices(shape=shape), label="key")
_key = normalise_key(key, shape)
axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape)
value_strat = xps.arrays(dtype=dtypes.result_dtype, shape=out_shape)
if out_shape == ():
# We can pass scalars if we're only indexing one element
value_strat |= xps.from_dtype(dtypes.result_dtype)
value = data.draw(value_strat, label="value")

res = xp.asarray(x, copy=True)
res[key] = value

ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype")
ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape")
f_res = sh.fmt_idx("x", key)
if isinstance(value, get_args(Scalar)):
msg = f"x[{key}]={res[key]!r}, but should be {value=} [__setitem__()]"
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
if math.isnan(value):
assert xp.isnan(res[key]), msg
else:
assert res[key] == value, msg
else:
ph.assert_0d_equals(
"__setitem__", "value", value, f"modified x[{key}]", res[key]
)
_key = key if isinstance(key, tuple) else (key,)
assume(all(isinstance(i, int) for i in _key)) # TODO: normalise slices and ellipsis
_key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape))
unaffected_indices = list(sh.ndindex(res.shape))
unaffected_indices.remove(_key)
ph.assert_array_elements("__setitem__", res[key], value, out_repr=f_res)
unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
for idx in unaffected_indices:
ph.assert_0d_equals(
"__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx]
"__setitem__", f"old {f_res}", x[idx], f"modified {f_res}", res[idx]
)


Expand Down
Loading