-
Notifications
You must be signed in to change notification settings - Fork 418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New Distribution: MvDiscreteNonParametric #1424
base: master
Are you sure you want to change the base?
Changes from all commits
34181c4
0cc7518
67b1890
7fba9da
6f43b18
c64339a
e156052
d11d300
24612da
288401b
e8cf084
68c4d4d
cbb6ae7
67a15dd
5644219
3baece6
10ba493
236380a
2947ca9
eeba4f2
b208df3
b2f0ec0
ad6d01e
001d73d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,79 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
struct GeneralDiscreteNonParametric{VF,T,P <: Real,Ts <: AbstractVector{T},Ps <: AbstractVector{P},} <: Distribution{VF,Discrete} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
support::Ts | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p::Ps | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
function GeneralDiscreteNonParametric{VF,T,P,Ts,Ps}( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
support::Ts, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p::Ps; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
check_args=true, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) where {VF,T,P <: Real,Ts <: AbstractVector{T},Ps <: AbstractVector{P}} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if check_args | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
length(support) == length(p) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
error("length of `support` and `p` must be equal") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
isprobvec(p) || error("`p` must be a probability vector") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
allunique(support) || error("`support` must contain only unique values") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
new{VF,T,P,Ts,Ps}(support, p) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
function rand(rng::AbstractRNG, d::GeneralDiscreteNonParametric) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x = support(d) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p = probs(d) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
n = length(p) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
draw = rand(rng, float(eltype(p))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cp = p[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
i = 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
while cp <= draw && i < n | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@inbounds cp += p[i +=1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return x[i] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+20
to
+31
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While this method is correct, it messes with the
Suggested change
or let's remove it for now and add it in a follow-up PR that updates
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
support(d::MvDiscreteNonParametric) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Get a sorted AbstractVector defining the support of `d`. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+33
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for a docstring,
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
support(d::GeneralDiscreteNonParametric) = d.support | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
probs(d::MvDiscreteNonParametric) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Get the vector of probabilities associated with the support of `d`. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+40
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
probs(d::GeneralDiscreteNonParametric) = d.p | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Base.length(d::GeneralDiscreteNonParametric) = length(first(d.support)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+45
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
IMO it is better to be a bit conservative here and only define it for |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
function _rand!( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
rng::AbstractRNG, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
d::GeneralDiscreteNonParametric, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x::AbstractVector{T}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) where {T<:Real} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+48
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
length(x) == length(d) || throw(DimensionMismatch("Invalid argument dimension.")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+53
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The dimensions are already checked before
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
s = d.support | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p = d.p | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
n = length(p) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
draw = Base.rand(rng, float(eltype(p))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to qualify
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cp = p[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
i = 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
while cp <= draw && i < n | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@inbounds cp += p[i+=1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
copyto!(x, s[i]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return x | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
function _logpdf(d::GeneralDiscreteNonParametric, x::AbstractVector{T}) where {T<:Real} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+69
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, the dispatches are only used by
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
s = support(d) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p = probs(d) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for i = 1:length(p) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if s[i] == x | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return log(p[i]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+73
to
+77
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, this could be optimized with julia> sort([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.5, 0.0, 0.5]])
3-element Vector{Vector{Float64}}:
[0.0, 1.0, 0.0]
[0.5, 0.0, 0.5]
[1.0, 0.0, 0.0]
julia> searchsortedfirst(ans, [0.5, 0.0, 0.5]) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually have been implementing an algorithm that does this more efficiently. But I only got the 2D case working. My idea is to submit a future PR improving this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need a custom algorithm? Can't we just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant for computing the empirical cdf. Not exactly the sorting of the support. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return log(zero(eltype(p))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,55 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const MvDiscreteNonParametric{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
T<:AbstractVector{<:Real}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
P<:Real, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ts<:AbstractVector{T}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ps<:AbstractVector{P}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} = GeneralDiscreteNonParametric{Multivariate,T,P,Ts,Ps} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+1
to
+6
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the motivation for introducing this type-alias? It seems we could work with I suggest removing this file completely and moving the definitions to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we just drop the |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
MvDiscreteNonParametric( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
support::AbstractVector, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p::AbstractVector{<:Real}=fill(inv(length(support)), length(support)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Construct a multivariate discrete nonparametric probability distribution with `support` and corresponding | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
probabilities `p`. If the probability vector argument is not passed, then | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
equal probability is assigned to each entry in the support. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Examples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
```julia | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# rows correspond to samples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x = collect(eachrow(rand(10,2))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
μ = MvDiscreteNonParametric(x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# columns correspond to samples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
y = collect(eachcol(rand(7,12))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ν = MvDiscreteNonParametric(y) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
function MvDiscreteNonParametric( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
support::AbstractArray{<:AbstractVector{<:Real}}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be
Suggested change
in the multivariate case, and
Suggested change
only in the more general case of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I just use the second option and change the function to |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p::AbstractVector{<:Real} = fill(inv(length(support)), length(support)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if we should add a default definition for
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return MvDiscreteNonParametric{eltype(support),eltype(p),typeof(support),typeof(p)}( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
support, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Base.eltype(::Type{<:MvDiscreteNonParametric{T}}) where T = Base.eltype(T) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be defined more generally:
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
function mean(d::MvDiscreteNonParametric) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return StatsBase.mean(hcat(d.support...), Weights(d.p, one(eltype(d.p))),dims=2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
function var(d::MvDiscreteNonParametric) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x = hcat(support(d)...) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p = probs(d) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return StatsBase.var(x, Weights(p, one(eltype(p))), 2,corrected = false) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
function cov(d::MvDiscreteNonParametric) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x = hcat(support(d)...) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p = probs(d) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return cov(x, Weights(p, one(eltype(p))), 2,corrected = false) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+41
to
+55
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I requested to add support for vectors of vectors to StatsBase but it might take a bit until it is available. IMO it is fine to have possibly less optimized implementations until then. However, we should avoid splatting and use
Suggested change
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,19 +69,6 @@ Base.isapprox(c1::D, c2::D) where D<:DiscreteNonParametric = | |
|
||
# Sampling | ||
|
||
function rand(rng::AbstractRNG, d::DiscreteNonParametric) | ||
x = support(d) | ||
p = probs(d) | ||
n = length(p) | ||
draw = rand(rng, float(eltype(p))) | ||
cp = p[1] | ||
i = 1 | ||
while cp <= draw && i < n | ||
@inbounds cp += p[i +=1] | ||
end | ||
return x[i] | ||
end | ||
|
||
Comment on lines
-72
to
-84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this was removed accidentally? I guess it would be cleaner to not make any changes to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah. my bad |
||
sampler(d::DiscreteNonParametric) = | ||
DiscreteNonParametricSampler(support(d), probs(d)) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
using Distributions | ||
using StatsBase | ||
using LinearAlgebra | ||
using Random | ||
using Test | ||
|
||
|
||
@testset "GeneralDiscreteNonParametric" begin | ||
|
||
@testset "Declaring MvDiscreteNonParametric" begin | ||
|
||
Random.seed!(7) | ||
n = 4 | ||
m = 2 | ||
A = collect(eachrow(rand(n, m))) | ||
p = normalize!(rand(n), 1) | ||
|
||
# Passing probabilities | ||
μ = @inferred(MvDiscreteNonParametric(A, p)) | ||
@test support(μ) == A | ||
@test length(μ) == m | ||
# @test size(μ) == (m, n) | ||
@test probs(μ) == p | ||
|
||
# Without passing probabilities | ||
μ = @inferred(MvDiscreteNonParametric(A)) | ||
@test support(μ) == A | ||
@test length(μ) == m | ||
# @test size(μ) == (m, n) | ||
@test probs(μ) == fill(1 / n, n) | ||
|
||
# Array of arrays without ArraysOfArrays.jl | ||
n, m = 3, 2 | ||
p = ([3 / 5, 1 / 5, 1 / 5]) | ||
A = [[1,0],[1,1],[0,1]] | ||
μ = @inferred(MvDiscreteNonParametric(A, p)) | ||
|
||
@test support(μ) == A | ||
@test length(μ) == m | ||
@test probs(μ) == p | ||
|
||
end | ||
|
||
|
||
@testset "Functionalities" begin | ||
|
||
function variance(d) | ||
v = zeros(length(d)) | ||
for i in 1:length(d) | ||
s = hcat(μ.support...)[i,:] | ||
mₛ = mean(d)[i] | ||
v[i] = sum(abs2.(s .- mₛ), Weights(d.p)) | ||
end | ||
return v | ||
end | ||
|
||
function covariance(d) | ||
n = length(d) | ||
v = zeros(n, n) | ||
for i in 1:n, j in 1:n | ||
s = hcat(μ.support...)[i,:] | ||
mₛ = mean(d)[i] | ||
|
||
u = hcat(μ.support...)[j,:] | ||
mᵤ = mean(d)[j] | ||
|
||
v[i,j] = sum((s .- mₛ) .* (u .- mᵤ), Weights(d.p)) | ||
end | ||
return v | ||
end | ||
|
||
Random.seed!(7) | ||
n, m = 7, 9 | ||
|
||
A = collect(eachrow(rand(n, m))) | ||
p = normalize!(rand(n), 1) | ||
μ = @inferred(MvDiscreteNonParametric(A, p)) | ||
|
||
@test mean(μ) ≈ mean(hcat(μ.support...), Weights(p), dims=2)[:] | ||
@test var(μ) ≈ variance(μ) | ||
@test cov(μ) ≈ covariance(μ) | ||
@test pdf(μ, μ.support) ≈ μ.p | ||
@test pdf(μ, zeros(m)) == 0.0 | ||
# @test entropy(μ) == entropy(μ.p) | ||
# @test entropy(μ, 2) == entropy(μ.p, 2) | ||
|
||
end | ||
|
||
@testset "Sampling" begin | ||
Random.seed!(7) | ||
A = collect(eachrow(rand(3, 2))) | ||
μ = MvDiscreteNonParametric(A, [0.9,0.1,0.0]) | ||
|
||
for i in 1:3 | ||
samples = rand(μ, 10000) | ||
@test abs(mean([s == A[i] for s in eachcol(samples)]) - μ.p[i]) < 0.05 | ||
end | ||
end | ||
|
||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe let's keep it unexported until everything is figured out?