Skip to content

Commit

Permalink
Reuse OneDPL implementation of std::nth_element() for partition of 1D…
Browse files Browse the repository at this point in the history
… array
  • Loading branch information
antonwolfy committed May 22, 2023
1 parent 25e6b9a commit 5e1d133
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 11 deletions.
25 changes: 23 additions & 2 deletions dpnp/backend/kernels/dpnp_krnl_sorting.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//*****************************************************************************
// Copyright (c) 2016-2020, Intel Corporation
// Copyright (c) 2016-2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -160,6 +160,24 @@ DPCTLSyclEventRef dpnp_partition_c(DPCTLSyclQueueRef q_ref,

sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));

if (ndim == 1) // 1d array with C-contiguous data
{
_DataType* arr = static_cast<_DataType*>(array1_in);
_DataType* result = static_cast<_DataType*>(result1);

auto policy = oneapi::dpl::execution::make_device_policy<dpnp_partition_c_kernel<_DataType>>(q);

// fill the result array with data from input one
q.memcpy(result, arr, size * sizeof(_DataType)).wait();

// make a partial sorting such that:
// 1. result[0 <= i < kth] <= result[kth]
// 2. result[kth <= i < size] >= result[kth]
// event-blocking call, no need for wait()
std::nth_element(policy, result, result + kth, result + size, dpnp_less_comp());
return event_ref;
}

DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size, true);
DPNPC_ptr_adapter<_DataType> input2_ptr(q_ref, array2_in, size, true);
DPNPC_ptr_adapter<_DataType> result1_ptr(q_ref, result1, size, true, true);
Expand All @@ -181,7 +199,7 @@ DPCTLSyclEventRef dpnp_partition_c(DPCTLSyclQueueRef q_ref,
size_t ind = j - ind_begin;
matrix[ind] = arr2[j];
}
std::partial_sort(matrix, matrix + shape_[ndim - 1], matrix + shape_[ndim - 1]);
std::partial_sort(matrix, matrix + shape_[ndim - 1], matrix + shape_[ndim - 1], dpnp_less_comp());
for (size_t j = ind_begin; j < ind_end + 1; ++j)
{
size_t ind = j - ind_begin;
Expand Down Expand Up @@ -492,10 +510,13 @@ void func_map_init_sorting(func_map_t& fmap)
fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_partition_default_c<float>};
fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_partition_default_c<double>};

fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_partition_ext_c<bool>};
fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_partition_ext_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_partition_ext_c<int64_t>};
fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_partition_ext_c<float>};
fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_partition_ext_c<double>};
fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_C64][eft_C64] = {eft_C64, (void*)dpnp_partition_ext_c<std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_partition_ext_c<std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED][eft_INT][eft_INT] = {
eft_INT, (void*)dpnp_searchsorted_default_c<int32_t, int64_t>};
Expand Down
53 changes: 53 additions & 0 deletions dpnp/backend/src/dpnp_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,59 @@ constexpr auto both_types_are_any_of = std::conjunction_v<is_any<T1, Ts...>, is_
template <typename T1, typename T2, typename... Ts>
constexpr auto none_of_both_types = !std::disjunction_v<is_any<T1, Ts...>, is_any<T2, Ts...>>;


/**
* @brief If the type _Tp is a reference type, provides the member typedef type which is the type referred to by _Tp
* with its topmost cv-qualifiers removed. Otherwise type is _Tp with its topmost cv-qualifiers removed.
*
* @note std::remove_cvref is only available since c++20
*/
template<typename _Tp>
using dpnp_remove_cvref_t = typename std::remove_cv_t<typename std::remove_reference_t<_Tp>>;


/**
* @brief "<" comparison with complex types support.
*
* @note return a result of lexicographical "<" comparison for complex types.
*/
class dpnp_less_comp
{
public:
template <typename _Xp, typename _Yp>
bool operator()(_Xp&& __x, _Yp&& __y) const
{
if constexpr (both_types_are_same<dpnp_remove_cvref_t<_Xp>, dpnp_remove_cvref_t<_Yp>, std::complex<float>, std::complex<double>>)
{
bool ret = false;
_Xp a = std::forward<_Xp>(__x);
_Yp b = std::forward<_Yp>(__y);

if (a.real() < b.real())
{
ret = (a.imag() == a.imag() || b.imag() != b.imag());
}
else if (a.real() > b.real())
{
ret = (b.imag() != b.imag() && a.imag() == a.imag());
}
else if (a.real() == b.real() || (a.real() != a.real() && b.real() != b.real()))
{
ret = (a.imag() < b.imag() || (b.imag() != b.imag() && a.imag() == a.imag()));
}
else
{
ret = (b.real() != b.real());
}
return ret;
}
else
{
return std::forward<_Xp>(__x) < std::forward<_Yp>(__y);
}
}
};

/**
* FPTR interface initialization functions
*/
Expand Down
20 changes: 11 additions & 9 deletions tests/test_sort.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import pytest
from .helper import get_all_dtypes

import dpnp

import numpy
from numpy.testing import (
assert_array_equal
)


@pytest.mark.parametrize("kth",
[0, 1],
ids=['0', '1'])
@pytest.mark.parametrize("dtype",
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
ids=['float64', 'float32', 'int64', 'int32'])
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("array",
[[3, 4, 2, 1],
[[1, 0], [3, 0]],
Expand All @@ -25,11 +27,11 @@
'[[[1, -3], [3, 0]], [[5, 2], [0, 1]], [[1, 0], [0, 1]]]',
'[[[[8, 2], [3, 0]], [[5, 2], [0, 1]]], [[[1, 3], [3, 1]], [[5, 2], [0, 1]]]]'])
def test_partition(array, dtype, kth):
a = numpy.array(array, dtype)
ia = dpnp.array(array, dtype)
expected = numpy.partition(a, kth)
result = dpnp.partition(ia, kth)
numpy.testing.assert_array_equal(expected, result)
a = dpnp.array(array, dtype)
p = dpnp.partition(a, kth)

assert (p[0:kth] <= p[kth]).all()
assert (p[kth] <= p[kth + 1:]).all()


@pytest.mark.usefixtures("allow_fall_back_on_numpy")
Expand Down Expand Up @@ -77,4 +79,4 @@ def test_searchsorted(array, dtype, v_, side):
iv = dpnp.array(v_, dtype)
expected = numpy.searchsorted(a, v, side=side)
result = dpnp.searchsorted(ia, iv, side=side)
numpy.testing.assert_array_equal(expected, result)
assert_array_equal(expected, result)

0 comments on commit 5e1d133

Please sign in to comment.