Skip to content

Commit

Permalink
Minor cleanup and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan committed Jul 18, 2018
1 parent ba0e222 commit 8719310
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 31 deletions.
68 changes: 38 additions & 30 deletions src/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,21 +200,35 @@ end
@weights ExponentialWeights false

"""
ExponentialWeights
ExponentialWeights(vs)
# Fields
Construct an `ExponentialWeights` vector with weight values `vs`, which must sum to 1.
* `λ::Float64`: is a smoothing factor or rate paremeter between 0 .. 1.
As this value approaches 0 the resulting weights will be almost equal(),
while values closer to 1 will put higher weight on the end elements of the vector.
Exponential weights are a common form of temporal weights which assign exponentially
greater weight to past observations, which in this case corresponds to the tail end of
the vector.
"""
function ExponentialWeights(vs::V) where {T<:Real, V<:AbstractVector{T}}
s = sum(vs)
s one(T) || throw(ArgumentError("weight values do not sum to 1 (got $s)"))
ExponentialWeights{T, T, V}(vs, s)
end

When called with a desired length `n` (`Int`) a vector of length `n` will
be returned, where each element is set to `λ * (1 - λ)^(1 - i)`.
"""
eweights(n, λ)
# Usage
Construct an [`ExponentialWeights`](@ref) vector with length `n`,
where each element in position ``i`` is set to ``λ (1 - λ)^{1 - i}``.
The entire set of weights are then normalized to sum to 1.
```julia
w = ExponentialWeights(10, 0.3)
``λ`` is a smoothing factor or rate parameter such that ``0 < λ \\leq 1``.
As this value approaches 0, the resulting weights will be almost equal,
while values closer to 1 will put greater weight on the tail elements of the vector.
# Examples
```julia-repl
julia> eweights(10, 0.3)
10-element ExponentialWeights{Float64,Float64,Array{Float64,1}}:
0.012458
0.0177971
Expand All @@ -228,41 +242,35 @@ w = ExponentialWeights(10, 0.3)
0.308721
```
"""
function ExponentialWeights(vs::V) where {T<:Real, V<:AbstractVector{T}}
s = sum(vs)
s one(T) || throw(ArgumentError("weight values do not sum to 1 (got $s)"))
ExponentialWeights{T, T, V}(vs, s)
end

function ExponentialWeights(n::Integer, λ::Real)
n > 0 || throw(ArgumentError("cannot construct weights of length < 1"))
function eweights(n::Integer, λ::Real)
n > 0 || throw(ArgumentError("cannot construct exponential weights of length < 1"))
0 < λ <= 1 || throw(ArgumentError("smoothing factor must be between 0 and 1"))
w0 = map(i -> λ * (1 - λ)^(1 - i), 1:n)
s = sum(w0)
ExponentialWeights(w0 / s)
w0 ./= s
ExponentialWeights{typeof(s), eltype(w0), typeof(w0)}(w0, s)
end

"""
eweights(n, λ)
Construct an `ExponentialWeights` vector with length `n`,
where each element in position ``i`` is set to ``λ * (1 - λ)^(1 - i)``.
The entire set of weights are then normalized so that they sum to 1.0
eweights(vs)
``λ`` is a smoothing factor or rate parameter between 0 and 1.
As this value approaches 0 the resulting weights will be almost equal,
while values closer to 1 will put higher weight on the end elements of the vector.
Construct an [`ExponentialWeights`](@ref) vector using the given array.
"""
eweights(n::Integer, λ::Real) = ExponentialWeights(n, λ)
eweights(v::RealVector) = ExponentialWeights(v)
eweights(v::RealArray) = ExponentialWeights(vec(v))

"""
varcorrection(w::ExponentialWeights, corrected=false)
* `corrected=true`: ``\\frac{1}{1 - \\sum {w^2}}``
* `corrected=false`: ``1.0``
* `corrected=false`: ``1``
"""
@inline function varcorrection(w::ExponentialWeights, corrected::Bool=false)
corrected ? 1 / (1 - sum(x -> x^2, w)) : 1.0
if corrected
1 / (1 - sum(abs2, w.values))
else
1 / one(w.sum) # just 1 promoted to the same type as the other branch
end
end

##### Equality tests #####
Expand Down
5 changes: 4 additions & 1 deletion test/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,8 @@ end
θ = 5.25
λ = 1 - exp(-1 / θ) # simple conversion for the more common/readable method

w = ExponentialWeights(4, λ)
v =*(1-λ)^(1-i) for i = 1:4]
w = ExponentialWeights(v ./ sum(v))

@test round.(w, digits=4) == [0.1837, 0.2222, 0.2688, 0.3253]
@test eweights(4, λ) w
Expand All @@ -513,6 +514,8 @@ end
@testset "Failure Conditions" begin
@test_throws ArgumentError eweights(0, 0.3)
@test_throws ArgumentError eweights(1, 1.1)
@test_throws ArgumentError eweights(rand(4))
@test_throws ArgumentError eweights(rand(4, 4))
@test_throws ArgumentError ExponentialWeights(rand(4))
end

Expand Down

0 comments on commit 8719310

Please sign in to comment.