diff --git a/Project.toml b/Project.toml index 2e81febfa..095184736 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ julia = "1" [extras] Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -32,4 +33,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Calculus", "Distributed", "FiniteDifferences", "ForwardDiff", "JSON", "StaticArrays", "Test"] +test = ["Calculus", "Distances", "Distributed", "FiniteDifferences", "ForwardDiff", "JSON", "StaticArrays", "Test"] + + diff --git a/src/samplers.jl b/src/samplers.jl index 794f2bff4..885a9f35f 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -24,7 +24,9 @@ for fname in ["aliastable.jl", "vonmises.jl", "vonmisesfisher.jl", "discretenonparametric.jl", - "categorical.jl"] + "categorical.jl", + "resampler.jl", +] include(joinpath("samplers", fname)) end diff --git a/src/samplers/resampler.jl b/src/samplers/resampler.jl new file mode 100644 index 000000000..d3f28b1cc --- /dev/null +++ b/src/samplers/resampler.jl @@ -0,0 +1,53 @@ +""" + WeightedResampler(obs::AbstractArray, wv::AbstractWeights) + +A WeightedResampler is a subtype of Distributions.Sampleable which randomly selects +observations from the raw input data (`obs`) based on the weights (`wv`) provided. + +This type supports univariate, multivariate and matrixvariate forms, so `obs` can +be a vector of values, matrix of values or a vector of matrices. +""" +struct WeightedResampler{F<:VariateForm, S<:ValueSupport, T<:AbstractArray} <: Sampleable{F, S} + obs::T + wv::AbstractWeights +end + +function WeightedResampler(obs::T, wv::AbstractWeights) where T<:AbstractArray + F = _variate_form(T) + S = _value_support(eltype(T)) + + _validate(obs, wv) + WeightedResampler{F, S, T}(obs, wv) +end + +_variate_form(::Type{<:AbstractVector}) = Univariate +_variate_form(::Type{<:AbstractMatrix}) = Multivariate +_variate_form(::Type{<:AbstractVector{<:AbstractMatrix}}) = Matrixvariate + +_value_support(::Type{Int}) = Discrete +_value_support(::Type{Float64}) = Continuous +_value_support(T::Type{<:AbstractMatrix}) = _value_support(eltype(T)) + +_validate(obs::AbstractVector, wv::AbstractWeights) = _validate(length(obs), length(wv)) +_validate(obs::AbstractMatrix, wv::AbstractWeights) = _validate(size(obs, 2), length(wv)) + +function _validate(nobs::Int, nwv::Int) + if nobs != nwv + throw(DimensionMismatch("Length of the weights vector ($nwv) must match the " * + "number of observations ($nobs).")) + end +end + +Base.length(s::WeightedResampler{Multivariate}) = size(s.obs, 1) + +function Base.rand(rng::AbstractRNG, s::WeightedResampler{<:Union{Univariate,Matrixvariate}}) + i = sample(rng, s.wv) + return s.obs[i] +end + +function _rand!(rng::AbstractRNG, s::WeightedResampler{Multivariate}, x::AbstractVector{<:Real}) + j = sample(rng, s.wv) + for i in 1:length(s) + @inbounds x[i] = s.obs[i, j] + end +end diff --git a/test/samplers.jl b/test/samplers.jl index a1b77c709..847b795d0 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -1,6 +1,7 @@ # Testing of samplers -using Distributions +using Distributions +using Distances using Test @@ -18,7 +19,8 @@ import Distributions: GammaGSSampler, GammaMTSampler, GammaIPSampler, - PoissBinAliasSampler + PoissBinAliasSampler, + WeightedResampler n_tsamples = 10^6 @@ -117,3 +119,110 @@ for S in [GammaGSSampler, GammaIPSampler] test_samples(S(d), d, n_tsamples, rng=rng) end end + +@testset "WeightedResampler" begin + rng = MersenneTwister(1234) + + @testset "Univariate" begin + obs = collect(1:12) + + @testset "Equally Weighted" begin + # Constant analytic weights + wv = aweights(ones(12)) + + s = WeightedResampler(obs, wv) + X = rand(rng, s, 100000) + + # The mean values of the samples should roughly match the mean of the + # original observation + @test isapprox(mean(X), mean(obs); atol=0.01) + end + + @testset "Linearly Weighted" begin + # Linearly increasing analytic weights + wv = aweights(collect(1/12:1/12:1.0)) + + s = WeightedResampler(obs, wv) + X = rand(rng, s, 100000) + + # The mean of the samples should not match the mean of the + # original observation + @test !isapprox(mean(X), mean(obs); atol=0.01) + + # 12 should be sampled the most + @test mode(X) == 12 + end + end + + @testset "Multivariate" begin + v = [1.2, 0.7, -0.3, 5.4, -2.8] + # Define different observations via arbitrary operations on v + obs = hcat( + v, reverse(v), sort(v), sin.(v), cos.(v), tan.(v), + v / 100, v * 2, abs.(v), log.(abs.(v)), v .^ 2, v * 10, + ) + + @testset "Equally Weighted" begin + # Constant analytic weights + wv = aweights(ones(12)) + + s = WeightedResampler(obs, wv) + X = rand(rng, s, 100000) + + # The mean values of each variable in the samples should roughly match + # the means of the original observation + @test nrmsd(mean(X; dims=2), mean(obs; dims=2)) < 0.001 + end + + @testset "Linearly Weighted" begin + # Linearly increasing analytic weights + wv = aweights(collect(0.083:0.083:1.0)) + + s = WeightedResampler(obs, wv) + X = rand(rng, s, 100000) + + # The mean values of each variable of the samples should not match the + # means of the original observation + @test nrmsd(mean(X; dims=2), mean(obs; dims=2)) > 0.1 + + # v * 10 should be sampled the most + @test vec(mapslices(mode, X; dims=2)) == v * 10 + end + end + @testset "Matrixvariate" begin + # NOTE: Since we've already testing the sampling behaviour we just want to + # check that we've implement the Distributions API correctly for the + # Matrixvariate case + s = WeightedResampler([rand(4, 3) for i in 1:10], aweights(rand(10))) + X = rand(s) + end + + @testset "DimensionMismatch" begin + @test_throws DimensionMismatch WeightedResampler(rand(10), aweights(collect(1:12))) + end + + # Explicitly test the _function for the resampler + @testset "_variate_form" begin + @test Distributions._variate_form(Vector) == Univariate + @test Distributions._variate_form(Matrix) == Multivariate + @test Distributions._variate_form(Vector{Matrix}) == Matrixvariate + @test_throws MethodError Distributions._variate_form(Float64) + @test_throws MethodError Distributions._variate_form(Array{Float64, 3}) + end + + @testset "_value_support" begin + @test Distributions._value_support(Int) == Discrete + @test Distributions._value_support(Float64) == Continuous + @test Distributions._value_support(Matrix{Float64}) == Continuous + @test_throws MethodError Distributions._value_support(String) + @test_throws MethodError Distributions._value_support(Vector{Float64}) + end + + @testset "_validate" begin + Distributions._validate(4, 4) + Distributions._validate(ones(4), aweights(rand(4))) + Distributions._validate(ones(3, 4), aweights(rand(4))) + @test_throws DimensionMismatch Distributions._validate(ones(3), aweights(rand(4))) + @test_throws DimensionMismatch Distributions._validate(ones(4, 3), aweights(rand(4))) + end +end