Skip to content

Commit

Permalink
Fix sorting bugs (esp MissingOptimization) that come up when using …
Browse files Browse the repository at this point in the history
…SortingAlgorithms.TimSort (#50171)
  • Loading branch information
LilithHafner authored Jun 16, 2023
1 parent c5b0a6c commit ba251e8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 14 deletions.
28 changes: 14 additions & 14 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export # not exported by Base
SMALL_ALGORITHM,
SMALL_THRESHOLD

abstract type Algorithm end

## functions requiring only ordering ##

Expand Down Expand Up @@ -436,7 +437,7 @@ for (sym, exp, type) in [
(:mn, :(throw(ArgumentError("mn is needed but has not been computed"))), :(eltype(v))),
(:mx, :(throw(ArgumentError("mx is needed but has not been computed"))), :(eltype(v))),
(:scratch, nothing, :(Union{Nothing, Vector})), # could have different eltype
(:allow_legacy_dispatch, true, Bool)]
(:legacy_dispatch_entry, nothing, Union{Nothing, Algorithm})]
usym = Symbol(:_, sym)
@eval function $usym(v, o, kw)
# using missing instead of nothing because scratch could === nothing.
Expand Down Expand Up @@ -499,8 +500,6 @@ internal or recursive calls.
"""
function _sort! end

abstract type Algorithm end


"""
MissingOptimization(next) <: Algorithm
Expand All @@ -524,12 +523,12 @@ struct WithoutMissingVector{T, U} <: AbstractVector{T}
new{nonmissingtype(eltype(data)), typeof(data)}(data)
end
end
Base.@propagate_inbounds function Base.getindex(v::WithoutMissingVector, i)
Base.@propagate_inbounds function Base.getindex(v::WithoutMissingVector, i::Integer)
out = v.data[i]
@assert !(out isa Missing)
out::eltype(v)
end
Base.@propagate_inbounds function Base.setindex!(v::WithoutMissingVector, x, i)
Base.@propagate_inbounds function Base.setindex!(v::WithoutMissingVector, x, i::Integer)
v.data[i] = x
v
end
Expand Down Expand Up @@ -590,8 +589,9 @@ function _sort!(v::AbstractVector, a::MissingOptimization, o::Ordering, kw)
# we can assume v is equal to eachindex(o.data) which allows a copying partition
# without allocations.
lo_i, hi_i = lo, hi
for i in eachindex(o.data) # equal to copy(v)
x = o.data[i]
cv = eachindex(o.data) # equal to copy(v)
for i in lo:hi
x = o.data[cv[i]]
if ismissing(x) == (o.order == Reverse) # should x go at the beginning/end?
v[lo_i] = i
lo_i += 1
Expand Down Expand Up @@ -2149,25 +2149,25 @@ end
# Support 3-, 5-, and 6-argument versions of sort! for calling into the internals in the old way
sort!(v::AbstractVector, a::Algorithm, o::Ordering) = sort!(v, firstindex(v), lastindex(v), a, o)
function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::Algorithm, o::Ordering)
_sort!(v, a, o, (; lo, hi, allow_legacy_dispatch=false))
_sort!(v, a, o, (; lo, hi, legacy_dispatch_entry=a))
v
end
sort!(v::AbstractVector, lo::Integer, hi::Integer, a::Algorithm, o::Ordering, _) = sort!(v, lo, hi, a, o)
function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::Algorithm, o::Ordering, scratch::Vector)
_sort!(v, a, o, (; lo, hi, scratch, allow_legacy_dispatch=false))
_sort!(v, a, o, (; lo, hi, scratch, legacy_dispatch_entry=a))
v
end

# Support dispatch on custom algorithms in the old way
# sort!(::AbstractVector, ::Integer, ::Integer, ::MyCustomAlgorithm, ::Ordering) = ...
function _sort!(v::AbstractVector, a::Algorithm, o::Ordering, kw)
@getkw lo hi scratch allow_legacy_dispatch
if allow_legacy_dispatch
@getkw lo hi scratch legacy_dispatch_entry
if legacy_dispatch_entry === a
# This error prevents infinite recursion for unknown algorithms
throw(ArgumentError("Base.Sort._sort!(::$(typeof(v)), ::$(typeof(a)), ::$(typeof(o)), ::Any) is not defined"))
else
sort!(v, lo, hi, a, o)
scratch
else
# This error prevents infinite recursion for unknown algorithms
throw(ArgumentError("Base.Sort._sort!(::$(typeof(v)), ::$(typeof(a)), ::$(typeof(o))) is not defined"))
end
end

Expand Down
40 changes: 40 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,46 @@ Base.similar(A::MyArray49392, ::Type{T}, dims::Dims{N}) where {T, N} = MyArray49
@test all(sort!(y, dims=2) .== sort!(x,dims=2))
end

@testset "MissingOptimization fastpath for Perm ordering when lo:hi ≠ eachindex(v)" begin
v = [rand() < .5 ? missing : rand() for _ in 1:100]
ix = collect(1:100)
sort!(ix, 1, 10, Base.Sort.DEFAULT_STABLE, Base.Order.Perm(Base.Order.Forward, v))
@test issorted(v[ix[1:10]])
end

struct NonScalarIndexingOfWithoutMissingVectorAlg <: Base.Sort.Algorithm end
function Base.Sort._sort!(v::AbstractVector, ::NonScalarIndexingOfWithoutMissingVectorAlg, o::Base.Order.Ordering, kw)
Base.Sort.@getkw lo hi
first_half = v[lo:lo+(hi-lo)÷2]
second_half = v[lo+(hi-lo)÷2+1:hi]
whole = v[lo:hi]
all(vcat(first_half, second_half) .=== whole) || error()
out = Base.Sort._sort!(whole, Base.Sort.DEFAULT_STABLE, o, (;kw..., lo=1, hi=length(whole)))
v[lo:hi] .= whole
out
end

@testset "Non-scaler indexing of WithoutMissingVector" begin
@testset "Unit test" begin
wmv = Base.Sort.WithoutMissingVector(Union{Missing, Int}[1, 7, 2, 9])
@test wmv[[1, 3]] == [1, 2]
@test wmv[1:3] == [1, 7, 2]
end
@testset "End to end" begin
alg = Base.Sort.InitialOptimizations(NonScalarIndexingOfWithoutMissingVectorAlg())
@test issorted(sort(rand(100); alg))
@test issorted(sort([rand() < .5 ? missing : randstring() for _ in 1:100]; alg))
end
end

struct DispatchLoopTestAlg <: Base.Sort.Algorithm end
function Base.sort!(v::AbstractVector, lo::Integer, hi::Integer, ::DispatchLoopTestAlg, order::Base.Order.Ordering)
sort!(view(v, lo:hi); order)
end
@testset "Support dispatch from the old style to the new style and back" begin
@test issorted(sort!(rand(100), Base.Sort.InitialOptimizations(DispatchLoopTestAlg()), Base.Order.Forward))
end

# This testset is at the end of the file because it is slow.
@testset "searchsorted" begin
numTypes = [ Int8, Int16, Int32, Int64, Int128,
Expand Down

0 comments on commit ba251e8

Please sign in to comment.