Skip to content

Commit

Permalink
Merge pull request #116 from honno/minor-fixes
Browse files Browse the repository at this point in the history
Minor fixes
  • Loading branch information
honno authored Apr 21, 2022
2 parents 63ebadb + 4d3405a commit a2f7bd5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 29 deletions.
33 changes: 5 additions & 28 deletions array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,43 +409,20 @@ def test_full_like(x, fill_value, kw):
finite_kw = {"allow_nan": False, "allow_infinity": False}


def int_stops(
start: int, num, dtype: DataType, endpoint: bool
) -> st.SearchStrategy[int]:
min_gap = num
if endpoint:
min_gap += 1
m, M = dh.dtype_ranges[dtype]
max_pos_gap = M - start
max_neg_gap = start - m
max_pos_mul = max_pos_gap // min_gap
max_neg_mul = max_neg_gap // min_gap
return st.one_of(
st.integers(0, max_pos_mul).map(lambda n: start + min_gap * n),
st.integers(0, max_neg_mul).map(lambda n: start - min_gap * n),
)


@given(
num=hh.sizes,
dtype=st.none() | xps.numeric_dtypes(),
dtype=st.none() | xps.floating_dtypes(),
endpoint=st.booleans(),
data=st.data(),
)
def test_linspace(num, dtype, endpoint, data):
_dtype = dh.default_float if dtype is None else dtype

start = data.draw(xps.from_dtype(_dtype, **finite_kw), label="start")
if dh.is_float_dtype(_dtype):
stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop")
# avoid overflow errors
assume(not ah.isnan(ah.asarray(stop - start, dtype=_dtype)))
assume(not ah.isnan(ah.asarray(start - stop, dtype=_dtype)))
else:
if num == 0:
stop = start
else:
stop = data.draw(int_stops(start, num, _dtype, endpoint), label="stop")
stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop")
# avoid overflow errors
assume(not ah.isnan(ah.asarray(stop - start, dtype=_dtype)))
assume(not ah.isnan(ah.asarray(start - stop, dtype=_dtype)))

kw = data.draw(
hh.specified_kwargs(
Expand Down
2 changes: 1 addition & 1 deletion array_api_tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_matmul(x1, x2):
@given(
x=finite_matrices(),
kw=kwargs(keepdims=booleans(),
ord=sampled_from([-float('inf'), -2, -2, 1, 2, float('inf'), 'fro', 'nuc']))
ord=sampled_from([-float('inf'), -2, -1, 1, 2, float('inf'), 'fro', 'nuc']))
)
def test_matrix_norm(x, kw):
res = linalg.matrix_norm(x, **kw)
Expand Down

0 comments on commit a2f7bd5

Please sign in to comment.