Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify weights #526

Merged
merged 36 commits into from
Oct 1, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1fb0ea9
Consistent comment style
lbittarello Sep 22, 2019
516dee3
Remove values as a parameter of weights
lbittarello Sep 22, 2019
bb738fc
Deprecate values for weights
lbittarello Sep 22, 2019
74ed6f9
dimension as keyword to sum for consistency
lbittarello Sep 22, 2019
2efced6
Clean up
lbittarello Sep 22, 2019
18545aa
Drop unnecessary sum methods
lbittarello Sep 22, 2019
94bfe94
Update src/weights.jl
lbittarello Sep 22, 2019
a42f1e3
Update src/weights.jl
lbittarello Sep 22, 2019
a061a2e
Fix boundary check
lbittarello Sep 22, 2019
20eaae5
Remove space around keyword arguments
lbittarello Sep 22, 2019
10b4f57
Merge https://github.com/lbittarello/StatsBase.jl
lbittarello Sep 22, 2019
4957613
Patches
lbittarello Sep 22, 2019
d25dbf4
Indexing by array returns new weight vector
lbittarello Sep 22, 2019
904abc5
Update src/weights.jl
lbittarello Sep 23, 2019
be6d6c7
Update src/weights.jl
lbittarello Sep 23, 2019
f98e4c0
Update src/weights.jl
lbittarello Sep 23, 2019
fec6f5e
Update src/weights.jl
lbittarello Sep 23, 2019
5277140
Update src/weights.jl
lbittarello Sep 23, 2019
5fe88cc
Update src/weights.jl
lbittarello Sep 23, 2019
6c707f6
Update src/weights.jl
lbittarello Sep 23, 2019
dea5fc4
Update src/deprecates.jl
lbittarello Sep 23, 2019
ea6eaf7
Update src/weights.jl
lbittarello Sep 23, 2019
baa4b9c
Update src/weights.jl
lbittarello Sep 23, 2019
b22f5bd
Update src/weights.jl
lbittarello Sep 28, 2019
b225e84
Update src/weights.jl
lbittarello Sep 28, 2019
4b0fc7c
Update src/weights.jl
lbittarello Sep 28, 2019
b7ea5ac
Simplify sum
lbittarello Sep 29, 2019
cd2b28c
Simplify sum and mean for unit weights
lbittarello Sep 29, 2019
9e8c5ee
Update src/weights.jl
lbittarello Sep 29, 2019
2e22c43
Update src/weights.jl
lbittarello Sep 29, 2019
aa7c962
Update src/deprecates.jl
lbittarello Sep 29, 2019
3dfb0f6
Update src/weights.jl
lbittarello Sep 29, 2019
5da60a6
Update src/weights.jl
lbittarello Sep 29, 2019
79cde89
Update src/weights.jl
lbittarello Sep 29, 2019
d65ca65
Update src/weights.jl
lbittarello Sep 29, 2019
6f974ea
Required parentheses added
lbittarello Sep 29, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/deprecates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ end
@deprecate wmedian(v::RealVector, w::RealVector) median(v, weights(w))

@deprecate quantile(v::AbstractArray{<:Real}) quantile(v, [.0, .25, .5, .75, 1.0])

### Deprecated September 2019
@deprecate values(wv::AbstractWeights) convert(Vector, wv)
60 changes: 23 additions & 37 deletions src/weights.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
###### Weight vector #####
abstract type AbstractWeights{S<:Real, T<:Real, V<:AbstractVector{T}} <: AbstractVector{T} end
##### Weight vector #####

abstract type AbstractWeights{S<:Real, T<:Real} <: AbstractVector{T} end

"""
@weights name
Expand All @@ -9,20 +10,21 @@ and stores the `values` (`V<:RealVector`) and `sum` (`S<:Real`).
"""
macro weights(name)
return quote
mutable struct $name{S<:Real, T<:Real, V<:AbstractVector{T}} <: AbstractWeights{S, T, V}
values::V
mutable struct $name{S<:Real, T<:Real} <: AbstractWeights{S, T}
values::AbstractVector{T}
lbittarello marked this conversation as resolved.
Show resolved Hide resolved
sum::S
end
$(esc(name))(vs) = $(esc(name))(vs, sum(vs))
end
end

length(wv::AbstractWeights) = length(wv.values)
values(wv::AbstractWeights) = wv.values
sum(wv::AbstractWeights) = wv.sum
isempty(wv::AbstractWeights) = isempty(wv.values)
size(wv::AbstractWeights) = size(wv.values)

convert(::Type{Vector}, wv::AbstractWeights) = wv.values
lbittarello marked this conversation as resolved.
Show resolved Hide resolved

