diff --git a/src/PooledArrays.jl b/src/PooledArrays.jl index 9ba0f67..4892dfb 100644 --- a/src/PooledArrays.jl +++ b/src/PooledArrays.jl @@ -30,13 +30,13 @@ end mutable struct PooledArray{T, R<:Integer, N, RA} <: AbstractArray{T, N} refs::RA - pool::Vector{T} - invpool::Dict{T,R} + pool::Vector{Union{Missing,T}} + invpool::Dict{Union{Missing,T},R} # refcount[] is 1 if only one PooledArray holds a reference to pool and invpool refcount::Threads.Atomic{Int} - function PooledArray{T,R,N,RA}(rs::RefArray{RA}, invpool::Dict{T, R}, - pool::Vector{T}=_invert(invpool), + function PooledArray{T,R,N,RA}(rs::RefArray{RA}, invpool::Dict{Union{Missing,T}, R}, + pool::Vector{Union{Missing,T}}=_invert(invpool), refcount::Threads.Atomic{Int}=Threads.Atomic{Int}(1)) where {T,R,N,RA<:AbstractArray{R, N}} # this is a quick but incomplete consistency check if length(pool) != length(invpool) @@ -45,11 +45,22 @@ mutable struct PooledArray{T, R<:Integer, N, RA} <: AbstractArray{T, N} if length(rs.a) > 0 # 0 indicates #undef # refs mustn't overflow pool - minref, maxref = extrema(rs.a) - if (minref < 0 || maxref > length(invpool)) - throw(ArgumentError("Reference array points beyond the end of the pool")) + + ipl = length(invpool) + for v in rs.a + if (v < 0 || v > ipl) + throw(ArgumentError("Reference array points beyond the end of the pool")) + end + if !(T >: Missing) + if v == 1 + throw(ArgumentError("Missing value in a pool that does not allow it")) + end + end end end + if pool[1] !== missing + throw(ArgumentError("First entry in pool must be a missing value")) + end pa = new{T,R,N,RA}(rs.a, pool, invpool, refcount) finalizer(x -> Threads.atomic_sub!(x.refcount, 1), pa) return pa @@ -76,7 +87,8 @@ const PooledArrOrSub = Union{SubArray{T, N, <:PooledArray{T, R}}, ############################################################################## # Echo inner constructor as an outer constructor -PooledArray(refs::RefArray{RA}, invpool::Dict{T,R}, pool::Vector{T}=_invert(invpool), +PooledArray{T}(refs::RefArray{RA}, invpool::Dict{Union{Missing,T},R}, + pool::Vector{Union{Missing,T}}=_invert(invpool), refcount::Threads.Atomic{Int}=Threads.Atomic{Int}(1)) where {T,R,RA<:AbstractArray{R}} = PooledArray{T,R,ndims(RA),RA}(refs, invpool, pool, refcount) @@ -89,9 +101,9 @@ function _our_copy(x::SubArray{<:Any, 0}) return y end -function PooledArray(d::PooledArrOrSub) +function PooledArray(d::PooledArrOrSub{T}) where {T} Threads.atomic_add!(refcount(d), 1) - return PooledArray(RefArray(_our_copy(DataAPI.refarray(d))), + return PooledArray{T}(RefArray(_our_copy(DataAPI.refarray(d))), DataAPI.invrefpool(d), DataAPI.refpool(d), refcount(d)) end @@ -100,9 +112,9 @@ function _label(xs::AbstractArray, ::Type{I}=DEFAULT_POOLED_REF_TYPE, start = 1, labels = Array{I}(undef, size(xs)), - invpool::Dict{T,I} = Dict{T, I}(), - pool::Vector{T} = T[], - nlabels = 0, + invpool::Dict{Union{Missing,T},I} = Dict{Union{Missing,T}, I}(missing => one(I)), + pool::Vector{Union{Missing,T}} = Union{Missing,T}[missing], + nlabels = 1, ) where {T, I<:Integer} @inbounds for i in start:length(xs) @@ -176,7 +188,7 @@ end function PooledArray{T}(d::AbstractArray; signed::Bool=false, compress::Bool=false) where {T} R = signed ? (compress ? Int8 : DEFAULT_SIGNED_REF_TYPE) : (compress ? UInt8 : DEFAULT_POOLED_REF_TYPE) refs, invpool, pool = _label(d, T, R) - return PooledArray(RefArray(refs), invpool, pool) + return PooledArray{T}(RefArray(refs), invpool, pool) end PooledArray(d::AbstractArray{T}, r::Type) where {T} = PooledArray{T}(d, r) @@ -213,7 +225,7 @@ Base.copy(pa::PooledArrOrSub) = PooledArray(pa) # here we do not allow dest to be SubArray as copy! is intended to replace whole arrays # slow path will be used for SubArray -function copy!(dest::PooledArray{T, R, N}, +function copy!(dest::Union{PooledArray{T, R, N}, PooledArray{Union{Missing,T}, R, N}}, src::PooledArrOrSub{T, N, R}) where {T, N, R} copy!(dest.refs, DataAPI.refarray(src)) src_refcount = refcount(src) @@ -232,10 +244,12 @@ function copy!(dest::PooledArray{T, R, N}, end # this is needed as Julia Base uses a special path for this case we want to avoid -Base.copyto!(dest::PooledArrOrSub{T, N, R}, src::PooledArrOrSub{T, N, R}) where {T, N, R} = +Base.copyto!(dest::Union{PooledArrOrSub{T, N, R}, PooledArrOrSub{Union{Missing,T}, N, R}}, + src::PooledArrOrSub{T, N, R}) where {T, N, R} = copyto!(dest, 1, src, 1, length(src)) -function Base.copyto!(dest::PooledArrOrSub{T, N, R}, doffs::Union{Signed, Unsigned}, +function Base.copyto!(dest::Union{PooledArrOrSub{T, N, R}, PooledArrOrSub{Union{Missing,T}, N, R}}, + doffs::Union{Signed, Unsigned}, src::PooledArrOrSub{T, N, R}, soffs::Union{Signed, Unsigned}, n::Union{Signed, Unsigned}) where {T, N, R} n == 0 && return dest @@ -277,9 +291,9 @@ function Base.resize!(pa::PooledArray{T,R,1}, n::Integer) where {T,R} return pa end -function Base.reverse(x::PooledArray) +function Base.reverse(x::PooledArray{T}) where {T} Threads.atomic_add!(x.refcount, 1) - PooledArray(RefArray(reverse(x.refs)), x.invpool, x.pool, x.refcount) + PooledArray{T}(RefArray(reverse(x.refs)), x.invpool, x.pool, x.refcount) end function Base.permute!!(x::PooledArray, p::AbstractVector{T}) where T<:Integer @@ -293,7 +307,7 @@ function Base.invpermute!!(x::PooledArray, p::AbstractVector{T}) where T<:Intege end Base.similar(pa::PooledArray{T,R}, S::Type, dims::Dims) where {T,R} = - PooledArray(RefArray(zeros(R, dims)), Dict{S,R}()) + PooledArray{S}(RefArray(zeros(R, dims)), Dict{Union{Missing,S},R}(missing=>one(R))) Base.findall(pdv::PooledVector{Bool}) = findall(convert(Vector{Bool}, pdv)) @@ -304,6 +318,10 @@ Base.findall(pdv::PooledVector{Bool}) = findall(convert(Vector{Bool}, pdv)) ## ############################################################################## +# TODO ensure proper translation: +# 1. if missing is mapped to something else - re-introduce the missing +# 2. make sure missing ends up as the first entry in the result +# 3. correctly calculate the eltype of the result (so that Missing does not affect it if it was not allowed originally) function Base.map(f, x::PooledArray{T,R}) where {T,R<:Integer} ks = collect(keys(x.invpool)) vs = collect(values(x.invpool)) @@ -392,6 +410,8 @@ Base.sort(pa::PooledArray; kw...) = pa[sortperm(pa; kw...)] ## ############################################################################## +# TODO: correctly handle types in conversions + function Base.convert(::Type{PooledArray{S,R1,N}}, pa::PooledArray{T,R2,N}) where {S,T,R1<:Integer,R2<:Integer,N} invpool_conv = convert(Dict{S,R1}, pa.invpool) @assert invpool_conv !== pa.invpool @@ -446,27 +466,29 @@ Base.convert(::Type{Array}, pa::PooledArray{T, R, N}) where {T, R, N} = convert( # We need separate functions due to dispatch ambiguities -Base.@propagate_inbounds function Base.getindex(A::PooledArray, I::Int) +Base.@propagate_inbounds function Base.getindex(A::PooledArray{T}, I::Int)::T where T idx = DataAPI.refarray(A)[I] - iszero(idx) && throw(UndefRefError()) + # if eltype(A) does not allow Missing then also 1 is an error + idx <= !(eltype(A) >: Missing) && throw(UndefRefError()) return @inbounds DataAPI.refpool(A)[idx] end # we handle fast only the case when the first index is an abstract vector # this is to make sure other indexing synraxes use standard dispatch from Base # the reason is that creation of DataAPI.refarray(A) is unfortunately slow -Base.@propagate_inbounds function Base.getindex(A::PooledArrOrSub, +Base.@propagate_inbounds function Base.getindex(A::PooledArrOrSub{T}, I1::AbstractVector, - I2::Union{Real, AbstractVector}...) + I2::Union{Real, AbstractVector}...) where {T} # make sure we do not increase A.refcount in case creation of newrefs fails newrefs = DataAPI.refarray(A)[I1, I2...] @assert newrefs isa AbstractArray Threads.atomic_add!(refcount(A), 1) - return PooledArray(RefArray(newrefs), DataAPI.invrefpool(A), DataAPI.refpool(A), refcount(A)) + return PooledArray{T}(RefArray(newrefs), DataAPI.invrefpool(A), DataAPI.refpool(A), refcount(A)) end Base.@propagate_inbounds function Base.isassigned(pa::PooledArrOrSub, I::Int...) - !iszero(DataAPI.refarray(pa)[I...]) + # if eltype(A) does not allow Missing then also 1 is an error + !(DataAPI.refarray(pa)[I...] <= !(eltype(pa) >: Missing)) end ############################################################################## @@ -511,6 +533,9 @@ end Base.IndexStyle(::Type{<:PooledArray}) = IndexLinear() Base.@propagate_inbounds function Base.setindex!(x::PooledArray, val, ind::Int) + if val === missing && !(eltype(x) >: Missing) + throw(ArgumentError("PooledArray element type does not allow storing missing values")) + end x.refs[ind] = getpoolidx(x, val) return x end @@ -548,6 +573,8 @@ Base.empty!(pv::PooledVector) = (empty!(pv.refs); pv) Base.deleteat!(pv::PooledVector, inds) = (deleteat!(pv.refs, inds); pv) +# TODO: review vcat for possible Missing related issues + function _vcat!(c, a, b) copyto!(c, 1, a, 1, length(a)) return copyto!(c, length(a)+1, b, 1, length(b))