From a0b74175d179c3fc89178d7ea2e3a4a1635fda6b Mon Sep 17 00:00:00 2001 From: Rory Finnegan Date: Tue, 17 Jul 2018 16:07:04 -0700 Subject: [PATCH] Add exponential weights Co-authored-by: Alex Arslan --- docs/src/weights.md | 14 +++++++++++- src/StatsBase.jl | 2 ++ src/weights.jl | 52 ++++++++++++++++++++++++++++++++++++++++++++- test/weights.jl | 20 +++++++++++++++++ 4 files changed, 86 insertions(+), 2 deletions(-) diff --git a/docs/src/weights.md b/docs/src/weights.md index e8322bfaa..424d207f9 100644 --- a/docs/src/weights.md +++ b/docs/src/weights.md @@ -41,6 +41,16 @@ w = ProbabilityWeights([0.2, 0.1, 0.3]) w = pweights([0.2, 0.1, 0.3]) ``` +### `ExponentialWeights` + +Exponential weights are a common form of temporal weights which assign exponentially decreasing +weight to past observations. + +```julia +w = ExponentialWeights([0.1837, 0.2222, 0.2688, 0.3253]) +w = eweights(4, 0.173) # construction based on length and rate parameter +``` + ### `Weights` The `Weights` type describes a generic weights vector which does not support all operations possible for `FrequencyWeights`, `AnalyticWeights` and `ProbabilityWeights`. @@ -66,9 +76,11 @@ The following constructors are provided: AnalyticWeights FrequencyWeights ProbabilityWeights +ExponentialWeights Weights aweights fweights pweights +eweights weights -``` \ No newline at end of file +``` diff --git a/src/StatsBase.jl b/src/StatsBase.jl index 8b5873396..1215b5468 100644 --- a/src/StatsBase.jl +++ b/src/StatsBase.jl @@ -29,10 +29,12 @@ module StatsBase AnalyticWeights, # to represent an analytic/precision/reliability weight vector FrequencyWeights, # to representing a frequency/case/repeat weight vector ProbabilityWeights, # to representing a probability/sampling weight vector + ExponentialWeights, # to represent an exponential weight vector weights, # construct a generic Weights vector aweights, # construct an AnalyticWeights vector fweights, # construct a FrequencyWeights vector pweights, # construct a ProbabilityWeights vector + eweights, # construct an ExponentialWeights vector wsum, # weighted sum with vector as second argument wsum!, # weighted sum across dimensions with provided storage wmean, # weighted mean diff --git a/src/weights.jl b/src/weights.jl index f633fe14f..2fe2788a1 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -193,9 +193,59 @@ pweights(vs::RealArray) = ProbabilityWeights(vec(vs)) end end +@weights ExponentialWeights + +@doc """ + ExponentialWeights(vs, wsum=sum(vs)) + +Construct an `ExponentialWeights` vector with weight values `vs`. +A precomputed sum may be provided as `wsum`. + +Exponential weights are a common form of temporal weights which assign exponentially +decreasing weight to past observations, which in this case corresponds to the front of +the vector. That is, newer observations are assumed to be at the end. +""" ExponentialWeights + +""" + eweights(n, λ) + +Construct an [`ExponentialWeights`](@ref) vector with length `n`, +where each element in position ``i`` is set to ``λ (1 - λ)^{1 - i}``. + +``λ`` is a smoothing factor or rate parameter such that ``0 < λ \\leq 1``. +As this value approaches 0, the resulting weights will be almost equal, +while values closer to 1 will put greater weight on the tail elements of the vector. + +# Examples + +```julia-repl +julia> eweights(10, 0.3) +10-element ExponentialWeights{Float64,Float64,Array{Float64,1}}: + 0.3 + 0.42857142857142855 + 0.6122448979591837 + 0.8746355685131197 + 1.249479383590171 + 1.7849705479859588 + 2.549957925694227 + 3.642797036706039 + 5.203995766722913 + 7.434279666747019 +``` +""" +function eweights(n::Integer, λ::Real) + n > 0 || throw(ArgumentError("cannot construct exponential weights of length < 1")) + 0 < λ <= 1 || throw(ArgumentError("smoothing factor must be between 0 and 1")) + w0 = map(i -> λ * (1 - λ)^(1 - i), 1:n) + s = sum(w0) + ExponentialWeights{typeof(s), eltype(w0), typeof(w0)}(w0, s) +end + +# NOTE: No variance correction is implemented for exponential weights + ##### Equality tests ##### -for w in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights) +for w in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, ExponentialWeights, Weights) @eval begin Base.isequal(x::$w, y::$w) = isequal(x.sum, y.sum) && isequal(x.values, y.values) Base.:(==)(x::$w, y::$w) = (x.sum == y.sum) && (x.values == y.values) diff --git a/test/weights.jl b/test/weights.jl index d4c4935ff..b29364b53 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -2,6 +2,8 @@ using StatsBase using LinearAlgebra, Random, SparseArrays, Test @testset "StatsBase.Weights" begin +# NOTE: Do not add eweights here, as its methods don't match those of the others, so the +# tests below don't make sense for it weight_funcs = (weights, aweights, fweights, pweights) # Construction @@ -497,4 +499,22 @@ end @test wquantile(data[1], w, 0.5) ≈ answer atol = 1e-5 end +@testset "ExponentialWeights" begin + @testset "Basic Usage" begin + θ = 5.25 + λ = 1 - exp(-1 / θ) # simple conversion for the more common/readable method + + v = [λ*(1-λ)^(1-i) for i = 1:4] + w = ExponentialWeights(v) + + @test round.(w, digits=4) == [0.1734, 0.2098, 0.2539, 0.3071] + @test eweights(4, λ) ≈ w + end + + @testset "Failure Conditions" begin + @test_throws ArgumentError eweights(0, 0.3) + @test_throws ArgumentError eweights(1, 1.1) + end +end + end # @testset StatsBase.Weights