From 5e1d1335f998b9c324e1b3cec936e9a973e9d1a1 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Mon, 22 May 2023 06:58:02 -0500 Subject: [PATCH] Reuse OneDPL implementation of std::nth_element() for partition of 1D array --- dpnp/backend/kernels/dpnp_krnl_sorting.cpp | 25 +++++++++- dpnp/backend/src/dpnp_fptr.hpp | 53 ++++++++++++++++++++++ tests/test_sort.py | 20 ++++---- 3 files changed, 87 insertions(+), 11 deletions(-) diff --git a/dpnp/backend/kernels/dpnp_krnl_sorting.cpp b/dpnp/backend/kernels/dpnp_krnl_sorting.cpp index 614bb94f070..01bc26cdf8f 100644 --- a/dpnp/backend/kernels/dpnp_krnl_sorting.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_sorting.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2016-2020, Intel Corporation +// Copyright (c) 2016-2023, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without @@ -160,6 +160,24 @@ DPCTLSyclEventRef dpnp_partition_c(DPCTLSyclQueueRef q_ref, sycl::queue q = *(reinterpret_cast(q_ref)); + if (ndim == 1) // 1d array with C-contiguous data + { + _DataType* arr = static_cast<_DataType*>(array1_in); + _DataType* result = static_cast<_DataType*>(result1); + + auto policy = oneapi::dpl::execution::make_device_policy>(q); + + // fill the result array with data from input one + q.memcpy(result, arr, size * sizeof(_DataType)).wait(); + + // make a partial sorting such that: + // 1. result[0 <= i < kth] <= result[kth] + // 2. result[kth <= i < size] >= result[kth] + // event-blocking call, no need for wait() + std::nth_element(policy, result, result + kth, result + size, dpnp_less_comp()); + return event_ref; + } + DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size, true); DPNPC_ptr_adapter<_DataType> input2_ptr(q_ref, array2_in, size, true); DPNPC_ptr_adapter<_DataType> result1_ptr(q_ref, result1, size, true, true); @@ -181,7 +199,7 @@ DPCTLSyclEventRef dpnp_partition_c(DPCTLSyclQueueRef q_ref, size_t ind = j - ind_begin; matrix[ind] = arr2[j]; } - std::partial_sort(matrix, matrix + shape_[ndim - 1], matrix + shape_[ndim - 1]); + std::partial_sort(matrix, matrix + shape_[ndim - 1], matrix + shape_[ndim - 1], dpnp_less_comp()); for (size_t j = ind_begin; j < ind_end + 1; ++j) { size_t ind = j - ind_begin; @@ -492,10 +510,13 @@ void func_map_init_sorting(func_map_t& fmap) fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_partition_default_c}; fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_partition_default_c}; + fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_partition_ext_c}; fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_partition_ext_c}; fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_partition_ext_c}; fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_partition_ext_c}; fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_partition_ext_c}; + fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_C64][eft_C64] = {eft_C64, (void*)dpnp_partition_ext_c>}; + fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_partition_ext_c>}; fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED][eft_INT][eft_INT] = { eft_INT, (void*)dpnp_searchsorted_default_c}; diff --git a/dpnp/backend/src/dpnp_fptr.hpp b/dpnp/backend/src/dpnp_fptr.hpp index 9f8c102bca5..cb33e70185e 100644 --- a/dpnp/backend/src/dpnp_fptr.hpp +++ b/dpnp/backend/src/dpnp_fptr.hpp @@ -187,6 +187,59 @@ constexpr auto both_types_are_any_of = std::conjunction_v, is_ template constexpr auto none_of_both_types = !std::disjunction_v, is_any>; + +/** + * @brief If the type _Tp is a reference type, provides the member typedef type which is the type referred to by _Tp + * with its topmost cv-qualifiers removed. Otherwise type is _Tp with its topmost cv-qualifiers removed. + * + * @note std::remove_cvref is only available since c++20 + */ +template +using dpnp_remove_cvref_t = typename std::remove_cv_t>; + + +/** + * @brief "<" comparison with complex types support. + * + * @note return a result of lexicographical "<" comparison for complex types. + */ +class dpnp_less_comp +{ +public: + template + bool operator()(_Xp&& __x, _Yp&& __y) const + { + if constexpr (both_types_are_same, dpnp_remove_cvref_t<_Yp>, std::complex, std::complex>) + { + bool ret = false; + _Xp a = std::forward<_Xp>(__x); + _Yp b = std::forward<_Yp>(__y); + + if (a.real() < b.real()) + { + ret = (a.imag() == a.imag() || b.imag() != b.imag()); + } + else if (a.real() > b.real()) + { + ret = (b.imag() != b.imag() && a.imag() == a.imag()); + } + else if (a.real() == b.real() || (a.real() != a.real() && b.real() != b.real())) + { + ret = (a.imag() < b.imag() || (b.imag() != b.imag() && a.imag() == a.imag())); + } + else + { + ret = (b.real() != b.real()); + } + return ret; + } + else + { + return std::forward<_Xp>(__x) < std::forward<_Yp>(__y); + } + } +}; + /** * FPTR interface initialization functions */ diff --git a/tests/test_sort.py b/tests/test_sort.py index aa633c0c3ad..975c654cbb9 100644 --- a/tests/test_sort.py +++ b/tests/test_sort.py @@ -1,16 +1,18 @@ import pytest +from .helper import get_all_dtypes import dpnp import numpy +from numpy.testing import ( + assert_array_equal +) @pytest.mark.parametrize("kth", [0, 1], ids=['0', '1']) -@pytest.mark.parametrize("dtype", - [numpy.float64, numpy.float32, numpy.int64, numpy.int32], - ids=['float64', 'float32', 'int64', 'int32']) +@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) @pytest.mark.parametrize("array", [[3, 4, 2, 1], [[1, 0], [3, 0]], @@ -25,11 +27,11 @@ '[[[1, -3], [3, 0]], [[5, 2], [0, 1]], [[1, 0], [0, 1]]]', '[[[[8, 2], [3, 0]], [[5, 2], [0, 1]]], [[[1, 3], [3, 1]], [[5, 2], [0, 1]]]]']) def test_partition(array, dtype, kth): - a = numpy.array(array, dtype) - ia = dpnp.array(array, dtype) - expected = numpy.partition(a, kth) - result = dpnp.partition(ia, kth) - numpy.testing.assert_array_equal(expected, result) + a = dpnp.array(array, dtype) + p = dpnp.partition(a, kth) + + assert (p[0:kth] <= p[kth]).all() + assert (p[kth] <= p[kth + 1:]).all() @pytest.mark.usefixtures("allow_fall_back_on_numpy") @@ -77,4 +79,4 @@ def test_searchsorted(array, dtype, v_, side): iv = dpnp.array(v_, dtype) expected = numpy.searchsorted(a, v, side=side) result = dpnp.searchsorted(ia, iv, side=side) - numpy.testing.assert_array_equal(expected, result) + assert_array_equal(expected, result)