Skip to content

Commit

Permalink
Improve array promotion rules (#387)
Browse files Browse the repository at this point in the history
PR #384 added promotion rules that recursively promote the
eltype and reftype, but this has the drawback that the resulting
type is often not a supertype of the input types, forcing a conversion
to a `CategoricalArray` which may have a `Union` type.
Instead, do the same as `Array` and `AbstractArray` fallbacks, which
simply call `promote_typejoin` if eltypes do not match.
  • Loading branch information
nalimilan authored Mar 17, 2022
1 parent 4d7e344 commit c4ae8cc
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 47 deletions.
23 changes: 13 additions & 10 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,12 @@ CategoricalMatrix(A::CategoricalArray{T, 2, R};

## Promotion methods

Base.promote_rule(::Type{<:CategoricalArray{S}},
::Type{<:CategoricalArray{T}}) where {S, T} =
CategoricalArray{cat_promote_type(S, T)}
Base.promote_rule(::Type{<:CategoricalArray{S, N}},
::Type{<:CategoricalArray{T, N}}) where {S, T, N} =
CategoricalArray{cat_promote_type(S, T), N}
Base.promote_rule(::Type{<:CategoricalArray{S, N, R1}},
::Type{<:CategoricalArray{T, N, R2}}) where
{S, T, N, R1<:Integer, R2<:Integer} =
CategoricalArray{cat_promote_type(S, T), N, promote_type(R1, R2)}
# Identical behavior to the Array method
# Needed to prevent promote_result from returning an Array
# Note that eltype returns Any if a type parameter is omitted
Base.promote_rule(x::Type{<:CategoricalArray},
y::Type{<:CategoricalArray}) =
Base.el_same(promote_type(eltype(x), eltype(y)), x, y)

## Conversion methods

Expand Down Expand Up @@ -348,6 +344,13 @@ convert(::Type{CategoricalMatrix}, A::CategoricalMatrix) = A
convert(::Type{CategoricalArray{T, N, R}}, A::AbstractArray{S, N}) where {S, T, N, R} =
_convert(CategoricalArray{T, N, R}, A)

convert(::Type{CategoricalArray{T, N, R, V, C, U}},
A::CategoricalArray{T, N, R, V, C, U}) where {T, N, R, V, C, U} = A
# V, C and U are not used since they are recomputed from T and R
convert(::Type{CategoricalArray{T, N, R, V, C, U}},
A::AbstractArray{S, N}) where {S, T, N, R, V, C, U} =
_convert(CategoricalArray{T, N, R}, A)

function _convert(::Type{CategoricalArray{T, N, R}}, A::AbstractArray{S, N};
levels::Union{AbstractVector, Nothing}=nothing) where {S, T, N, R}
check_supported_eltype(T, T)
Expand Down
60 changes: 23 additions & 37 deletions test/13_arraycommon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2225,53 +2225,39 @@ end
end

@testset "promotion" begin
@test promote_type(CategoricalVector{Int},
CategoricalVector{String}) ==
CategoricalVector{Union{Int, String}}
@test promote_type(CategoricalVector{Int, UInt32},
CategoricalVector{String, UInt32}) ==
CategoricalVector{Union{Int, String}, UInt32}
@test promote_type(CategoricalArray{Int, UInt32},
CategoricalArray{String, UInt32}) ==
CategoricalArray{Union{Int, String}, UInt32}
@test promote_type(CategoricalVector{Int, UInt32},
CategoricalMatrix{String, UInt32}) ==
CategoricalArray{Union{Int, String}}
@test promote_type(CategoricalVector{Int, UInt8},
CategoricalVector{String, UInt16}) ==
CategoricalVector{Union{Int, String}, UInt16}

@test promote_type(CategoricalVector{Int8},
CategoricalVector{Float64}) ==
CategoricalVector{Float64}
@test promote_type(CategoricalVector{Int8, UInt32},
CategoricalVector{Float64, UInt32}) ==
CategoricalVector{Float64, UInt32}
@test promote_type(CategoricalArray{Int8, UInt32},
CategoricalArray{Float64, UInt32}) ==
CategoricalArray{Float64, UInt32}
@test promote_type(CategoricalVector{Int8, UInt32},
CategoricalMatrix{Float64, UInt32}) ==
CategoricalArray{Float64}
@test promote_type(CategoricalVector{Int8, UInt8},
CategoricalVector{Float64, UInt16}) ==
CategoricalVector{Float64, UInt16}

@test [CategoricalVector([1, 2]),
CategoricalVector(["a", "b"])] isa
Vector{CategoricalVector{Union{Int, String}, UInt32}}
Vector{CategoricalVector{<:Any, UInt32, <:Any, <:Any, Union{}}}
@test [CategoricalVector([1, missing]),
CategoricalVector(["a", "b"])] isa
Vector{CategoricalVector{Union{Int, String, Missing}, UInt32}}
Vector{CategoricalVector{<:Any, UInt32}}
@test [CategoricalVector([1, missing]),
CategoricalVector([1, 2])] isa
Vector{CategoricalVector{Union{Missing, Int}, UInt32, Int,
CategoricalValue{Int, UInt32}, Missing}}
@test [CategoricalVector([1, missing]),
CategoricalVector(["a", missing])] isa
Vector{CategoricalVector{Union{Int, String, Missing}, UInt32}}
Vector{CategoricalVector{<:Any, UInt32, <:Any, <:Any, Missing}}
@test [CategoricalVector([Int8(1), missing]),
CategoricalVector([Int16(2)])] isa
Vector{CategoricalVector{Union{Int16, Missing}, UInt32}}
Vector{CategoricalVector{<:Any, UInt32}}
@test [CategoricalVector([1, 2]),
CategoricalMatrix(["a" "b"])] isa
Vector{CategoricalArray{Union{Int, String}}}
Vector{CategoricalArray{<:Any, <:Any, UInt32, <:Any, <:Any, Union{}}}
@test [CategoricalVector([1, 2]),
CategoricalMatrix([1 2])] isa
Vector{CategoricalArray{Int, <:Any, UInt32, Int,
CategoricalValue{Int, UInt32}, Union{}}}
@test [CategoricalVector([1, 2]),
CategoricalMatrix([1 missing])] isa
Vector{CategoricalArray{<:Any, <:Any, UInt32, Int,
CategoricalValue{Int, UInt32}}}
@test [categorical([1, 2], compress=true),
CategoricalVector([1, 2])] isa
Vector{CategoricalVector{Int, UInt32, Int, CategoricalValue{Int, UInt32}, Union{}}}
@test [categorical([1, 2], compress=true),
CategoricalVector(["a", "b"])] isa
Vector{CategoricalVector{<:Any, <:Integer, <:Any, <:Any, Union{}}}
end

end

0 comments on commit c4ae8cc

Please sign in to comment.