From 2517622fb950f99c3afbc00be1dd93286ac9d282 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Tue, 15 Mar 2022 14:12:27 +0100 Subject: [PATCH] Improve array promotion rules PR #384 added promotion rules that recursively promote the eltype and reftype, but this has the drawback that the resulting type is often not a supertype of the input types, forcing a conversion to a `CategoricalArray` which may have a `Union` type. Instead, do the same as `Array` and `AbstractArray` fallbacks, which simply call `promote_typejoin` if eltypes do not match. --- src/array.jl | 23 +++++++++------- test/13_arraycommon.jl | 60 ++++++++++++++++-------------------------- 2 files changed, 36 insertions(+), 47 deletions(-) diff --git a/src/array.jl b/src/array.jl index c95c97b8..40b34998 100644 --- a/src/array.jl +++ b/src/array.jl @@ -304,16 +304,12 @@ CategoricalMatrix(A::CategoricalArray{T, 2, R}; ## Promotion methods -Base.promote_rule(::Type{<:CategoricalArray{S}}, - ::Type{<:CategoricalArray{T}}) where {S, T} = - CategoricalArray{cat_promote_type(S, T)} -Base.promote_rule(::Type{<:CategoricalArray{S, N}}, - ::Type{<:CategoricalArray{T, N}}) where {S, T, N} = - CategoricalArray{cat_promote_type(S, T), N} -Base.promote_rule(::Type{<:CategoricalArray{S, N, R1}}, - ::Type{<:CategoricalArray{T, N, R2}}) where - {S, T, N, R1<:Integer, R2<:Integer} = - CategoricalArray{cat_promote_type(S, T), N, promote_type(R1, R2)} +# Identical behavior to the Array method +# Needed to prevent promote_result from returning an Array +# Note that eltype returns Any if a type parameter is omitted +Base.promote_rule(x::Type{<:CategoricalArray}, + y::Type{<:CategoricalArray}) = + Base.el_same(promote_type(eltype(x), eltype(y)), x, y) ## Conversion methods @@ -348,6 +344,13 @@ convert(::Type{CategoricalMatrix}, A::CategoricalMatrix) = A convert(::Type{CategoricalArray{T, N, R}}, A::AbstractArray{S, N}) where {S, T, N, R} = _convert(CategoricalArray{T, N, R}, A) +convert(::Type{CategoricalArray{T, N, R, V, C, U}}, + A::CategoricalArray{T, N, R, V, C, U}) where {T, N, R, V, C, U} = A +# V, C and U are not used since they are recomputed from T and R +convert(::Type{CategoricalArray{T, N, R, V, C, U}}, + A::AbstractArray{S, N}) where {S, T, N, R, V, C, U} = + _convert(CategoricalArray{T, N, R}, A) + function _convert(::Type{CategoricalArray{T, N, R}}, A::AbstractArray{S, N}; levels::Union{AbstractVector, Nothing}=nothing) where {S, T, N, R} check_supported_eltype(T, T) diff --git a/test/13_arraycommon.jl b/test/13_arraycommon.jl index fdb6ec5b..dd14dcc2 100644 --- a/test/13_arraycommon.jl +++ b/test/13_arraycommon.jl @@ -2225,53 +2225,39 @@ end end @testset "promotion" begin - @test promote_type(CategoricalVector{Int}, - CategoricalVector{String}) == - CategoricalVector{Union{Int, String}} - @test promote_type(CategoricalVector{Int, UInt32}, - CategoricalVector{String, UInt32}) == - CategoricalVector{Union{Int, String}, UInt32} - @test promote_type(CategoricalArray{Int, UInt32}, - CategoricalArray{String, UInt32}) == - CategoricalArray{Union{Int, String}, UInt32} - @test promote_type(CategoricalVector{Int, UInt32}, - CategoricalMatrix{String, UInt32}) == - CategoricalArray{Union{Int, String}} - @test promote_type(CategoricalVector{Int, UInt8}, - CategoricalVector{String, UInt16}) == - CategoricalVector{Union{Int, String}, UInt16} - - @test promote_type(CategoricalVector{Int8}, - CategoricalVector{Float64}) == - CategoricalVector{Float64} - @test promote_type(CategoricalVector{Int8, UInt32}, - CategoricalVector{Float64, UInt32}) == - CategoricalVector{Float64, UInt32} - @test promote_type(CategoricalArray{Int8, UInt32}, - CategoricalArray{Float64, UInt32}) == - CategoricalArray{Float64, UInt32} - @test promote_type(CategoricalVector{Int8, UInt32}, - CategoricalMatrix{Float64, UInt32}) == - CategoricalArray{Float64} - @test promote_type(CategoricalVector{Int8, UInt8}, - CategoricalVector{Float64, UInt16}) == - CategoricalVector{Float64, UInt16} - @test [CategoricalVector([1, 2]), CategoricalVector(["a", "b"])] isa - Vector{CategoricalVector{Union{Int, String}, UInt32}} + Vector{CategoricalVector{<:Any, UInt32, <:Any, <:Any, Union{}}} @test [CategoricalVector([1, missing]), CategoricalVector(["a", "b"])] isa - Vector{CategoricalVector{Union{Int, String, Missing}, UInt32}} + Vector{CategoricalVector{<:Any, UInt32}} + @test [CategoricalVector([1, missing]), + CategoricalVector([1, 2])] isa + Vector{CategoricalVector{Union{Missing, Int}, UInt32, Int64, + CategoricalValue{Int64, UInt32}, Missing}} @test [CategoricalVector([1, missing]), CategoricalVector(["a", missing])] isa - Vector{CategoricalVector{Union{Int, String, Missing}, UInt32}} + Vector{CategoricalVector{<:Any, UInt32, <:Any, <:Any, Missing}} @test [CategoricalVector([Int8(1), missing]), CategoricalVector([Int16(2)])] isa - Vector{CategoricalVector{Union{Int16, Missing}, UInt32}} + Vector{CategoricalVector{<:Any, UInt32}} @test [CategoricalVector([1, 2]), CategoricalMatrix(["a" "b"])] isa - Vector{CategoricalArray{Union{Int, String}}} + Vector{CategoricalArray{<:Any, <:Any, UInt32, <:Any, <:Any, Union{}}} + @test [CategoricalVector([1, 2]), + CategoricalMatrix([1 2])] isa + Vector{CategoricalArray{Int, <:Any, UInt32, Int, + CategoricalValue{Int, UInt32}, Union{}}} + @test [CategoricalVector([1, 2]), + CategoricalMatrix([1 missing])] isa + Vector{CategoricalArray{<:Any, <:Any, UInt32, Int, + CategoricalValue{Int, UInt32}}} + @test [categorical([1, 2], compress=true), + CategoricalVector([1, 2])] isa + Vector{CategoricalVector{Int, UInt32, Int, CategoricalValue{Int, UInt32}, Union{}}} + @test [categorical([1, 2], compress=true), + CategoricalVector(["a", "b"])] isa + Vector{CategoricalVector{<:Any, <:Integer, <:Any, <:Any, Union{}}} end end