Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify weights #526

Merged
merged 36 commits into from
Oct 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1fb0ea9
Consistent comment style
lbittarello Sep 22, 2019
516dee3
Remove values as a parameter of weights
lbittarello Sep 22, 2019
bb738fc
Deprecate values for weights
lbittarello Sep 22, 2019
74ed6f9
dimension as keyword to sum for consistency
lbittarello Sep 22, 2019
2efced6
Clean up
lbittarello Sep 22, 2019
18545aa
Drop unnecessary sum methods
lbittarello Sep 22, 2019
94bfe94
Update src/weights.jl
lbittarello Sep 22, 2019
a42f1e3
Update src/weights.jl
lbittarello Sep 22, 2019
a061a2e
Fix boundary check
lbittarello Sep 22, 2019
20eaae5
Remove space around keyword arguments
lbittarello Sep 22, 2019
10b4f57
Merge https://github.com/lbittarello/StatsBase.jl
lbittarello Sep 22, 2019
4957613
Patches
lbittarello Sep 22, 2019
d25dbf4
Indexing by array returns new weight vector
lbittarello Sep 22, 2019
904abc5
Update src/weights.jl
lbittarello Sep 23, 2019
be6d6c7
Update src/weights.jl
lbittarello Sep 23, 2019
f98e4c0
Update src/weights.jl
lbittarello Sep 23, 2019
fec6f5e
Update src/weights.jl
lbittarello Sep 23, 2019
5277140
Update src/weights.jl
lbittarello Sep 23, 2019
5fe88cc
Update src/weights.jl
lbittarello Sep 23, 2019
6c707f6
Update src/weights.jl
lbittarello Sep 23, 2019
dea5fc4
Update src/deprecates.jl
lbittarello Sep 23, 2019
ea6eaf7
Update src/weights.jl
lbittarello Sep 23, 2019
baa4b9c
Update src/weights.jl
lbittarello Sep 23, 2019
b22f5bd
Update src/weights.jl
lbittarello Sep 28, 2019
b225e84
Update src/weights.jl
lbittarello Sep 28, 2019
4b0fc7c
Update src/weights.jl
lbittarello Sep 28, 2019
b7ea5ac
Simplify sum
lbittarello Sep 29, 2019
cd2b28c
Simplify sum and mean for unit weights
lbittarello Sep 29, 2019
9e8c5ee
Update src/weights.jl
lbittarello Sep 29, 2019
2e22c43
Update src/weights.jl
lbittarello Sep 29, 2019
aa7c962
Update src/deprecates.jl
lbittarello Sep 29, 2019
3dfb0f6
Update src/weights.jl
lbittarello Sep 29, 2019
5da60a6
Update src/weights.jl
lbittarello Sep 29, 2019
79cde89
Update src/weights.jl
lbittarello Sep 29, 2019
d65ca65
Update src/weights.jl
lbittarello Sep 29, 2019
6f974ea
Required parentheses added
lbittarello Sep 29, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/deprecates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@ end
@deprecate wmedian(v::RealVector, w::RealVector) median(v, weights(w))

@deprecate quantile(v::AbstractArray{<:Real}) quantile(v, [.0, .25, .5, .75, 1.0])

### Deprecated September 2019
@deprecate sum(A::AbstractArray, w::AbstractWeights, dims::Int) sum(A, w, dims=dims)
@deprecate values(wv::AbstractWeights) convert(Vector, wv)
95 changes: 45 additions & 50 deletions src/weights.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
###### Weight vector #####
##### Weight vector #####
abstract type AbstractWeights{S<:Real, T<:Real, V<:AbstractVector{T}} <: AbstractVector{T} end