Base.getindex(wv::AbstractWeights, i) = getindex(wv.values, i)

@propagate_inbounds function Base.setindex!(wv::AbstractWeights, v::Real, i::Int)
Expand Down Expand Up @@ -247,9 +249,9 @@ eweights(n::Integer, λ::Real) = eweights(1:n, λ)
eweights(t::AbstractVector, r::AbstractRange, λ::Real) =
eweights(something.(indexin(t, r)), λ)

# NOTE: No variance correction is implemented for exponential weights
# NOTE: no variance correction is implemented for exponential weights

struct UnitWeights{T<:Real} <: AbstractWeights{Int, T, V where V<:Vector{T}}
struct UnitWeights{T<:Real} <: AbstractWeights{Int, T}
len::Int
end

Expand All @@ -260,12 +262,13 @@ Construct a `UnitWeights` vector with length `s` and weight elements of type `T`
All weight elements are identically one.
""" UnitWeights

values(wv::UnitWeights{T}) where T = fill(one(T), length(wv))
sum(wv::UnitWeights{T}) where T = convert(T, length(wv))
isempty(wv::UnitWeights) = iszero(wv.len)
length(wv::UnitWeights) = wv.len
size(wv::UnitWeights) = Tuple(length(wv))

convert(::Type{Vector}, wv::UnitWeights{T}) where T = ones(T, length(wv))
lbittarello marked this conversation as resolved.
Show resolved Hide resolved

@propagate_inbounds function Base.getindex(wv::UnitWeights{T}, i::Integer) where T
@boundscheck checkbounds(wv, i)
one(T)
Expand Down Expand Up @@ -342,22 +345,6 @@ Compute the weighted sum of an array `v` with weights `w`, optionally over the d
wsum(v::AbstractVector, w::AbstractVector) = dot(v, w)
wsum(v::AbstractArray, w::AbstractVector) = dot(vec(v), w)

# Note: the methods for BitArray and SparseMatrixCSC are to avoid ambiguities
nalimilan marked this conversation as resolved.
Show resolved Hide resolved
Base.sum(v::BitArray, w::AbstractWeights) = wsum(v, values(w))
Base.sum(v::SparseArrays.SparseMatrixCSC, w::AbstractWeights) = wsum(v, values(w))
Base.sum(v::AbstractArray, w::AbstractWeights) = dot(v, values(w))

for v in (AbstractArray{<:Number}, BitArray, SparseArrays.SparseMatrixCSC, AbstractArray)
@eval begin
function Base.sum(v::$v, w::UnitWeights)
if length(v) != length(w)
throw(DimensionMismatch("Inconsistent array dimension."))
end
return sum(v)
end
end
end

## wsum along dimension
#
# Brief explanation of the algorithm:
Expand Down Expand Up @@ -392,7 +379,6 @@ end
# (d) A is a general dense array with eltype <: BlasReal:
# dim <= 2: delegate to (a) and (b)
# otherwise, decompose A into multiple pages
#

function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool)
r = wsum(A, w)
Expand Down Expand Up @@ -455,7 +441,8 @@ function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, di
return R
end

# General Cartesian-based weighted sum across dimensions
## general Cartesian-based weighted sum across dimensions

@generated function _wsum_general!(R::AbstractArray{RT}, f::supertype(typeof(abs)),
A::AbstractArray{T,N}, w::AbstractVector{WT}, dim::Int, init::Bool) where {T,RT,WT,N}
quote
Expand Down Expand Up @@ -512,7 +499,6 @@ end
end
end


# N = 1
_wsum!(R::StridedArray{T}, A::DenseArray{T,1}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} =
_wsum1!(R, A, w, init)
Expand All @@ -533,7 +519,6 @@ _wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bo
wsumtype(::Type{T}, ::Type{W}) where {T,W} = typeof(zero(T) * zero(W) + zero(T) * zero(W))
wsumtype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T


"""
wsum!(R, A, w, dim; init=true)

Expand All @@ -559,19 +544,20 @@ function wsum(A::AbstractArray{<:Number}, w::UnitWeights, dim::Int)
return sum(A, dims=dim)
end

# extended sum! and wsum
## extended sum! and wsum

Base.sum!(R::AbstractArray, A::AbstractArray, w::AbstractWeights{<:Real}, dim::Int; init::Bool=true) =
wsum!(R, A, values(w), dim; init=init)

Base.sum(A::AbstractArray{<:Number}, w::AbstractWeights{<:Real}, dim::Int) = wsum(A, values(w), dim)
Base.sum(A::AbstractArray{<:Number}, w::AbstractWeights{<:Real}; dims::Union{Nothing,Int}=nothing) =
dims == nothing ? wsum(A, w.values) : wsum(A, w.values, dims)

