Skip to content

Commit

Permalink
Added rand function to generaldiscretenonparametric, but _rand! is no…
Browse files Browse the repository at this point in the history
…t working.
  • Loading branch information
davibarreira committed Dec 19, 2021
1 parent 2947ca9 commit eeba4f2
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 35 deletions.
31 changes: 31 additions & 0 deletions src/multivariate/generaldiscretenonparametric.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
struct GeneralDiscreteNonParametric{VF,T,P <: Real,Ts <: AbstractVector{T},Ps <: AbstractVector{P},} <: Distribution{VF,Discrete}
support::Ts
p::Ps

function GeneralDiscreteNonParametric{VF,T,P,Ts,Ps}(
support::Ts,
p::Ps;
check_args=true,
) where {VF,T,P <: Real,Ts <: AbstractVector{T},Ps <: AbstractVector{P}}
if check_args
length(support) == length(p) ||
error("length of `support` and `p` must be equal")
isprobvec(p) || error("`p` must be a probability vector")
allunique(support) || error("`support` must contain only unique values")
end
new{VF,T,P,Ts,Ps}(support, p)
end
end

function rand(rng::AbstractRNG, d::GeneralDiscreteNonParametric)
x = support(d)
p = probs(d)
n = length(p)
draw = rand(rng, float(eltype(p)))
cp = p[1]
i = 1
while cp <= draw && i < n
@inbounds cp += p[i +=1]
end
return x[i]
end
42 changes: 20 additions & 22 deletions src/multivariate/mvdiscretenonparametric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,31 +51,29 @@ Get the vector of probabilities associated with the support of `d`.
probs(d::MvDiscreteNonParametric) = d.p


# It would be more intuitive if length was the
#
Base.length(d::MvDiscreteNonParametric) = length(first(d.support))
Base.size(d::MvDiscreteNonParametric) = (length(d), length(d.support))

function _rand!(
rng::AbstractRNG,
d::MvDiscreteNonParametric,
x::AbstractVector{T},
) where {T<:Real}

length(x) == length(d) || throw(DimensionMismatch("Invalid argument dimension."))
s = d.support
p = d.p

n = length(p)
draw = Base.rand(rng, float(eltype(p)))
cp = p[1]
i = 1
while cp <= draw && i < n
@inbounds cp += p[i+=1]
end
copyto!(x, s[i])
return x
end
# function _rand!(
# rng::AbstractRNG,
# d::MvDiscreteNonParametric,
# x::AbstractVector{T},
# ) where {T<:Real}

# length(x) == length(d) || throw(DimensionMismatch("Invalid argument dimension."))
# s = d.support
# p = d.p

# n = length(p)
# draw = Base.rand(rng, float(eltype(p)))
# cp = p[1]
# i = 1
# while cp <= draw && i < n
# @inbounds cp += p[i+=1]
# end
# copyto!(x, s[i])
# return x
# end

function _logpdf(d::MvDiscreteNonParametric, x::AbstractVector{T}) where {T<:Real}
s = support(d)
Expand Down
13 changes: 0 additions & 13 deletions src/univariate/discrete/discretenonparametric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,6 @@ Base.isapprox(c1::D, c2::D) where D<:DiscreteNonParametric =

# Sampling

function rand(rng::AbstractRNG, d::DiscreteNonParametric)
x = support(d)
p = probs(d)
n = length(p)
draw = rand(rng, float(eltype(p)))
cp = p[1]
i = 1
while cp <= draw && i < n
@inbounds cp += p[i +=1]
end
return x[i]
end

sampler(d::DiscreteNonParametric) =
DiscreteNonParametricSampler(support(d), probs(d))

Expand Down

0 comments on commit eeba4f2

Please sign in to comment.