"""
Expand All @@ -18,12 +18,24 @@ macro weights(name)
end

length(wv::AbstractWeights) = length(wv.values)
values(wv::AbstractWeights) = wv.values
sum(wv::AbstractWeights) = wv.sum
isempty(wv::AbstractWeights) = isempty(wv.values)
size(wv::AbstractWeights) = size(wv.values)

Base.getindex(wv::AbstractWeights, i) = getindex(wv.values, i)
Base.convert(::Type{Vector}, wv::AbstractWeights) = convert(Vector, wv.values)

@propagate_inbounds function Base.getindex(wv::AbstractWeights, i::Integer)
@boundscheck checkbounds(wv, i)
@inbounds wv.values[i]
end

@propagate_inbounds function Base.getindex(wv::W, i::AbstractArray) where W <: AbstractWeights
@boundscheck checkbounds(wv, i)
@inbounds v = wv.values[i]
W(v, sum(v))
end

Base.getindex(wv::W, ::Colon) where {W <: AbstractWeights} = W(copy(wv.values), sum(wv))

@propagate_inbounds function Base.setindex!(wv::AbstractWeights, v::Real, i::Int)
s = v - wv[i]
Expand Down Expand Up @@ -247,7 +259,7 @@ eweights(n::Integer, λ::Real) = eweights(1:n, λ)
eweights(t::AbstractVector, r::AbstractRange, λ::Real) =
eweights(something.(indexin(t, r)), λ)

# NOTE: No variance correction is implemented for exponential weights
# NOTE: no variance correction is implemented for exponential weights

struct UnitWeights{T<:Real} <: AbstractWeights{Int, T, V where V<:Vector{T}}
len::Int
Expand All @@ -260,23 +272,24 @@ Construct a `UnitWeights` vector with length `s` and weight elements of type `T`
All weight elements are identically one.
""" UnitWeights

values(wv::UnitWeights{T}) where T = fill(one(T), length(wv))
sum(wv::UnitWeights{T}) where T = convert(T, length(wv))
isempty(wv::UnitWeights) = iszero(wv.len)
length(wv::UnitWeights) = wv.len
size(wv::UnitWeights) = Tuple(length(wv))

Base.convert(::Type{Vector}, wv::UnitWeights{T}) where {T} = ones(T, length(wv))

@propagate_inbounds function Base.getindex(wv::UnitWeights{T}, i::Integer) where T
@boundscheck checkbounds(wv, i)
one(T)
end

@propagate_inbounds function Base.getindex(wv::UnitWeights{T}, i::AbstractArray{<:Int}) where T
@boundscheck checkbounds(wv, i)
fill(one(T), size(i))
UnitWeights{T}(length(i))
end

Base.getindex(wv::UnitWeights{T}, ::Colon) where T = fill(one(T), length(wv))
Base.getindex(wv::UnitWeights{T}, ::Colon) where {T} = UnitWeights{T}(wv.len)

"""
uweights(s::Integer)
Expand Down Expand Up @@ -315,7 +328,7 @@ This definition is equivalent to the correction applied to unweighted data.
corrected ? (1 / (w.len - 1)) : (1 / w.len)
end

##### Equality tests #####
#### Equality tests #####

for w in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights)
@eval begin
Expand All @@ -341,22 +354,7 @@ Compute the weighted sum of an array `v` with weights `w`, optionally over the d
"""
wsum(v::AbstractVector, w::AbstractVector) = dot(v, w)
wsum(v::AbstractArray, w::AbstractVector) = dot(vec(v), w)

# Note: the methods for BitArray and SparseMatrixCSC are to avoid ambiguities
nalimilan marked this conversation as resolved.
Show resolved Hide resolved
Base.sum(v::BitArray, w::AbstractWeights) = wsum(v, values(w))
Base.sum(v::SparseArrays.SparseMatrixCSC, w::AbstractWeights) = wsum(v, values(w))
Base.sum(v::AbstractArray, w::AbstractWeights) = dot(v, values(w))

for v in (AbstractArray{<:Number}, BitArray, SparseArrays.SparseMatrixCSC, AbstractArray)
@eval begin
function Base.sum(v::$v, w::UnitWeights)
if length(v) != length(w)
throw(DimensionMismatch("Inconsistent array dimension."))
end
return sum(v)
end
end
end
wsum(v::AbstractArray, w::AbstractVector, dims::Colon) = wsum(v, w)

## wsum along dimension
#
Expand Down Expand Up @@ -392,7 +390,6 @@ end
# (d) A is a general dense array with eltype <: BlasReal:
# dim <= 2: delegate to (a) and (b)
# otherwise, decompose A into multiple pages
#

function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool)
r = wsum(A, w)
Expand Down Expand Up @@ -455,7 +452,8 @@ function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, di
return R
end

