-
Notifications
You must be signed in to change notification settings - Fork 418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a WeightedResampler #890
base: master
Are you sure you want to change the base?
Changes from all commits
eba21c9
d59dbeb
910581e
f729119
53c8be2
5fe34b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
""" | ||
rofinn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
WeightedResampler(obs::AbstractArray, wv::AbstractWeights) | ||
|
||
A WeightedResampler is a subtype of Distributions.Sampleable which randomly selects | ||
observations from the raw input data (`obs`) based on the weights (`wv`) provided. | ||
|
||
This type supports univariate, multivariate and matrixvariate forms, so `obs` can | ||
be a vector of values, matrix of values or a vector of matrices. | ||
""" | ||
struct WeightedResampler{F<:VariateForm, S<:ValueSupport, T<:AbstractArray} <: Sampleable{F, S} | ||
obs::T | ||
wv::AbstractWeights | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Woops, this should be a type parameter! |
||
end | ||
|
||
function WeightedResampler(obs::T, wv::AbstractWeights) where T<:AbstractArray | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any reason to restrict weights to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ping @rofinn @nickrobinson251 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Idk but maybe @rofinn does But also it'd be an easy follow up to loosen it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ping @rofinn ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we merge as is and open an issue about this? It'd be a non breaking change to loosen the constraint in future. Or just loosen it now and fix it if there's a bug report There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess merging this as-is is OK. I've noticed that the weights are passed directly to |
||
F = _variate_form(T) | ||
S = _value_support(eltype(T)) | ||
|
||
_validate(obs, wv) | ||
WeightedResampler{F, S, T}(obs, wv) | ||
end | ||
|
||
_variate_form(::Type{<:AbstractVector}) = Univariate | ||
rofinn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_variate_form(::Type{<:AbstractMatrix}) = Multivariate | ||
_variate_form(::Type{<:AbstractVector{<:AbstractMatrix}}) = Matrixvariate | ||
|
||
_value_support(::Type{Int}) = Discrete | ||
_value_support(::Type{Float64}) = Continuous | ||
_value_support(T::Type{<:AbstractMatrix}) = _value_support(eltype(T)) | ||
|
||
_validate(obs::AbstractVector, wv::AbstractWeights) = _validate(length(obs), length(wv)) | ||
_validate(obs::AbstractMatrix, wv::AbstractWeights) = _validate(size(obs, 2), length(wv)) | ||
|
||
function _validate(nobs::Int, nwv::Int) | ||
if nobs != nwv | ||
throw(DimensionMismatch("Length of the weights vector ($nwv) must match the " * | ||
"number of observations ($nobs).")) | ||
end | ||
end | ||
|
||
Base.length(s::WeightedResampler{Multivariate}) = size(s.obs, 1) | ||
|
||
function Base.rand(rng::AbstractRNG, s::WeightedResampler{<:Union{Univariate,Matrixvariate}}) | ||
i = sample(rng, s.wv) | ||
return s.obs[i] | ||
end | ||
|
||
function _rand!(rng::AbstractRNG, s::WeightedResampler{Multivariate}, x::AbstractVector{<:Real}) | ||
j = sample(rng, s.wv) | ||
for i in 1:length(s) | ||
@inbounds x[i] = s.obs[i, j] | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.