Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a WeightedResampler #890

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ julia = "1"

[extras]
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Calculus", "Distributed", "ForwardDiff", "JSON", "StaticArrays", "Test"]
test = ["Calculus", "Distances", "Distributed", "ForwardDiff", "JSON", "StaticArrays", "Test"]
4 changes: 3 additions & 1 deletion src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ for fname in ["aliastable.jl",
"vonmises.jl",
"vonmisesfisher.jl",
"discretenonparametric.jl",
"categorical.jl"]
"categorical.jl",
"resampler.jl",
]

include(joinpath("samplers", fname))
end
53 changes: 53 additions & 0 deletions src/samplers/resampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
rofinn marked this conversation as resolved.
Show resolved Hide resolved
WeightedResampler(obs::AbstractArray, wv::AbstractWeights)

A WeightedResampler is a subtype of Distributions.Sampleable which randomly selects
observations from the raw input data (`obs`) based on the weights (`wv`) provided.

This type supports univariate, multivariate and matrixvariate forms, so `obs` can
be a vector of values, matrix of values or a vector of matrices.
"""
struct WeightedResampler{F<:VariateForm, S<:ValueSupport, T<:AbstractArray} <: Sampleable{F, S}
obs::T
wv::AbstractWeights
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woops, this should be a type parameter!

end

function WeightedResampler(obs::T, wv::AbstractWeights) where T<:AbstractArray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason to restrict weights to AbstractWeights rather than AbstractVector? At JuliaLang/julia#31395, I made functions accept any array since there's no ambiguity.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk but maybe @rofinn does

But also it'd be an easy follow up to loosen it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @rofinn ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge as is and open an issue about this? It'd be a non breaking change to loosen the constraint in future. Or just loosen it now and fix it if there's a bug report

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess merging this as-is is OK. I've noticed that the weights are passed directly to sample, so if we loosen the signature we should also ensure we convert the argument to AbstractWeights, which is what sample expects (currently at least). But I still think loosening this is a good idea, though not a requirement.

F = _variate_form(T)
S = _value_support(eltype(T))

_validate(obs, wv)
WeightedResampler{F, S, T}(obs, wv)
end

_variate_form(::Type{<:AbstractVector}) = Univariate
rofinn marked this conversation as resolved.
Show resolved Hide resolved
_variate_form(::Type{<:AbstractMatrix}) = Multivariate
_variate_form(::Type{<:AbstractVector{<:AbstractMatrix}}) = Matrixvariate

_value_support(::Type{Int}) = Discrete
_value_support(::Type{Float64}) = Continuous
_value_support(T::Type{<:AbstractMatrix}) = _value_support(eltype(T))

_validate(obs::AbstractVector, wv::AbstractWeights) = _validate(length(obs), length(wv))
_validate(obs::AbstractMatrix, wv::AbstractWeights) = _validate(size(obs, 2), length(wv))

function _validate(nobs::Int, nwv::Int)
if nobs != nwv
throw(DimensionMismatch("Length of the weights vector ($nwv) must match the " *
"number of observations ($nobs)."))
end
end

Base.length(s::WeightedResampler{Multivariate}) = size(s.obs, 1)

function Base.rand(rng::AbstractRNG, s::WeightedResampler{<:Union{Univariate,Matrixvariate}})
i = sample(rng, s.wv)
return s.obs[i]
end

function _rand!(rng::AbstractRNG, s::WeightedResampler{Multivariate}, x::AbstractVector{<:Real})
j = sample(rng, s.wv)
for i in 1:length(s)
@inbounds x[i] = s.obs[i, j]
end
end
113 changes: 111 additions & 2 deletions test/samplers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Testing of samplers

using Distributions
using Distributions
using Distances
using Test


Expand All @@ -18,7 +19,8 @@ import Distributions:
GammaGSSampler,
GammaMTSampler,
GammaIPSampler,
PoissBinAliasSampler
PoissBinAliasSampler,
WeightedResampler

n_tsamples = 10^6

Expand Down Expand Up @@ -117,3 +119,110 @@ for S in [GammaGSSampler, GammaIPSampler]
test_samples(S(d), d, n_tsamples, rng=rng)
end
end

@testset "WeightedResampler" begin
rng = MersenneTwister(1234)

@testset "Univariate" begin
obs = collect(1:12)

@testset "Equally Weighted" begin
# Constant analytic weights
wv = aweights(ones(12))

s = WeightedResampler(obs, wv)
X = rand(rng, s, 100000)

# The mean values of the samples should roughly match the mean of the
# original observation
@test isapprox(mean(X), mean(obs); atol=0.01)
end

@testset "Linearly Weighted" begin
# Linearly increasing analytic weights
wv = aweights(collect(1/12:1/12:1.0))

s = WeightedResampler(obs, wv)
X = rand(rng, s, 100000)

# The mean of the samples should not match the mean of the
# original observation
@test !isapprox(mean(X), mean(obs); atol=0.01)

# 12 should be sampled the most
@test mode(X) == 12
end
end

@testset "Multivariate" begin
v = [1.2, 0.7, -0.3, 5.4, -2.8]
# Define different observations via arbitrary operations on v
obs = hcat(
v, reverse(v), sort(v), sin.(v), cos.(v), tan.(v),
v / 100, v * 2, abs.(v), log.(abs.(v)), v .^ 2, v * 10,
)

@testset "Equally Weighted" begin
# Constant analytic weights
wv = aweights(ones(12))

s = WeightedResampler(obs, wv)
X = rand(rng, s, 100000)

# The mean values of each variable in the samples should roughly match
# the means of the original observation
@test nrmsd(mean(X; dims=2), mean(obs; dims=2)) < 0.001
end

@testset "Linearly Weighted" begin
# Linearly increasing analytic weights
wv = aweights(collect(0.083:0.083:1.0))

s = WeightedResampler(obs, wv)
X = rand(rng, s, 100000)

# The mean values of each variable of the samples should not match the
# means of the original observation
@test nrmsd(mean(X; dims=2), mean(obs; dims=2)) > 0.1

# v * 10 should be sampled the most
@test vec(mapslices(mode, X; dims=2)) == v * 10
end
end
@testset "Matrixvariate" begin
# NOTE: Since we've already testing the sampling behaviour we just want to
# check that we've implement the Distributions API correctly for the
# Matrixvariate case
s = WeightedResampler([rand(4, 3) for i in 1:10], aweights(rand(10)))
X = rand(s)
end

@testset "DimensionMismatch" begin
@test_throws DimensionMismatch WeightedResampler(rand(10), aweights(collect(1:12)))
end

# Explicitly test the _function for the resampler
@testset "_variate_form" begin
@test Distributions._variate_form(Vector) == Univariate
@test Distributions._variate_form(Matrix) == Multivariate
@test Distributions._variate_form(Vector{Matrix}) == Matrixvariate
@test_throws MethodError Distributions._variate_form(Float64)
@test_throws MethodError Distributions._variate_form(Array{Float64, 3})
end

@testset "_value_support" begin
@test Distributions._value_support(Int) == Discrete
@test Distributions._value_support(Float64) == Continuous
@test Distributions._value_support(Matrix{Float64}) == Continuous
@test_throws MethodError Distributions._value_support(String)
@test_throws MethodError Distributions._value_support(Vector{Float64})
end

@testset "_validate" begin
Distributions._validate(4, 4)
Distributions._validate(ones(4), aweights(rand(4)))
Distributions._validate(ones(3, 4), aweights(rand(4)))
@test_throws DimensionMismatch Distributions._validate(ones(3), aweights(rand(4)))
@test_throws DimensionMismatch Distributions._validate(ones(4, 3), aweights(rand(4)))
end
end