Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dedicated kernels for in-place dpt.divide and dpt.floor_divide #1431

Merged
merged 3 commits into from
Oct 11, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adds tests for inplace division behavior
  • Loading branch information
ndgrigorian committed Oct 10, 2023
commit b5c3ee7abfe943a5dabd28c3bade6bc169342c0a
71 changes: 70 additions & 1 deletion dpctl/tests/elementwise/test_divide.py
Original file line number Diff line number Diff line change
@@ -21,9 +21,16 @@

import dpctl
import dpctl.tensor as dpt
from dpctl.tensor._type_utils import _can_cast
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported

from .utils import _all_dtypes, _compare_dtypes, _usm_types
from .utils import (
_all_dtypes,
_compare_dtypes,
_complex_fp_dtypes,
_real_fp_dtypes,
_usm_types,
)


@pytest.mark.parametrize("op1_dtype", _all_dtypes)
@@ -187,3 +194,65 @@ def __sycl_usm_array_interface__(self):
c = Canary()
with pytest.raises(ValueError):
dpt.divide(a, c)


@pytest.mark.parametrize("dtype", _real_fp_dtypes + _complex_fp_dtypes)
def test_divide_inplace_python_scalar(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
X = dpt.zeros((10, 10), dtype=dtype, sycl_queue=q)
dt_kind = X.dtype.kind
if dt_kind == "f":
X /= float(1)
elif dt_kind == "c":
X /= complex(1)


@pytest.mark.parametrize("op1_dtype", _all_dtypes)
@pytest.mark.parametrize("op2_dtype", _all_dtypes)
def test_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(op1_dtype, q)
skip_if_dtype_not_supported(op2_dtype, q)

sz = 127
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)

dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
# out array only valid if it is inexact
if (
_can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64)
and dpt.dtype(op1_dtype).kind in "fc"
):
ar1 /= ar2
assert dpt.all(ar1 == 1)

ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
ar3 /= ar4
assert dpt.all(ar3 == 1)
else:
with pytest.raises(TypeError):
ar1 /= ar2
dpt.divide(ar1, ar2, out=ar1)

# out is second arg
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)
if (
_can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64)
and dpt.dtype(op2_dtype).kind in "fc"
):
dpt.divide(ar1, ar2, out=ar2)
assert dpt.all(ar2 == 1)

ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
dpt.divide(ar3, ar4, out=ar4)
dpt.all(ar4 == 1)
else:
with pytest.raises(TypeError):
dpt.divide(ar1, ar2, out=ar2)
72 changes: 67 additions & 5 deletions dpctl/tests/elementwise/test_floor_divide.py
Original file line number Diff line number Diff line change
@@ -21,13 +21,19 @@

import dpctl
import dpctl.tensor as dpt
from dpctl.tensor._type_utils import _can_cast
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported

from .utils import _compare_dtypes, _no_complex_dtypes, _usm_types
from .utils import (
_compare_dtypes,
_integral_dtypes,
_no_complex_dtypes,
_usm_types,
)


@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes)
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes)
@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:])
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:])
def test_floor_divide_dtype_matrix(op1_dtype, op2_dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(op1_dtype, q)
@@ -133,7 +139,7 @@ def test_floor_divide_broadcasting():
assert (dpt.asnumpy(r2) == expected2.astype(r2.dtype)).all()


@pytest.mark.parametrize("arr_dt", _no_complex_dtypes)
@pytest.mark.parametrize("arr_dt", _no_complex_dtypes[1:])
def test_floor_divide_python_scalar(arr_dt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(arr_dt, q)
@@ -204,7 +210,7 @@ def test_floor_divide_gh_1247():
)


@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:9])
@pytest.mark.parametrize("dtype", _integral_dtypes)
def test_floor_divide_integer_zero(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
@@ -255,3 +261,59 @@ def test_floor_divide_special_cases():
res = dpt.floor_divide(x, y)
res_np = np.floor_divide(dpt.asnumpy(x), dpt.asnumpy(y))
np.testing.assert_array_equal(dpt.asnumpy(res), res_np)


@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:])
def test_divide_inplace_python_scalar(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
X = dpt.zeros((10, 10), dtype=dtype, sycl_queue=q)
dt_kind = X.dtype.kind
if dt_kind in "ui":
X //= int(1)
elif dt_kind == "f":
X //= float(1)


@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:])
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:])
def test_floor_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(op1_dtype, q)
skip_if_dtype_not_supported(op2_dtype, q)

sz = 127
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)

dev = q.sycl_device
_fp16 = dev.has_aspect_fp16
_fp64 = dev.has_aspect_fp64
# out array only valid if it is inexact
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
ar1 //= ar2
assert dpt.all(ar1 == 1)

ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
ar3 //= ar4
assert dpt.all(ar3 == 1)
else:
with pytest.raises(TypeError):
ar1 //= ar2
dpt.floor_divide(ar1, ar2, out=ar1)

# out is second arg
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)
if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64):
dpt.floor_divide(ar1, ar2, out=ar2)
assert dpt.all(ar2 == 1)

ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
dpt.floor_divide(ar3, ar4, out=ar4)
dpt.all(ar4 == 1)
else:
with pytest.raises(TypeError):
dpt.floor_divide(ar1, ar2, out=ar2)