Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Oct 12, 2023
1 parent cfefd68 commit 0e50f82
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions tests/test_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
assert_dtype_allclose,
get_all_dtypes,
get_complex_dtypes,
get_float_complex_dtypes,
get_float_dtypes,
has_support_aspect64,
is_cpu_device,
Expand Down Expand Up @@ -966,7 +967,7 @@ def test_invalid_out(self, out):


class TestDivide:
@pytest.mark.parametrize("dtype", get_float_dtypes() + get_complex_dtypes())
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
def test_divide(self, dtype):
array1_data = numpy.arange(10)
array2_data = numpy.arange(5, 15)
Expand All @@ -983,12 +984,11 @@ def test_divide(self, dtype):
np_array2 = numpy.array(array2_data, dtype=dtype)
expected = numpy.divide(np_array1, np_array2, out=out)

tol = 1e-07
assert_allclose(expected, result, rtol=tol, atol=tol)
assert_allclose(out, dp_out, rtol=tol, atol=tol)
assert_dtype_allclose(result, expected)
assert_dtype_allclose(dp_out, out)

@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
@pytest.mark.parametrize("dtype", get_float_dtypes() + get_complex_dtypes())
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
def test_out_dtypes(self, dtype):
size = 10

Expand All @@ -1010,12 +1010,10 @@ def test_out_dtypes(self, dtype):
dp_out = dpnp.empty(size, dtype=dtype)

result = dpnp.divide(dp_array1, dp_array2, out=dp_out)

tol = 1e-07
assert_allclose(expected, result, rtol=tol, atol=tol)
assert_dtype_allclose(result, expected)

@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
@pytest.mark.parametrize("dtype", get_float_dtypes() + get_complex_dtypes())
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
def test_out_overlap(self, dtype):
size = 15
# DPNP
Expand All @@ -1026,10 +1024,9 @@ def test_out_overlap(self, dtype):
np_a = numpy.arange(2 * size, dtype=dtype)
numpy.divide(np_a[size::], np_a[::2], out=np_a[:size:])

tol = 1e-07
assert_allclose(np_a, dp_a, rtol=tol, atol=tol)
assert_dtype_allclose(dp_a, np_a)

@pytest.mark.parametrize("dtype", get_float_dtypes() + get_complex_dtypes())
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
def test_inplace_strided_out(self, dtype):
size = 21

Expand Down

0 comments on commit 0e50f82

Please sign in to comment.