From ba251e8d551dfa0b4d8a86932e06b4d1162d8b9f Mon Sep 17 00:00:00 2001 From: Lilith Orion Hafner Date: Fri, 16 Jun 2023 15:49:24 -0500 Subject: [PATCH] Fix sorting bugs (esp `MissingOptimization`) that come up when using SortingAlgorithms.TimSort (#50171) --- base/sort.jl | 28 ++++++++++++++-------------- test/sorting.jl | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/base/sort.jl b/base/sort.jl index 99f2ed3e1aeb8..90f8755d3b1a4 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -44,6 +44,7 @@ export # not exported by Base SMALL_ALGORITHM, SMALL_THRESHOLD +abstract type Algorithm end ## functions requiring only ordering ## @@ -436,7 +437,7 @@ for (sym, exp, type) in [ (:mn, :(throw(ArgumentError("mn is needed but has not been computed"))), :(eltype(v))), (:mx, :(throw(ArgumentError("mx is needed but has not been computed"))), :(eltype(v))), (:scratch, nothing, :(Union{Nothing, Vector})), # could have different eltype - (:allow_legacy_dispatch, true, Bool)] + (:legacy_dispatch_entry, nothing, Union{Nothing, Algorithm})] usym = Symbol(:_, sym) @eval function $usym(v, o, kw) # using missing instead of nothing because scratch could === nothing. @@ -499,8 +500,6 @@ internal or recursive calls. """ function _sort! end -abstract type Algorithm end - """ MissingOptimization(next) <: Algorithm @@ -524,12 +523,12 @@ struct WithoutMissingVector{T, U} <: AbstractVector{T} new{nonmissingtype(eltype(data)), typeof(data)}(data) end end -Base.@propagate_inbounds function Base.getindex(v::WithoutMissingVector, i) +Base.@propagate_inbounds function Base.getindex(v::WithoutMissingVector, i::Integer) out = v.data[i] @assert !(out isa Missing) out::eltype(v) end -Base.@propagate_inbounds function Base.setindex!(v::WithoutMissingVector, x, i) +Base.@propagate_inbounds function Base.setindex!(v::WithoutMissingVector, x, i::Integer) v.data[i] = x v end @@ -590,8 +589,9 @@ function _sort!(v::AbstractVector, a::MissingOptimization, o::Ordering, kw) # we can assume v is equal to eachindex(o.data) which allows a copying partition # without allocations. lo_i, hi_i = lo, hi - for i in eachindex(o.data) # equal to copy(v) - x = o.data[i] + cv = eachindex(o.data) # equal to copy(v) + for i in lo:hi + x = o.data[cv[i]] if ismissing(x) == (o.order == Reverse) # should x go at the beginning/end? v[lo_i] = i lo_i += 1 @@ -2149,25 +2149,25 @@ end # Support 3-, 5-, and 6-argument versions of sort! for calling into the internals in the old way sort!(v::AbstractVector, a::Algorithm, o::Ordering) = sort!(v, firstindex(v), lastindex(v), a, o) function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::Algorithm, o::Ordering) - _sort!(v, a, o, (; lo, hi, allow_legacy_dispatch=false)) + _sort!(v, a, o, (; lo, hi, legacy_dispatch_entry=a)) v end sort!(v::AbstractVector, lo::Integer, hi::Integer, a::Algorithm, o::Ordering, _) = sort!(v, lo, hi, a, o) function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::Algorithm, o::Ordering, scratch::Vector) - _sort!(v, a, o, (; lo, hi, scratch, allow_legacy_dispatch=false)) + _sort!(v, a, o, (; lo, hi, scratch, legacy_dispatch_entry=a)) v end # Support dispatch on custom algorithms in the old way # sort!(::AbstractVector, ::Integer, ::Integer, ::MyCustomAlgorithm, ::Ordering) = ... function _sort!(v::AbstractVector, a::Algorithm, o::Ordering, kw) - @getkw lo hi scratch allow_legacy_dispatch - if allow_legacy_dispatch + @getkw lo hi scratch legacy_dispatch_entry + if legacy_dispatch_entry === a + # This error prevents infinite recursion for unknown algorithms + throw(ArgumentError("Base.Sort._sort!(::$(typeof(v)), ::$(typeof(a)), ::$(typeof(o)), ::Any) is not defined")) + else sort!(v, lo, hi, a, o) scratch - else - # This error prevents infinite recursion for unknown algorithms - throw(ArgumentError("Base.Sort._sort!(::$(typeof(v)), ::$(typeof(a)), ::$(typeof(o))) is not defined")) end end diff --git a/test/sorting.jl b/test/sorting.jl index cf98182307088..147a70a5db7d9 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -1025,6 +1025,46 @@ Base.similar(A::MyArray49392, ::Type{T}, dims::Dims{N}) where {T, N} = MyArray49 @test all(sort!(y, dims=2) .== sort!(x,dims=2)) end +@testset "MissingOptimization fastpath for Perm ordering when lo:hi ≠ eachindex(v)" begin + v = [rand() < .5 ? missing : rand() for _ in 1:100] + ix = collect(1:100) + sort!(ix, 1, 10, Base.Sort.DEFAULT_STABLE, Base.Order.Perm(Base.Order.Forward, v)) + @test issorted(v[ix[1:10]]) +end + +struct NonScalarIndexingOfWithoutMissingVectorAlg <: Base.Sort.Algorithm end +function Base.Sort._sort!(v::AbstractVector, ::NonScalarIndexingOfWithoutMissingVectorAlg, o::Base.Order.Ordering, kw) + Base.Sort.@getkw lo hi + first_half = v[lo:lo+(hi-lo)÷2] + second_half = v[lo+(hi-lo)÷2+1:hi] + whole = v[lo:hi] + all(vcat(first_half, second_half) .=== whole) || error() + out = Base.Sort._sort!(whole, Base.Sort.DEFAULT_STABLE, o, (;kw..., lo=1, hi=length(whole))) + v[lo:hi] .= whole + out +end + +@testset "Non-scaler indexing of WithoutMissingVector" begin + @testset "Unit test" begin + wmv = Base.Sort.WithoutMissingVector(Union{Missing, Int}[1, 7, 2, 9]) + @test wmv[[1, 3]] == [1, 2] + @test wmv[1:3] == [1, 7, 2] + end + @testset "End to end" begin + alg = Base.Sort.InitialOptimizations(NonScalarIndexingOfWithoutMissingVectorAlg()) + @test issorted(sort(rand(100); alg)) + @test issorted(sort([rand() < .5 ? missing : randstring() for _ in 1:100]; alg)) + end +end + +struct DispatchLoopTestAlg <: Base.Sort.Algorithm end +function Base.sort!(v::AbstractVector, lo::Integer, hi::Integer, ::DispatchLoopTestAlg, order::Base.Order.Ordering) + sort!(view(v, lo:hi); order) +end +@testset "Support dispatch from the old style to the new style and back" begin + @test issorted(sort!(rand(100), Base.Sort.InitialOptimizations(DispatchLoopTestAlg()), Base.Order.Forward)) +end + # This testset is at the end of the file because it is slow. @testset "searchsorted" begin numTypes = [ Int8, Int16, Int32, Int64, Int128,