-
Notifications
You must be signed in to change notification settings - Fork 11
/
weights.jl
72 lines (54 loc) · 1.94 KB
/
weights.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#==========================================================================================#
# UNIT WEIGHTS
struct UnitWeights{S <: Real, T <: Real, V <: AbstractVector{T}} <: AbstractWeights{S, T, V}
values::V
sum::S
end
uweights(vs::AbstractVector) = UnitWeights(vs)
uweights(vs::AbstractArray) = UnitWeights(vec(vs))
function UnitWeights(vs::AbstractVector{T}, s::S = sum(vs)) where {S <: Real, T <: Real}
UnitWeights{S, T, typeof(vs)}(vs, s)
end
#==========================================================================================#
# CONVERT WEIGHTS
function parse_weights(w::AnalyticWeights, msng::BitVector)
n = sum(msng)
v = float(w[msng])
s = n / sum(v)
v .= v .* s
AnalyticWeights(v, n)
end
function parse_weights(w::FrequencyWeights, msng::BitVector)
v = w[msng]
s = Int(sum(v))
FrequencyWeights(float(v), s)
end
function parse_weights(w::ProbabilityWeights, msng::BitVector)
n = sum(msng)
v = float(w[msng])
s = n / sum(v)
v .= v .* s
ProbabilityWeights(v, n)
end
function parse_weights(w::UnitWeights, msng::BitVector)
UnitWeights(float(w[msng]), sum(msng))
end
function parse_weights(w::Weights, msng::BitVector)
n = sum(msng)
v = float(w[msng])
s = n / sum(v)
v .= v .* s
Weights(v, n)
end
#==========================================================================================#
# REWEIGHTING
reweight(w::UnitWeights, v::ProbabilityWeights) = v
reweight(w::AbstractWeights, v::ProbabilityWeights) = pweights(w .* v)
#==========================================================================================#
# OPERATIONS
Base.sum(v::AbstractArray, w::UnitWeights) = sum(v)
Base.mean(v::AbstractArray, w::UnitWeights) = mean(v)
Base.:(==)(x::UnitWeights, y::UnitWeights) = (x.sum == y.sum) && (x.values == y.values)
function Base.isequal(x::UnitWeights, y::UnitWeights)
return isequal(x.sum, y.sum) && isequal(x.values, y.values)
end