Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ranges::nth_element #1063

Merged
merged 5 commits into from
Aug 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 237 additions & 9 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -6452,22 +6452,22 @@ void inplace_merge(_ExPo&&, _BidIt _First, _BidIt _Mid, _BidIt _Last) noexcept /

// FUNCTION TEMPLATE sort
template <class _BidIt, class _Pr>
_CONSTEXPR20 _BidIt _Insertion_sort_unchecked(_BidIt _First, const _BidIt _Last, _Pr _Pred) {
_CONSTEXPR20 _BidIt _Insertion_sort_unchecked(const _BidIt _First, const _BidIt _Last, _Pr _Pred) {
// insertion sort [_First, _Last)
if (_First != _Last) {
for (_BidIt _Next = _First; ++_Next != _Last;) { // order next element
_BidIt _Next1 = _Next;
_Iter_value_t<_BidIt> _Val = _STD move(*_Next);
for (_BidIt _Mid = _First; ++_Mid != _Last;) { // order next element
_BidIt _Hole = _Mid;
_Iter_value_t<_BidIt> _Val = _STD move(*_Mid);

if (_DEBUG_LT_PRED(_Pred, _Val, *_First)) { // found new earliest element, move to front
_Move_backward_unchecked(_First, _Next, ++_Next1);
_Move_backward_unchecked(_First, _Mid, ++_Hole);
*_First = _STD move(_Val);
} else { // look for insertion point after first
for (_BidIt _First1 = _Next1; _DEBUG_LT_PRED(_Pred, _Val, *--_First1); _Next1 = _First1) {
*_Next1 = _STD move(*_First1); // move hole down
for (_BidIt _Prev = _Hole; _DEBUG_LT_PRED(_Pred, _Val, *--_Prev); _Hole = _Prev) {
*_Hole = _STD move(*_Prev); // move hole down
}

*_Next1 = _STD move(_Val); // insert element in hole
*_Hole = _STD move(_Val); // insert element in hole
}
}
}
Expand Down Expand Up @@ -6531,6 +6531,7 @@ _CONSTEXPR20 pair<_RanIt, _RanIt> _Partition_by_median_guess_unchecked(_RanIt _F
for (;;) { // partition
for (; _Gfirst < _Last; ++_Gfirst) {
if (_DEBUG_LT_PRED(_Pred, *_Pfirst, *_Gfirst)) {
continue;
} else if (_Pred(*_Gfirst, *_Pfirst)) {
break;
} else if (_Plast != _Gfirst) {
Expand All @@ -6543,6 +6544,7 @@ _CONSTEXPR20 pair<_RanIt, _RanIt> _Partition_by_median_guess_unchecked(_RanIt _F

for (; _First < _Glast; --_Glast) {
if (_DEBUG_LT_PRED(_Pred, *_Prev_iter(_Glast), *_Pfirst)) {
continue;
} else if (_Pred(*_Pfirst, *_Prev_iter(_Glast))) {
break;
} else if (--_Pfirst != _Prev_iter(_Glast)) {
Expand Down Expand Up @@ -6628,6 +6630,156 @@ void sort(_ExPo&& _Exec, const _RanIt _First, const _RanIt _Last) noexcept /* te
// order [_First, _Last)
_STD sort(_STD forward<_ExPo>(_Exec), _First, _Last, less{});
}

#ifdef __cpp_lib_concepts
namespace ranges {
// clang-format off
template <bidirectional_iterator _It, class _Pr, class _Pj>
requires sortable<_It, _Pr, _Pj>
constexpr void _Insertion_sort_common(const _It _First, const _It _Last, _Pr _Pred, _Pj _Proj) {
// insertion sort [_First, _Last)

if (_First == _Last) { // empty range is sorted
return;
}

for (auto _Mid = _First; ++_Mid != _Last;) { // order next element
iter_value_t<_It> _Val = _RANGES iter_move(_Mid);
auto _Hole = _Mid;

for (auto _Prev = _Hole;;) {
--_Prev;
if (!_STD invoke(_Pred, _STD invoke(_Proj, _Val), _STD invoke(_Proj, *_Prev))) {
break;
}
*_Hole = _RANGES iter_move(_Prev); // move hole down
if (--_Hole == _First) {
break;
}
}

*_Hole = _STD move(_Val); // insert element in hole
}
}

template <random_access_iterator _It, class _Pr, class _Pj>
requires sortable<_It, _Pr, _Pj>
constexpr void _Med3_unchecked(_It _First, _It _Mid, _It _Last, _Pr _Pred, _Pj _Proj) {
// sort median of three elements to middle
CaseyCarter marked this conversation as resolved.
Show resolved Hide resolved
if (_STD invoke(_Pred, _STD invoke(_Proj, *_Mid), _STD invoke(_Proj, *_First))) {
_RANGES iter_swap(_Mid, _First);
}

if (!_STD invoke(_Pred, _STD invoke(_Proj, *_Last), _STD invoke(_Proj, *_Mid))) {
return;
}

// swap middle and last, then test first again
_RANGES iter_swap(_Last, _Mid);

if (_STD invoke(_Pred, _STD invoke(_Proj, *_Mid), _STD invoke(_Proj, *_First))) {
_RANGES iter_swap(_Mid, _First);
}
}

template <random_access_iterator _It, class _Pr, class _Pj>
requires sortable<_It, _Pr, _Pj>
constexpr void _Guess_median_unchecked(_It _First, _It _Mid, _It _Last, _Pr _Pred, _Pj _Proj) {
// sort median element to middle
using _Diff = iter_difference_t<_It>;
const _Diff _Count = _Last - _First;
if (_Count > 40) { // Tukey's ninther
CaseyCarter marked this conversation as resolved.
Show resolved Hide resolved
const _Diff _Step = (_Count + 1) >> 3; // +1 can't overflow because range was made inclusive in caller
const _Diff _Two_step = _Step << 1; // note: intentionally discards low-order bit
_Med3_unchecked(_First, _First + _Step, _First + _Two_step, _Pred, _Proj);
_Med3_unchecked(_Mid - _Step, _Mid, _Mid + _Step, _Pred, _Proj);
_Med3_unchecked(_Last - _Two_step, _Last - _Step, _Last, _Pred, _Proj);
_Med3_unchecked(_First + _Step, _Mid, _Last - _Step, _Pred, _Proj);
} else {
_Med3_unchecked(_First, _Mid, _Last, _Pred, _Proj);
}
}

template <random_access_iterator _It, class _Pr, class _Pj>
requires sortable<_It, _Pr, _Pj>
_NODISCARD constexpr subrange<_It> _Partition_by_median_guess_unchecked(
_It _First, _It _Last, _Pr _Pred, _Pj _Proj) {
// Choose a pivot, partition [_First, _Last) into elements less than pivot, elements equal to pivot, and
// elements greater than pivot; return the equal partition as a subrange.

_It _Mid = _First + ((_Last - _First) >> 1); // shift for codegen
_RANGES _Guess_median_unchecked(_First, _Mid, _RANGES prev(_Last), _Pred, _Proj);
_It _Pfirst = _Mid;
_It _Plast = _RANGES next(_Pfirst);

while (_First < _Pfirst
&& !_STD invoke(_Pred, _STD invoke(_Proj, *_RANGES prev(_Pfirst)), _STD invoke(_Proj, *_Pfirst))
&& !_STD invoke(_Pred, _STD invoke(_Proj, *_Pfirst), _STD invoke(_Proj, *_RANGES prev(_Pfirst)))) {
--_Pfirst;
}

while (_Plast < _Last
&& !_STD invoke(_Pred, _STD invoke(_Proj, *_Plast), _STD invoke(_Proj, *_Pfirst))
&& !_STD invoke(_Pred, _STD invoke(_Proj, *_Pfirst), _STD invoke(_Proj, *_Plast))) {
++_Plast;
}

_It _Gfirst = _Plast;
_It _Glast = _Pfirst;

for (;;) { // partition
for (; _Gfirst < _Last; ++_Gfirst) {
if (_STD invoke(_Pred, _STD invoke(_Proj, *_Pfirst), _STD invoke(_Proj, *_Gfirst))) {
CaseyCarter marked this conversation as resolved.
Show resolved Hide resolved
continue;
} else if (_STD invoke(_Pred, _STD invoke(_Proj, *_Gfirst), _STD invoke(_Proj, *_Pfirst))) {
break;
} else if (_Plast != _Gfirst) {
_RANGES iter_swap(_Plast, _Gfirst);
++_Plast;
} else {
++_Plast;
}
}

for (; _First < _Glast; --_Glast) {
if (_STD invoke(_Pred, _STD invoke(_Proj, *_RANGES prev(_Glast)), _STD invoke(_Proj, *_Pfirst))) {
CaseyCarter marked this conversation as resolved.
Show resolved Hide resolved
continue;
} else if (_STD invoke(
_Pred, _STD invoke(_Proj, *_Pfirst), _STD invoke(_Proj, *_RANGES prev(_Glast)))) {
break;
} else if (--_Pfirst != _RANGES prev(_Glast)) {
_RANGES iter_swap(_Pfirst, _RANGES prev(_Glast));
}
}

if (_Glast == _First && _Gfirst == _Last) {
return {_STD move(_Pfirst), _STD move(_Plast)};
}

if (_Glast == _First) { // no room at bottom, rotate pivot upward
if (_Plast != _Gfirst) {
_RANGES iter_swap(_Pfirst, _Plast);
}

++_Plast;
_RANGES iter_swap(_Pfirst, _Gfirst);
++_Pfirst;
++_Gfirst;
} else if (_Gfirst == _Last) { // no room at top, rotate pivot downward
if (--_Glast != --_Pfirst) {
_RANGES iter_swap(_Glast, _Pfirst);
}

_RANGES iter_swap(_Pfirst, --_Plast);
} else {
_RANGES iter_swap(_Gfirst, --_Glast);
++_Gfirst;
}
}
}
// clang-format on
} // namespace ranges
#endif // __cpp_lib_concepts
#endif // _HAS_CXX17

// FUNCTION TEMPLATE stable_sort
Expand Down Expand Up @@ -6949,7 +7101,7 @@ _CONSTEXPR20 void nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last, _Pr _Pre
if (_UMid.second <= _UNth) {
_UFirst = _UMid.second;
} else if (_UMid.first <= _UNth) {
return; // Nth inside fat pivot, done
return; // _Nth is in the subrange of elements equal to the pivot; done
} else {
_ULast = _UMid.first;
}
Expand Down Expand Up @@ -6977,6 +7129,82 @@ void nth_element(_ExPo&&, _RanIt _First, _RanIt _Nth, _RanIt _Last) noexcept /*
// not parallelized at present, parallelism expected to be feasible in a future release
_STD nth_element(_First, _Nth, _Last);
}

#ifdef __cpp_lib_concepts
namespace ranges {
// VARIABLE ranges::nth_element
class _Nth_element_fn : private _Not_quite_object {
CaseyCarter marked this conversation as resolved.
Show resolved Hide resolved
public:
using _Not_quite_object::_Not_quite_object;

// clang-format off
template <random_access_iterator _It, sentinel_for<_It> _Se, class _Pr = ranges::less, class _Pj = identity>
requires sortable<_It, _Pr, _Pj>
constexpr _It operator()(_It _First, _It _Nth, _Se _Last, _Pr _Pred = {}, _Pj _Proj = {}) const {
_Adl_verify_range(_First, _Nth);
_Adl_verify_range(_Nth, _Last);
auto _UNth = _Get_unwrapped(_Nth);
auto _UFinal = _Get_final_iterator_unwrapped<_It>(_UNth, _STD move(_Last));
_Seek_wrapped(_Nth, _UFinal);

_Nth_element_common(_Get_unwrapped(_STD move(_First)), _STD move(_UNth), _STD move(_UFinal),
_Pass_fn(_Pred), _Pass_fn(_Proj));
return _Nth;
}

template <random_access_range _Rng, class _Pr = ranges::less, class _Pj = identity>
requires sortable<iterator_t<_Rng>, _Pr, _Pj>
constexpr borrowed_iterator_t<_Rng> operator()(
_Rng&& _Range, iterator_t<_Rng> _Nth, _Pr _Pred = {}, _Pj _Proj = {}) const {
_Adl_verify_range(_RANGES begin(_Range), _Nth);
_Adl_verify_range(_Nth, _RANGES end(_Range));
auto _UNth = _Get_unwrapped(_Nth);
auto _UFinal = [&] {
if constexpr (common_range<_Rng>) {
return _Uend(_Range);
} else if constexpr (sized_range<_Rng>) {
return _RANGES next(_Ubegin(_Range), _RANGES distance(_Range));
} else {
return _RANGES next(_UNth, _Uend(_Range));
}
}();
_Seek_wrapped(_Nth, _UFinal);

_Nth_element_common(
_Ubegin(_Range), _STD move(_UNth), _STD move(_UFinal), _Pass_fn(_Pred), _Pass_fn(_Proj));
return _Nth;
}
// clang-format on
private:
template <class _It, class _Pr, class _Pj>
static constexpr void _Nth_element_common(_It _First, _It _Nth, _It _Last, _Pr _Pred, _Pj _Proj) {
_STL_INTERNAL_STATIC_ASSERT(random_access_iterator<_It>);
_STL_INTERNAL_STATIC_ASSERT(sortable<_It, _Pr, _Pj>);

if (_Nth == _Last) {
return; // nothing to do
}

while (_ISORT_MAX < _Last - _First) { // divide and conquer, ordering partition containing Nth
subrange<_It> _Mid = _RANGES _Partition_by_median_guess_unchecked(_First, _Last, _Pred, _Proj);

if (_Mid.end() <= _Nth) {
_First = _Mid.end();
} else if (_Mid.begin() <= _Nth) {
return; // _Nth is in the subrange of elements equal to the pivot; done
} else {
_Last = _Mid.begin();
}
}

// sort any remainder
_RANGES _Insertion_sort_common(_STD move(_First), _STD move(_Last), _STD move(_Pred), _STD move(_Proj));
CaseyCarter marked this conversation as resolved.
Show resolved Hide resolved
}
};

inline constexpr _Nth_element_fn nth_element{_Not_quite_object::_Construct_tag{}};
} // namespace ranges
#endif // __cpp_lib_concepts
#endif // _HAS_CXX17

// FUNCTION TEMPLATE includes
Expand Down
1 change: 1 addition & 0 deletions tests/std/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ tests\P0896R4_ranges_alg_minmax
tests\P0896R4_ranges_alg_mismatch
tests\P0896R4_ranges_alg_move
tests\P0896R4_ranges_alg_none_of
tests\P0896R4_ranges_alg_nth_element
tests\P0896R4_ranges_alg_partition
tests\P0896R4_ranges_alg_partition_copy
tests\P0896R4_ranges_alg_partition_point
Expand Down
4 changes: 4 additions & 0 deletions tests/std/tests/P0896R4_ranges_alg_nth_element/env.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

RUNALL_INCLUDE ..\concepts_matrix.lst
82 changes: 82 additions & 0 deletions tests/std/tests/P0896R4_ranges_alg_nth_element/test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <cassert>
#include <concepts>
#include <ranges>
#include <utility>

#include <range_algorithm_support.hpp>

using namespace std;

// Validate dangling story
STATIC_ASSERT(same_as<decltype(ranges::nth_element(borrowed<false>{}, nullptr_to<int>)), ranges::dangling>);
STATIC_ASSERT(same_as<decltype(ranges::nth_element(borrowed<true>{}, nullptr_to<int>)), int*>);

using P = pair<int, int>;

struct instantiator {
static constexpr int keys[] = {7, 6, 5, 4, 3, 2, 1, 0};

template <ranges::random_access_range R>
static constexpr void call() {
#if !defined(__clang__) && !defined(__EDG__) // TRANSITION, VSO-938163
#pragma warning(suppress : 4127) // conditional expression is constant
if (!ranges::contiguous_range<R> || !is_constant_evaluated())
#endif // TRANSITION, VSO-938163
{
using ranges::nth_element, ranges::all_of, ranges::find, ranges::iterator_t, ranges::less, ranges::none_of,
ranges::size;

P input[size(keys)];
const auto init = [&] {
for (size_t j = 0; j < size(keys); ++j) {
input[j] = P{keys[j], static_cast<int>(10 + j)};
}
};

// Validate range overload
for (int i = 0; i < int{size(keys)}; ++i) {
init();
const R wrapped{input};
const auto nth = wrapped.begin() + i;
const same_as<iterator_t<R>> auto result = nth_element(wrapped, nth, less{}, get_first);
assert(result == wrapped.end());
assert((*nth == P{i, static_cast<int>(10 + (find(keys, i) - keys))}));
if (nth != wrapped.end()) {
assert(all_of(wrapped.begin(), nth, [&](auto&& x) { return get_first(x) <= get_first(*nth); }));
assert(all_of(nth, wrapped.end(), [&](auto&& x) { return get_first(*nth) <= get_first(x); }));
}
}

// Validate iterator overload
for (int i = 0; i < int{size(keys)}; ++i) {
init();
const R wrapped{input};
const auto nth = wrapped.begin() + i;
const same_as<iterator_t<R>> auto result =
nth_element(wrapped.begin(), nth, wrapped.end(), less{}, get_first);
assert(result == wrapped.end());
assert((input[i] == P{i, static_cast<int>(10 + (find(keys, i) - keys))}));
if (nth != wrapped.end()) {
assert(all_of(wrapped.begin(), nth, [&](auto&& x) { return get_first(x) <= get_first(*nth); }));
assert(all_of(nth, wrapped.end(), [&](auto&& x) { return get_first(*nth) <= get_first(x); }));
}
}

{
// Validate empty range
const R range{};
const same_as<iterator_t<R>> auto result = nth_element(range, range.begin(), less{}, get_first);
assert(result == range.end());
}
}
}
};

int main() {
STATIC_ASSERT((test_random<instantiator, P>(), true));
test_random<instantiator, P>();
}