function Base.sum(A::AbstractArray{<:Number}, w::UnitWeights, dim::Int)
size(A, dim) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
return sum(A, dims=dim)
function Base.sum(A::AbstractArray{<:Number}, w::UnitWeights; dims::Union{Nothing,Int}=nothing)
size(A, dims) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
lbittarello marked this conversation as resolved.
Show resolved Hide resolved
return sum(A, dims=dims)
end

###### Weighted means #####
##### Weighted means #####

"""
wmean(v, w::AbstractVector)
Expand Down Expand Up @@ -628,7 +614,8 @@ function _mean(A::AbstractArray, w::UnitWeights, dims::Int)
return mean(A, dims=dims)
end

###### Weighted quantile #####
##### Weighted quantile #####

"""
quantile(v, w::AbstractWeights, p)

Expand Down Expand Up @@ -723,9 +710,8 @@ end

quantile(v::RealVector, w::AbstractWeights{<:Real}, p::Number) = quantile(v, w, [p])[1]

##### Weighted median #####


###### Weighted median #####
"""
median(v::RealVector, w::AbstractWeights)

Expand Down
29 changes: 13 additions & 16 deletions test/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ a = reshape(1.0:27.0, 3, 3, 3)
@test sum(1:3, f([1.0, 1.0, 0.5])) ≈ 4.5

for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0])
@test sum(a, f(wt), 1) ≈ sum(a.*reshape(wt, length(wt), 1, 1), dims = 1)
@test sum(a, f(wt), 2) ≈ sum(a.*reshape(wt, 1, length(wt), 1), dims = 2)
@test sum(a, f(wt), 3) ≈ sum(a.*reshape(wt, 1, 1, length(wt)), dims = 3)
@test sum(a, f(wt), dims = 1) ≈ sum(a.*reshape(wt, length(wt), 1, 1), dims = 1)
@test sum(a, f(wt), dims = 2) ≈ sum(a.*reshape(wt, 1, length(wt), 1), dims = 2)
@test sum(a, f(wt), dims = 3) ≈ sum(a.*reshape(wt, 1, 1, length(wt)), dims = 3)
end
end

Expand All @@ -250,8 +250,6 @@ end
end
end


# Quantile fweights
@testset "Quantile fweights" begin
data = (
[7, 1, 2, 4, 10],
Expand Down Expand Up @@ -429,10 +427,9 @@ end
v = [7, 1, 2, 4, 10]
w = [1, 1/3, 1/3, 1/3, 1]
answer = 6.0
@test quantile(data[1], f(w), 0.5) answer atol = 1e-5
@test quantile(data[1], f(w), 0.5) answer atol = 1e-5
end


@testset "Median $f" for f in weight_funcs
data = [4, 3, 2, 1]
wt = [0, 0, 0, 0]
Expand Down Expand Up @@ -470,17 +467,17 @@ end
@test sum([1.0, 2.0, 3.0], wt) ≈ 6.0
@test mean([1.0, 2.0, 3.0], wt) ≈ 2.0

@test sum(a, wt, 1) ≈ sum(a, dims=1)
@test sum(a, wt, 2) ≈ sum(a, dims=2)
@test sum(a, wt, 3) ≈ sum(a, dims=3)
@test sum(a, wt, dims = 1) ≈ sum(a, dims = 1)
lbittarello marked this conversation as resolved.
Show resolved Hide resolved
@test sum(a, wt, dims = 2) ≈ sum(a, dims = 2)
@test sum(a, wt, dims = 3) ≈ sum(a, dims = 3)

@test wsum(a, wt, 1) ≈ sum(a, dims=1)
@test wsum(a, wt, 2) ≈ sum(a, dims=2)
@test wsum(a, wt, 3) ≈ sum(a, dims=3)
@test wsum(a, wt, dims = 1) ≈ sum(a, dims = 1)
@test wsum(a, wt, dims = 2) ≈ sum(a, dims = 2)
@test wsum(a, wt, dims = 3) ≈ sum(a, dims = 3)

@test mean(a, wt, dims=1) ≈ mean(a, dims=1)
@test mean(a, wt, dims=2) ≈ mean(a, dims=2)
@test mean(a, wt, dims=3) ≈ mean(a, dims=3)
@test mean(a, wt, dims = 1) ≈ mean(a, dims = 1)
@test mean(a, wt, dims = 2) ≈ mean(a, dims = 2)
@test mean(a, wt, dims = 3) ≈ mean(a, dims = 3)

@test_throws DimensionMismatch sum(a, wt)
@test_throws DimensionMismatch sum(a, wt, 4)
Expand Down