Skip to content

Commit

Permalink
Improve implementation of dpnp.kron to avoid unnecessary copy for n…
Browse files Browse the repository at this point in the history
…on-contiguous arrays (#2059)

* remove continuity check in dpnp.kron

* add additional assert check

* update CHANGELOG.md
  • Loading branch information
vtavana committed Sep 19, 2024
1 parent e2a6c6d commit 3c45e13
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ In addition, this release completes implementation of `dpnp.fft` module and adds
* `dpnp` uses pybind11 2.13.6 [#2041](https://github.com/IntelPython/dpnp/pull/2041)
* Updated `dpnp.fft` backend to depend on `INTEL_MKL_VERSION` flag to ensures that the appropriate code segment is executed based on the version of OneMKL [#2035](https://github.com/IntelPython/dpnp/pull/2035)
* Use `dpctl::tensor::alloc_utils::sycl_free_noexcept` instead of `sycl::free` in `host_task` tasks associated with life-time management of temporary USM allocations [#2058](https://github.com/IntelPython/dpnp/pull/2058)
* Improved implementation of `dpnp.kron` to avoid unnecessary copy for non-contiguous arrays [#2059](https://github.com/IntelPython/dpnp/pull/2059)

### Fixed

Expand Down
6 changes: 1 addition & 5 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,10 +676,6 @@ def dpnp_kron(a, b, a_ndim, b_ndim):

a_shape = a.shape
b_shape = b.shape
if not a.flags.contiguous:
a = dpnp.reshape(a, a_shape)
if not b.flags.contiguous:
b = dpnp.reshape(b, b_shape)

# Equalise the shapes by prepending smaller one with 1s
a_shape = (1,) * max(0, b_ndim - a_ndim) + a_shape
Expand All @@ -693,7 +689,7 @@ def dpnp_kron(a, b, a_ndim, b_ndim):
ndim = max(b_ndim, a_ndim)
a_arr = dpnp.expand_dims(a_arr, axis=tuple(range(1, 2 * ndim, 2)))
b_arr = dpnp.expand_dims(b_arr, axis=tuple(range(0, 2 * ndim, 2)))
result = dpnp.multiply(a_arr, b_arr, order="C")
result = dpnp.multiply(a_arr, b_arr)

# Reshape back
return result.reshape(tuple(numpy.multiply(a_shape, b_shape)))
Expand Down
30 changes: 29 additions & 1 deletion tests/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def test_kron_input_dtype_matrix(self, dtype1, dtype2):
@pytest.mark.parametrize(
"stride", [3, -1, -2, -4], ids=["3", "-1", "-2", "-4"]
)
def test_kron_strided(self, dtype, stride):
def test_kron_strided1(self, dtype, stride):
a = numpy.arange(20, dtype=dtype)
b = numpy.arange(20, dtype=dtype)
ia = dpnp.array(a)
Expand All @@ -751,6 +751,34 @@ def test_kron_strided(self, dtype, stride):
expected = numpy.kron(a[::stride], b[::stride])
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("stride", [2, -1, -2], ids=["2", "-1", "-2"])
def test_kron_strided2(self, stride):
a = numpy.arange(48).reshape(6, 8)
b = numpy.arange(480).reshape(6, 8, 10)
ia = dpnp.array(a)
ib = dpnp.array(b)

result = dpnp.kron(
ia[::stride, ::stride], ib[::stride, ::stride, ::stride]
)
expected = numpy.kron(
a[::stride, ::stride], b[::stride, ::stride, ::stride]
)
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("order", ["C", "F", "A"])
def test_kron_order(self, order):
a = numpy.arange(48).reshape(6, 8, order=order)
b = numpy.arange(480).reshape(6, 8, 10, order=order)
ia = dpnp.array(a)
ib = dpnp.array(b)

result = dpnp.kron(ia, ib)
expected = numpy.kron(a, b)
assert result.flags["C_CONTIGUOUS"] == expected.flags["C_CONTIGUOUS"]
assert result.flags["F_CONTIGUOUS"] == expected.flags["F_CONTIGUOUS"]
assert_dtype_allclose(result, expected)


class TestMultiDot:
def setup_method(self):
Expand Down

0 comments on commit 3c45e13

Please sign in to comment.