diff --git a/stl/inc/algorithm b/stl/inc/algorithm index 80a0ce5455..6f461d9e22 100644 --- a/stl/inc/algorithm +++ b/stl/inc/algorithm @@ -6452,22 +6452,22 @@ void inplace_merge(_ExPo&&, _BidIt _First, _BidIt _Mid, _BidIt _Last) noexcept / // FUNCTION TEMPLATE sort template -_CONSTEXPR20 _BidIt _Insertion_sort_unchecked(_BidIt _First, const _BidIt _Last, _Pr _Pred) { +_CONSTEXPR20 _BidIt _Insertion_sort_unchecked(const _BidIt _First, const _BidIt _Last, _Pr _Pred) { // insertion sort [_First, _Last) if (_First != _Last) { - for (_BidIt _Next = _First; ++_Next != _Last;) { // order next element - _BidIt _Next1 = _Next; - _Iter_value_t<_BidIt> _Val = _STD move(*_Next); + for (_BidIt _Mid = _First; ++_Mid != _Last;) { // order next element + _BidIt _Hole = _Mid; + _Iter_value_t<_BidIt> _Val = _STD move(*_Mid); if (_DEBUG_LT_PRED(_Pred, _Val, *_First)) { // found new earliest element, move to front - _Move_backward_unchecked(_First, _Next, ++_Next1); + _Move_backward_unchecked(_First, _Mid, ++_Hole); *_First = _STD move(_Val); } else { // look for insertion point after first - for (_BidIt _First1 = _Next1; _DEBUG_LT_PRED(_Pred, _Val, *--_First1); _Next1 = _First1) { - *_Next1 = _STD move(*_First1); // move hole down + for (_BidIt _Prev = _Hole; _DEBUG_LT_PRED(_Pred, _Val, *--_Prev); _Hole = _Prev) { + *_Hole = _STD move(*_Prev); // move hole down } - *_Next1 = _STD move(_Val); // insert element in hole + *_Hole = _STD move(_Val); // insert element in hole } } } @@ -6531,6 +6531,7 @@ _CONSTEXPR20 pair<_RanIt, _RanIt> _Partition_by_median_guess_unchecked(_RanIt _F for (;;) { // partition for (; _Gfirst < _Last; ++_Gfirst) { if (_DEBUG_LT_PRED(_Pred, *_Pfirst, *_Gfirst)) { + continue; } else if (_Pred(*_Gfirst, *_Pfirst)) { break; } else if (_Plast != _Gfirst) { @@ -6543,6 +6544,7 @@ _CONSTEXPR20 pair<_RanIt, _RanIt> _Partition_by_median_guess_unchecked(_RanIt _F for (; _First < _Glast; --_Glast) { if (_DEBUG_LT_PRED(_Pred, *_Prev_iter(_Glast), *_Pfirst)) { + continue; } else if (_Pred(*_Pfirst, *_Prev_iter(_Glast))) { break; } else if (--_Pfirst != _Prev_iter(_Glast)) { @@ -6628,6 +6630,156 @@ void sort(_ExPo&& _Exec, const _RanIt _First, const _RanIt _Last) noexcept /* te // order [_First, _Last) _STD sort(_STD forward<_ExPo>(_Exec), _First, _Last, less{}); } + +#ifdef __cpp_lib_concepts +namespace ranges { + // clang-format off + template + requires sortable<_It, _Pr, _Pj> + constexpr void _Insertion_sort_common(const _It _First, const _It _Last, _Pr _Pred, _Pj _Proj) { + // insertion sort [_First, _Last) + + if (_First == _Last) { // empty range is sorted + return; + } + + for (auto _Mid = _First; ++_Mid != _Last;) { // order next element + iter_value_t<_It> _Val = _RANGES iter_move(_Mid); + auto _Hole = _Mid; + + for (auto _Prev = _Hole;;) { + --_Prev; + if (!_STD invoke(_Pred, _STD invoke(_Proj, _Val), _STD invoke(_Proj, *_Prev))) { + break; + } + *_Hole = _RANGES iter_move(_Prev); // move hole down + if (--_Hole == _First) { + break; + } + } + + *_Hole = _STD move(_Val); // insert element in hole + } + } + + template + requires sortable<_It, _Pr, _Pj> + constexpr void _Med3_unchecked(_It _First, _It _Mid, _It _Last, _Pr _Pred, _Pj _Proj) { + // sort median of three elements to middle + if (_STD invoke(_Pred, _STD invoke(_Proj, *_Mid), _STD invoke(_Proj, *_First))) { + _RANGES iter_swap(_Mid, _First); + } + + if (!_STD invoke(_Pred, _STD invoke(_Proj, *_Last), _STD invoke(_Proj, *_Mid))) { + return; + } + + // swap middle and last, then test first again + _RANGES iter_swap(_Last, _Mid); + + if (_STD invoke(_Pred, _STD invoke(_Proj, *_Mid), _STD invoke(_Proj, *_First))) { + _RANGES iter_swap(_Mid, _First); + } + } + + template + requires sortable<_It, _Pr, _Pj> + constexpr void _Guess_median_unchecked(_It _First, _It _Mid, _It _Last, _Pr _Pred, _Pj _Proj) { + // sort median element to middle + using _Diff = iter_difference_t<_It>; + const _Diff _Count = _Last - _First; + if (_Count > 40) { // Tukey's ninther + const _Diff _Step = (_Count + 1) >> 3; // +1 can't overflow because range was made inclusive in caller + const _Diff _Two_step = _Step << 1; // note: intentionally discards low-order bit + _Med3_unchecked(_First, _First + _Step, _First + _Two_step, _Pred, _Proj); + _Med3_unchecked(_Mid - _Step, _Mid, _Mid + _Step, _Pred, _Proj); + _Med3_unchecked(_Last - _Two_step, _Last - _Step, _Last, _Pred, _Proj); + _Med3_unchecked(_First + _Step, _Mid, _Last - _Step, _Pred, _Proj); + } else { + _Med3_unchecked(_First, _Mid, _Last, _Pred, _Proj); + } + } + + template + requires sortable<_It, _Pr, _Pj> + _NODISCARD constexpr subrange<_It> _Partition_by_median_guess_unchecked( + _It _First, _It _Last, _Pr _Pred, _Pj _Proj) { + // Choose a pivot, partition [_First, _Last) into elements less than pivot, elements equal to pivot, and + // elements greater than pivot; return the equal partition as a subrange. + + _It _Mid = _First + ((_Last - _First) >> 1); // shift for codegen + _RANGES _Guess_median_unchecked(_First, _Mid, _RANGES prev(_Last), _Pred, _Proj); + _It _Pfirst = _Mid; + _It _Plast = _RANGES next(_Pfirst); + + while (_First < _Pfirst + && !_STD invoke(_Pred, _STD invoke(_Proj, *_RANGES prev(_Pfirst)), _STD invoke(_Proj, *_Pfirst)) + && !_STD invoke(_Pred, _STD invoke(_Proj, *_Pfirst), _STD invoke(_Proj, *_RANGES prev(_Pfirst)))) { + --_Pfirst; + } + + while (_Plast < _Last + && !_STD invoke(_Pred, _STD invoke(_Proj, *_Plast), _STD invoke(_Proj, *_Pfirst)) + && !_STD invoke(_Pred, _STD invoke(_Proj, *_Pfirst), _STD invoke(_Proj, *_Plast))) { + ++_Plast; + } + + _It _Gfirst = _Plast; + _It _Glast = _Pfirst; + + for (;;) { // partition + for (; _Gfirst < _Last; ++_Gfirst) { + if (_STD invoke(_Pred, _STD invoke(_Proj, *_Pfirst), _STD invoke(_Proj, *_Gfirst))) { + continue; + } else if (_STD invoke(_Pred, _STD invoke(_Proj, *_Gfirst), _STD invoke(_Proj, *_Pfirst))) { + break; + } else if (_Plast != _Gfirst) { + _RANGES iter_swap(_Plast, _Gfirst); + ++_Plast; + } else { + ++_Plast; + } + } + + for (; _First < _Glast; --_Glast) { + if (_STD invoke(_Pred, _STD invoke(_Proj, *_RANGES prev(_Glast)), _STD invoke(_Proj, *_Pfirst))) { + continue; + } else if (_STD invoke( + _Pred, _STD invoke(_Proj, *_Pfirst), _STD invoke(_Proj, *_RANGES prev(_Glast)))) { + break; + } else if (--_Pfirst != _RANGES prev(_Glast)) { + _RANGES iter_swap(_Pfirst, _RANGES prev(_Glast)); + } + } + + if (_Glast == _First && _Gfirst == _Last) { + return {_STD move(_Pfirst), _STD move(_Plast)}; + } + + if (_Glast == _First) { // no room at bottom, rotate pivot upward + if (_Plast != _Gfirst) { + _RANGES iter_swap(_Pfirst, _Plast); + } + + ++_Plast; + _RANGES iter_swap(_Pfirst, _Gfirst); + ++_Pfirst; + ++_Gfirst; + } else if (_Gfirst == _Last) { // no room at top, rotate pivot downward + if (--_Glast != --_Pfirst) { + _RANGES iter_swap(_Glast, _Pfirst); + } + + _RANGES iter_swap(_Pfirst, --_Plast); + } else { + _RANGES iter_swap(_Gfirst, --_Glast); + ++_Gfirst; + } + } + } + // clang-format on +} // namespace ranges +#endif // __cpp_lib_concepts #endif // _HAS_CXX17 // FUNCTION TEMPLATE stable_sort @@ -6949,7 +7101,7 @@ _CONSTEXPR20 void nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last, _Pr _Pre if (_UMid.second <= _UNth) { _UFirst = _UMid.second; } else if (_UMid.first <= _UNth) { - return; // Nth inside fat pivot, done + return; // _Nth is in the subrange of elements equal to the pivot; done } else { _ULast = _UMid.first; } @@ -6977,6 +7129,82 @@ void nth_element(_ExPo&&, _RanIt _First, _RanIt _Nth, _RanIt _Last) noexcept /* // not parallelized at present, parallelism expected to be feasible in a future release _STD nth_element(_First, _Nth, _Last); } + +#ifdef __cpp_lib_concepts +namespace ranges { + // VARIABLE ranges::nth_element + class _Nth_element_fn : private _Not_quite_object { + public: + using _Not_quite_object::_Not_quite_object; + + // clang-format off + template _Se, class _Pr = ranges::less, class _Pj = identity> + requires sortable<_It, _Pr, _Pj> + constexpr _It operator()(_It _First, _It _Nth, _Se _Last, _Pr _Pred = {}, _Pj _Proj = {}) const { + _Adl_verify_range(_First, _Nth); + _Adl_verify_range(_Nth, _Last); + auto _UNth = _Get_unwrapped(_Nth); + auto _UFinal = _Get_final_iterator_unwrapped<_It>(_UNth, _STD move(_Last)); + _Seek_wrapped(_Nth, _UFinal); + + _Nth_element_common(_Get_unwrapped(_STD move(_First)), _STD move(_UNth), _STD move(_UFinal), + _Pass_fn(_Pred), _Pass_fn(_Proj)); + return _Nth; + } + + template + requires sortable, _Pr, _Pj> + constexpr borrowed_iterator_t<_Rng> operator()( + _Rng&& _Range, iterator_t<_Rng> _Nth, _Pr _Pred = {}, _Pj _Proj = {}) const { + _Adl_verify_range(_RANGES begin(_Range), _Nth); + _Adl_verify_range(_Nth, _RANGES end(_Range)); + auto _UNth = _Get_unwrapped(_Nth); + auto _UFinal = [&] { + if constexpr (common_range<_Rng>) { + return _Uend(_Range); + } else if constexpr (sized_range<_Rng>) { + return _RANGES next(_Ubegin(_Range), _RANGES distance(_Range)); + } else { + return _RANGES next(_UNth, _Uend(_Range)); + } + }(); + _Seek_wrapped(_Nth, _UFinal); + + _Nth_element_common( + _Ubegin(_Range), _STD move(_UNth), _STD move(_UFinal), _Pass_fn(_Pred), _Pass_fn(_Proj)); + return _Nth; + } + // clang-format on + private: + template + static constexpr void _Nth_element_common(_It _First, _It _Nth, _It _Last, _Pr _Pred, _Pj _Proj) { + _STL_INTERNAL_STATIC_ASSERT(random_access_iterator<_It>); + _STL_INTERNAL_STATIC_ASSERT(sortable<_It, _Pr, _Pj>); + + if (_Nth == _Last) { + return; // nothing to do + } + + while (_ISORT_MAX < _Last - _First) { // divide and conquer, ordering partition containing Nth + subrange<_It> _Mid = _RANGES _Partition_by_median_guess_unchecked(_First, _Last, _Pred, _Proj); + + if (_Mid.end() <= _Nth) { + _First = _Mid.end(); + } else if (_Mid.begin() <= _Nth) { + return; // _Nth is in the subrange of elements equal to the pivot; done + } else { + _Last = _Mid.begin(); + } + } + + // sort any remainder + _RANGES _Insertion_sort_common(_STD move(_First), _STD move(_Last), _STD move(_Pred), _STD move(_Proj)); + } + }; + + inline constexpr _Nth_element_fn nth_element{_Not_quite_object::_Construct_tag{}}; +} // namespace ranges +#endif // __cpp_lib_concepts #endif // _HAS_CXX17 // FUNCTION TEMPLATE includes diff --git a/tests/std/test.lst b/tests/std/test.lst index 56b6f4f719..ecae9f9325 100644 --- a/tests/std/test.lst +++ b/tests/std/test.lst @@ -265,6 +265,7 @@ tests\P0896R4_ranges_alg_minmax tests\P0896R4_ranges_alg_mismatch tests\P0896R4_ranges_alg_move tests\P0896R4_ranges_alg_none_of +tests\P0896R4_ranges_alg_nth_element tests\P0896R4_ranges_alg_partition tests\P0896R4_ranges_alg_partition_copy tests\P0896R4_ranges_alg_partition_point diff --git a/tests/std/tests/P0896R4_ranges_alg_nth_element/env.lst b/tests/std/tests/P0896R4_ranges_alg_nth_element/env.lst new file mode 100644 index 0000000000..f3ccc8613c --- /dev/null +++ b/tests/std/tests/P0896R4_ranges_alg_nth_element/env.lst @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +RUNALL_INCLUDE ..\concepts_matrix.lst diff --git a/tests/std/tests/P0896R4_ranges_alg_nth_element/test.cpp b/tests/std/tests/P0896R4_ranges_alg_nth_element/test.cpp new file mode 100644 index 0000000000..36b3f93665 --- /dev/null +++ b/tests/std/tests/P0896R4_ranges_alg_nth_element/test.cpp @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include + +#include + +using namespace std; + +// Validate dangling story +STATIC_ASSERT(same_as{}, nullptr_to)), ranges::dangling>); +STATIC_ASSERT(same_as{}, nullptr_to)), int*>); + +using P = pair; + +struct instantiator { + static constexpr int keys[] = {7, 6, 5, 4, 3, 2, 1, 0}; + + template + static constexpr void call() { +#if !defined(__clang__) && !defined(__EDG__) // TRANSITION, VSO-938163 +#pragma warning(suppress : 4127) // conditional expression is constant + if (!ranges::contiguous_range || !is_constant_evaluated()) +#endif // TRANSITION, VSO-938163 + { + using ranges::nth_element, ranges::all_of, ranges::find, ranges::iterator_t, ranges::less, ranges::none_of, + ranges::size; + + P input[size(keys)]; + const auto init = [&] { + for (size_t j = 0; j < size(keys); ++j) { + input[j] = P{keys[j], static_cast(10 + j)}; + } + }; + + // Validate range overload + for (int i = 0; i < int{size(keys)}; ++i) { + init(); + const R wrapped{input}; + const auto nth = wrapped.begin() + i; + const same_as> auto result = nth_element(wrapped, nth, less{}, get_first); + assert(result == wrapped.end()); + assert((*nth == P{i, static_cast(10 + (find(keys, i) - keys))})); + if (nth != wrapped.end()) { + assert(all_of(wrapped.begin(), nth, [&](auto&& x) { return get_first(x) <= get_first(*nth); })); + assert(all_of(nth, wrapped.end(), [&](auto&& x) { return get_first(*nth) <= get_first(x); })); + } + } + + // Validate iterator overload + for (int i = 0; i < int{size(keys)}; ++i) { + init(); + const R wrapped{input}; + const auto nth = wrapped.begin() + i; + const same_as> auto result = + nth_element(wrapped.begin(), nth, wrapped.end(), less{}, get_first); + assert(result == wrapped.end()); + assert((input[i] == P{i, static_cast(10 + (find(keys, i) - keys))})); + if (nth != wrapped.end()) { + assert(all_of(wrapped.begin(), nth, [&](auto&& x) { return get_first(x) <= get_first(*nth); })); + assert(all_of(nth, wrapped.end(), [&](auto&& x) { return get_first(*nth) <= get_first(x); })); + } + } + + { + // Validate empty range + const R range{}; + const same_as> auto result = nth_element(range, range.begin(), less{}, get_first); + assert(result == range.end()); + } + } + } +}; + +int main() { + STATIC_ASSERT((test_random(), true)); + test_random(); +}