Skip to content

Commit

Permalink
Simplify weights (#526)
Browse files Browse the repository at this point in the history
  • Loading branch information
lbittarello authored and nalimilan committed Oct 1, 2019
1 parent 95b794a commit b039107
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 147 deletions.
4 changes: 4 additions & 0 deletions src/deprecates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@ end
@deprecate wmedian(v::RealVector, w::RealVector) median(v, weights(w))

@deprecate quantile(v::AbstractArray{<:Real}) quantile(v, [.0, .25, .5, .75, 1.0])

### Deprecated September 2019
@deprecate sum(A::AbstractArray, w::AbstractWeights, dims::Int) sum(A, w, dims=dims)
@deprecate values(wv::AbstractWeights) convert(Vector, wv)
95 changes: 45 additions & 50 deletions src/weights.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
###### Weight vector #####
##### Weight vector #####
abstract type AbstractWeights{S<:Real, T<:Real, V<:AbstractVector{T}} <: AbstractVector{T} end

"""
Expand All @@ -18,12 +18,24 @@ macro weights(name)
end

length(wv::AbstractWeights) = length(wv.values)
values(wv::AbstractWeights) = wv.values
sum(wv::AbstractWeights) = wv.sum
isempty(wv::AbstractWeights) = isempty(wv.values)
size(wv::AbstractWeights) = size(wv.values)

Base.getindex(wv::AbstractWeights, i) = getindex(wv.values, i)
Base.convert(::Type{Vector}, wv::AbstractWeights) = convert(Vector, wv.values)

@propagate_inbounds function Base.getindex(wv::AbstractWeights, i::Integer)
@boundscheck checkbounds(wv, i)
@inbounds wv.values[i]
end

@propagate_inbounds function Base.getindex(wv::W, i::AbstractArray) where W <: AbstractWeights
@boundscheck checkbounds(wv, i)
@inbounds v = wv.values[i]
W(v, sum(v))
end

Base.getindex(wv::W, ::Colon) where {W <: AbstractWeights} = W(copy(wv.values), sum(wv))

@propagate_inbounds function Base.setindex!(wv::AbstractWeights, v::Real, i::Int)
s = v - wv[i]
Expand Down Expand Up @@ -247,7 +259,7 @@ eweights(n::Integer, λ::Real) = eweights(1:n, λ)
eweights(t::AbstractVector, r::AbstractRange, λ::Real) =
eweights(something.(indexin(t, r)), λ)

# NOTE: No variance correction is implemented for exponential weights
# NOTE: no variance correction is implemented for exponential weights

struct UnitWeights{T<:Real} <: AbstractWeights{Int, T, V where V<:Vector{T}}
len::Int
Expand All @@ -260,23 +272,24 @@ Construct a `UnitWeights` vector with length `s` and weight elements of type `T`
All weight elements are identically one.
""" UnitWeights

values(wv::UnitWeights{T}) where T = fill(one(T), length(wv))
sum(wv::UnitWeights{T}) where T = convert(T, length(wv))
isempty(wv::UnitWeights) = iszero(wv.len)
length(wv::UnitWeights) = wv.len
size(wv::UnitWeights) = Tuple(length(wv))

Base.convert(::Type{Vector}, wv::UnitWeights{T}) where {T} = ones(T, length(wv))

@propagate_inbounds function Base.getindex(wv::UnitWeights{T}, i::Integer) where T
@boundscheck checkbounds(wv, i)
one(T)
end

@propagate_inbounds function Base.getindex(wv::UnitWeights{T}, i::AbstractArray{<:Int}) where T
@boundscheck checkbounds(wv, i)
fill(one(T), size(i))
UnitWeights{T}(length(i))
end

Base.getindex(wv::UnitWeights{T}, ::Colon) where T = fill(one(T), length(wv))
Base.getindex(wv::UnitWeights{T}, ::Colon) where {T} = UnitWeights{T}(wv.len)

"""
uweights(s::Integer)
Expand Down Expand Up @@ -315,7 +328,7 @@ This definition is equivalent to the correction applied to unweighted data.
corrected ? (1 / (w.len - 1)) : (1 / w.len)
end

##### Equality tests #####
#### Equality tests #####

for w in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights)
@eval begin
Expand All @@ -341,22 +354,7 @@ Compute the weighted sum of an array `v` with weights `w`, optionally over the d
"""
wsum(v::AbstractVector, w::AbstractVector) = dot(v, w)
wsum(v::AbstractArray, w::AbstractVector) = dot(vec(v), w)

# Note: the methods for BitArray and SparseMatrixCSC are to avoid ambiguities
Base.sum(v::BitArray, w::AbstractWeights) = wsum(v, values(w))
Base.sum(v::SparseArrays.SparseMatrixCSC, w::AbstractWeights) = wsum(v, values(w))
Base.sum(v::AbstractArray, w::AbstractWeights) = dot(v, values(w))

for v in (AbstractArray{<:Number}, BitArray, SparseArrays.SparseMatrixCSC, AbstractArray)
@eval begin
function Base.sum(v::$v, w::UnitWeights)
if length(v) != length(w)
throw(DimensionMismatch("Inconsistent array dimension."))
end
return sum(v)
end
end
end
wsum(v::AbstractArray, w::AbstractVector, dims::Colon) = wsum(v, w)

## wsum along dimension
#
Expand Down Expand Up @@ -392,7 +390,6 @@ end
# (d) A is a general dense array with eltype <: BlasReal:
# dim <= 2: delegate to (a) and (b)
# otherwise, decompose A into multiple pages
#

function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool)
r = wsum(A, w)
Expand Down Expand Up @@ -455,7 +452,8 @@ function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, di
return R
end