# General Cartesian-based weighted sum across dimensions
## general Cartesian-based weighted sum across dimensions

@generated function _wsum_general!(R::AbstractArray{RT}, f::supertype(typeof(abs)),
A::AbstractArray{T,N}, w::AbstractVector{WT}, dim::Int, init::Bool) where {T,RT,WT,N}
quote
Expand Down Expand Up @@ -512,7 +510,6 @@ end
end
end


# N = 1
_wsum!(R::StridedArray{T}, A::DenseArray{T,1}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} =
_wsum1!(R, A, w, init)
Expand All @@ -533,7 +530,6 @@ _wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bo
wsumtype(::Type{T}, ::Type{W}) where {T,W} = typeof(zero(T) * zero(W) + zero(T) * zero(W))
wsumtype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T


"""
wsum!(R, A, w, dim; init=true)

Expand All @@ -559,19 +555,21 @@ function wsum(A::AbstractArray{<:Number}, w::UnitWeights, dim::Int)
return sum(A, dims=dim)
end

# extended sum! and wsum
## extended sum! and wsum

Base.sum!(R::AbstractArray, A::AbstractArray, w::AbstractWeights{<:Real}, dim::Int; init::Bool=true) =
wsum!(R, A, values(w), dim; init=init)
wsum!(R, A, w, dim; init=init)

Base.sum(A::AbstractArray{<:Number}, w::AbstractWeights{<:Real}, dim::Int) = wsum(A, values(w), dim)
Base.sum(A::AbstractArray, w::AbstractWeights{<:Real}; dims::Union{Colon,Int}=:) =
wsum(A, w, dims)

function Base.sum(A::AbstractArray{<:Number}, w::UnitWeights, dim::Int)
size(A, dim) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
return sum(A, dims=dim)
function Base.sum(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:)
a = (dims === :) ? length(A) : size(A, dims)
a != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
return sum(A, dims=dims)
end

###### Weighted means #####
##### Weighted means #####

"""
wmean(v, w::AbstractVector)
Expand All @@ -589,9 +587,10 @@ end
Compute the weighted mean of array `A` with weight vector `w`
(of type `AbstractWeights`) along dimension `dims`, and write results to `R`.
"""
mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights;
dims::Union{Nothing,Int}=nothing) = _mean!(R, A, w, dims)
_mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Nothing) = throw(ArgumentError("dims argument must be provided"))
mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights; dims::Union{Nothing,Int}=nothing) =
_mean!(R, A, w, dims)
_mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Nothing) =
throw(ArgumentError("dims argument must be provided"))
_mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Int) =
rmul!(Base.sum!(R, A, w, dims), inv(sum(w)))

Expand All @@ -611,24 +610,21 @@ w = rand(n)
mean(x, weights(w))
```
"""
mean(A::AbstractArray, w::AbstractWeights; dims::Union{Nothing,Int}=nothing) =
mean(A::AbstractArray, w::AbstractWeights; dims::Union{Colon,Int}=:) =
_mean(A, w, dims)
_mean(A::AbstractArray, w::AbstractWeights, dims::Nothing) =
_mean(A::AbstractArray, w::AbstractWeights, dims::Colon) =
sum(A, w) / sum(w)
_mean(A::AbstractArray{T}, w::AbstractWeights{W}, dims::Int) where {T,W} =
_mean!(similar(A, wmeantype(T, W), Base.reduced_indices(axes(A), dims)), A, w, dims)

function _mean(A::AbstractArray, w::UnitWeights, dims::Nothing)
length(A) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
return mean(A)
end

function _mean(A::AbstractArray, w::UnitWeights, dims::Int)
size(A, dims) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
function mean(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:)
a = (dims === :) ? length(A) : size(A, dims)
a != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
return mean(A, dims=dims)
end

###### Weighted quantile #####
##### Weighted quantile #####

"""
quantile(v, w::AbstractWeights, p)

Expand Down Expand Up @@ -723,9 +719,8 @@ end

quantile(v::RealVector, w::AbstractWeights{<:Real}, p::Number) = quantile(v, w, [p])[1]

##### Weighted median #####


###### Weighted median #####
"""
median(v::RealVector, w::AbstractWeights)

Expand Down
Loading