From 3c45e1346276b8966dca426eeb2c335b3864d01e Mon Sep 17 00:00:00 2001 From: vtavana <120411540+vtavana@users.noreply.github.com> Date: Wed, 18 Sep 2024 22:21:28 -0500 Subject: [PATCH] Improve implementation of `dpnp.kron` to avoid unnecessary copy for non-contiguous arrays (#2059) * remove continuity check in dpnp.kron * add additional assert check * update CHANGELOG.md --- CHANGELOG.md | 1 + dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 6 +---- tests/test_product.py | 30 ++++++++++++++++++++- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4343f59df32..5639698d6fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -104,6 +104,7 @@ In addition, this release completes implementation of `dpnp.fft` module and adds * `dpnp` uses pybind11 2.13.6 [#2041](https://github.com/IntelPython/dpnp/pull/2041) * Updated `dpnp.fft` backend to depend on `INTEL_MKL_VERSION` flag to ensures that the appropriate code segment is executed based on the version of OneMKL [#2035](https://github.com/IntelPython/dpnp/pull/2035) * Use `dpctl::tensor::alloc_utils::sycl_free_noexcept` instead of `sycl::free` in `host_task` tasks associated with life-time management of temporary USM allocations [#2058](https://github.com/IntelPython/dpnp/pull/2058) +* Improved implementation of `dpnp.kron` to avoid unnecessary copy for non-contiguous arrays [#2059](https://github.com/IntelPython/dpnp/pull/2059) ### Fixed diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 1da6c3f4400..e15bd93d7bb 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -676,10 +676,6 @@ def dpnp_kron(a, b, a_ndim, b_ndim): a_shape = a.shape b_shape = b.shape - if not a.flags.contiguous: - a = dpnp.reshape(a, a_shape) - if not b.flags.contiguous: - b = dpnp.reshape(b, b_shape) # Equalise the shapes by prepending smaller one with 1s a_shape = (1,) * max(0, b_ndim - a_ndim) + a_shape @@ -693,7 +689,7 @@ def dpnp_kron(a, b, a_ndim, b_ndim): ndim = max(b_ndim, a_ndim) a_arr = dpnp.expand_dims(a_arr, axis=tuple(range(1, 2 * ndim, 2))) b_arr = dpnp.expand_dims(b_arr, axis=tuple(range(0, 2 * ndim, 2))) - result = dpnp.multiply(a_arr, b_arr, order="C") + result = dpnp.multiply(a_arr, b_arr) # Reshape back return result.reshape(tuple(numpy.multiply(a_shape, b_shape))) diff --git a/tests/test_product.py b/tests/test_product.py index a15e82f6d90..7fa26b13bc6 100644 --- a/tests/test_product.py +++ b/tests/test_product.py @@ -741,7 +741,7 @@ def test_kron_input_dtype_matrix(self, dtype1, dtype2): @pytest.mark.parametrize( "stride", [3, -1, -2, -4], ids=["3", "-1", "-2", "-4"] ) - def test_kron_strided(self, dtype, stride): + def test_kron_strided1(self, dtype, stride): a = numpy.arange(20, dtype=dtype) b = numpy.arange(20, dtype=dtype) ia = dpnp.array(a) @@ -751,6 +751,34 @@ def test_kron_strided(self, dtype, stride): expected = numpy.kron(a[::stride], b[::stride]) assert_dtype_allclose(result, expected) + @pytest.mark.parametrize("stride", [2, -1, -2], ids=["2", "-1", "-2"]) + def test_kron_strided2(self, stride): + a = numpy.arange(48).reshape(6, 8) + b = numpy.arange(480).reshape(6, 8, 10) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.kron( + ia[::stride, ::stride], ib[::stride, ::stride, ::stride] + ) + expected = numpy.kron( + a[::stride, ::stride], b[::stride, ::stride, ::stride] + ) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("order", ["C", "F", "A"]) + def test_kron_order(self, order): + a = numpy.arange(48).reshape(6, 8, order=order) + b = numpy.arange(480).reshape(6, 8, 10, order=order) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.kron(ia, ib) + expected = numpy.kron(a, b) + assert result.flags["C_CONTIGUOUS"] == expected.flags["C_CONTIGUOUS"] + assert result.flags["F_CONTIGUOUS"] == expected.flags["F_CONTIGUOUS"] + assert_dtype_allclose(result, expected) + class TestMultiDot: def setup_method(self):