# General Cartesian-based weighted sum across dimensions
## general Cartesian-based weighted sum across dimensions

@generated function _wsum_general!(R::AbstractArray{RT}, f::supertype(typeof(abs)),
A::AbstractArray{T,N}, w::AbstractVector{WT}, dim::Int, init::Bool) where {T,RT,WT,N}
quote
Expand Down Expand Up @@ -512,7 +510,6 @@ end
end
end


# N = 1
_wsum!(R::StridedArray{T}, A::DenseArray{T,1}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} =
_wsum1!(R, A, w, init)
Expand All @@ -533,7 +530,6 @@ _wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bo
wsumtype(::Type{T}, ::Type{W}) where {T,W} = typeof(zero(T) * zero(W) + zero(T) * zero(W))
wsumtype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T


"""
wsum!(R, A, w, dim; init=true)
Expand All @@ -559,19 +555,21 @@ function wsum(A::AbstractArray{<:Number}, w::UnitWeights, dim::Int)
return sum(A, dims=dim)
end

# extended sum! and wsum
## extended sum! and wsum

Base.sum!(R::AbstractArray, A::AbstractArray, w::AbstractWeights{<:Real}, dim::Int; init::Bool=true) =
wsum!(R, A, values(w), dim; init=init)
wsum!(R, A, w, dim; init=init)

Base.sum(A::AbstractArray{<:Number}, w::AbstractWeights{<:Real}, dim::Int) = wsum(A, values(w), dim)
Base.sum(A::AbstractArray, w::AbstractWeights{<:Real}; dims::Union{Colon,Int}=:) =
wsum(A, w, dims)

function Base.sum(A::AbstractArray{<:Number}, w::UnitWeights, dim::Int)
size(A, dim) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
return sum(A, dims=dim)
function Base.sum(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:)
a = (dims === :) ? length(A) : size(A, dims)
a != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
return sum(A, dims=dims)
end

###### Weighted means #####
##### Weighted means #####

"""
wmean(v, w::AbstractVector)
Expand All @@ -589,9 +587,10 @@ end
Compute the weighted mean of array `A` with weight vector `w`
(of type `AbstractWeights`) along dimension `dims`, and write results to `R`.
"""
mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights;
dims::Union{Nothing,Int}=nothing) = _mean!(R, A, w, dims)
_mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Nothing) = throw(ArgumentError("dims argument must be provided"))
mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights; dims::Union{Nothing,Int}=nothing) =
_mean!(R, A, w, dims)
_mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Nothing) =
throw(ArgumentError("dims argument must be provided"))
_mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Int) =
rmul!(Base.sum!(R, A, w, dims), inv(sum(w)))

Expand All @@ -611,24 +610,21 @@ w = rand(n)
mean(x, weights(w))
```
"""
mean(A::AbstractArray, w::AbstractWeights; dims::Union{Nothing,Int}=nothing) =
mean(A::AbstractArray, w::AbstractWeights; dims::Union{Colon,Int}=:) =
_mean(A, w, dims)
_mean(A::AbstractArray, w::AbstractWeights, dims::Nothing) =
_mean(A::AbstractArray, w::AbstractWeights, dims::Colon) =
sum(A, w) / sum(w)
_mean(A::AbstractArray{T}, w::AbstractWeights{W}, dims::Int) where {T,W} =
_mean!(similar(A, wmeantype(T, W), Base.reduced_indices(axes(A), dims)), A, w, dims)

function _mean(A::AbstractArray, w::UnitWeights, dims::Nothing)
length(A) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
return mean(A)
end

function _mean(A::AbstractArray, w::UnitWeights, dims::Int)
size(A, dims) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
function mean(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:)
a = (dims === :) ? length(A) : size(A, dims)
a != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
return mean(A, dims=dims)
end

###### Weighted quantile #####
##### Weighted quantile #####

"""
quantile(v, w::AbstractWeights, p)
Expand Down Expand Up @@ -723,9 +719,8 @@ end

quantile(v::RealVector, w::AbstractWeights{<:Real}, p::Number) = quantile(v, w, [p])[1]

##### Weighted median #####


###### Weighted median #####
"""
median(v::RealVector, w::AbstractWeights)
Expand Down
Loading

0 comments on commit b039107

Please sign in to comment.