Skip to content

Commit

Permalink
Adds tests for inplace division behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian committed Oct 10, 2023
1 parent eab67fb commit b5c3ee7
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 6 deletions.
71 changes: 70 additions & 1 deletion dpctl/tests/elementwise/test_divide.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,16 @@

import dpctl
import dpctl.tensor as dpt
from dpctl.tensor._type_utils import _can_cast
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported

from .utils import _all_dtypes, _compare_dtypes, _usm_types
from .utils import (
_all_dtypes,
_compare_dtypes,
_complex_fp_dtypes,
_real_fp_dtypes,
_usm_types,
)


@pytest.mark.parametrize("op1_dtype", _all_dtypes)
Expand Down Expand Up @@ -187,3 +194,65 @@ def __sycl_usm_array_interface__(self):
c = Canary()
with pytest.raises(ValueError):
dpt.divide(a, c)


@pytest.mark.parametrize("dtype", _real_fp_dtypes + _complex_fp_dtypes)
def test_divide_inplace_python_scalar(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
X = dpt.zeros((10, 10), dtype=dtype, sycl_queue=q)
dt_kind = X.dtype.kind
if dt_kind == "f":
X /= float(1)
elif dt_kind == "c":
X /= complex(1)


@pytest.mark.parametrize("op1_dtype", _all_dtypes)
@pytest.mark.parametrize("op2_dtype", _all_dtypes)
def test_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(op1_dtype, q)
skip_if_dtype_not_supported(op2_dtype, q)

sz = 127
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)

dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
# out array only valid if it is inexact
if (
_can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64)
and dpt.dtype(op1_dtype).kind in "fc"
):
ar1 /= ar2
assert dpt.all(ar1 == 1)

ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
ar3 /= ar4
assert dpt.all(ar3 == 1)
else:
with pytest.raises(TypeError):
ar1 /= ar2
dpt.divide(ar1, ar2, out=ar1)

# out is second arg
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)
if (
_can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64)
and dpt.dtype(op2_dtype).kind in "fc"
):
dpt.divide(ar1, ar2, out=ar2)
assert dpt.all(ar2 == 1)

ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
dpt.divide(ar3, ar4, out=ar4)
dpt.all(ar4 == 1)
else:
with pytest.raises(TypeError):
dpt.divide(ar1, ar2, out=ar2)
72 changes: 67 additions & 5 deletions dpctl/tests/elementwise/test_floor_divide.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,19 @@

import dpctl
import dpctl.tensor as dpt
from dpctl.tensor._type_utils import _can_cast
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported

from .utils import _compare_dtypes, _no_complex_dtypes, _usm_types
from .utils import (
_compare_dtypes,
_integral_dtypes,
_no_complex_dtypes,
_usm_types,
)


@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes)
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes)
@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:])
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:])
def test_floor_divide_dtype_matrix(op1_dtype, op2_dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(op1_dtype, q)
Expand Down Expand Up @@ -133,7 +139,7 @@ def test_floor_divide_broadcasting():
assert (dpt.asnumpy(r2) == expected2.astype(r2.dtype)).all()


@pytest.mark.parametrize("arr_dt", _no_complex_dtypes)
@pytest.mark.parametrize("arr_dt", _no_complex_dtypes[1:])
def test_floor_divide_python_scalar(arr_dt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(arr_dt, q)
Expand Down Expand Up @@ -204,7 +210,7 @@ def test_floor_divide_gh_1247():
)


@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:9])
@pytest.mark.parametrize("dtype", _integral_dtypes)
def test_floor_divide_integer_zero(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
Expand Down Expand Up @@ -255,3 +261,59 @@ def test_floor_divide_special_cases():
res = dpt.floor_divide(x, y)
res_np = np.floor_divide(dpt.asnumpy(x), dpt.asnumpy(y))
np.testing.assert_array_equal(dpt.asnumpy(res), res_np)


@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:])
def test_divide_inplace_python_scalar(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
X = dpt.zeros((10, 10), dtype=dtype, sycl_queue=q)
dt_kind = X.dtype.kind
if dt_kind in "ui":
X //= int(1)
elif dt_kind == "f":
X //= float(1)


@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:])
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:])
def test_floor_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(op1_dtype, q)
skip_if_dtype_not_supported(op2_dtype, q)

sz = 127
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)

dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
# out array only valid if it is inexact
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
ar1 //= ar2
assert dpt.all(ar1 == 1)

ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
ar3 //= ar4
assert dpt.all(ar3 == 1)
else:
with pytest.raises(TypeError):
ar1 //= ar2
dpt.floor_divide(ar1, ar2, out=ar1)

# out is second arg
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)
if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64):
dpt.floor_divide(ar1, ar2, out=ar2)
assert dpt.all(ar2 == 1)

ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
dpt.floor_divide(ar3, ar4, out=ar4)
dpt.all(ar4 == 1)
else:
with pytest.raises(TypeError):
dpt.floor_divide(ar1, ar2, out=ar2)

0 comments on commit b5c3ee7

Please sign in to comment.