Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change dimension ordering #50

Merged
merged 47 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3cdde3e
Update discretediag
sethaxen Nov 19, 2022
7764bb2
Update ess_rhat
sethaxen Nov 19, 2022
f2b3f5e
Update gelmandiag
sethaxen Nov 19, 2022
10e3814
Update rstar
sethaxen Nov 19, 2022
86297e9
Add 3d array method for rstar
sethaxen Nov 19, 2022
d5609e2
Add 3d array method for mcse
sethaxen Nov 19, 2022
79adf82
Update docstrings
sethaxen Nov 19, 2022
1370c7a
Change dimension order in tests
sethaxen Nov 19, 2022
29cba7f
Update rstar tests with permuted dims
sethaxen Nov 19, 2022
201bdec
Test rstar with 3d array
sethaxen Nov 19, 2022
163ea88
Test mcse with 3d array
sethaxen Nov 19, 2022
66d2c70
Increment version number
sethaxen Nov 19, 2022
9215e2f
Bump compat
sethaxen Nov 19, 2022
b733691
Run formatter
sethaxen Nov 19, 2022
06b4d9f
Update seed
sethaxen Nov 19, 2022
0917971
Add Compat as dependency
sethaxen Nov 21, 2022
3230e08
Use eachslice
sethaxen Nov 21, 2022
61653f9
Replace mapslices with an explicit loop
sethaxen Nov 21, 2022
1d0a086
Run formatter
sethaxen Nov 21, 2022
f1e05db
Avoid explicit axis check
sethaxen Nov 21, 2022
e7bd124
Remove type-instability
sethaxen Nov 21, 2022
1941044
Accept any table to rstar
sethaxen Nov 21, 2022
1dbc4f2
Update rstar documentation
sethaxen Nov 21, 2022
87d5f88
Release type constraint
sethaxen Nov 21, 2022
78d256b
Support rstar taking matrices or vectors
sethaxen Nov 21, 2022
aa01a1e
Update rstar tests
sethaxen Nov 21, 2022
e04fd46
Add type consistency check
sethaxen Nov 21, 2022
1a561f0
Don't permutedims in discretediag
sethaxen Nov 23, 2022
e6ce9b2
Revert all changes to mcse
sethaxen Nov 25, 2022
d1924cf
Reorder dimensions to (draw, chain, params)
sethaxen Dec 1, 2022
ea01e3f
Apply suggestions from code review
sethaxen Dec 12, 2022
d7437a0
Split rstar docstring
sethaxen Dec 12, 2022
ac5f2cb
Convert to table once
sethaxen Dec 12, 2022
be46258
Clean up language
sethaxen Dec 12, 2022
586b8f1
Use correct variable name
sethaxen Dec 12, 2022
365267f
Allow Tables v1
sethaxen Dec 12, 2022
9438e34
Bump Julia compat to v1.3
sethaxen Dec 12, 2022
a5e784c
Remove compat dependency
sethaxen Dec 12, 2022
76317b9
Remove all special-casing for less than v1.2
sethaxen Dec 12, 2022
859f658
Test on v1.3
sethaxen Dec 12, 2022
c48380c
Use PackageSpec for v1.3
sethaxen Dec 12, 2022
7666803
Merge test Project.tomls
sethaxen Dec 12, 2022
d3942c3
Move rstar test file out of own directory
sethaxen Dec 12, 2022
5e0bd5b
Fix version numbers
sethaxen Dec 12, 2022
11b7eb4
Merge branch 'main' into unifydimorder
sethaxen Dec 12, 2022
ff7e5e0
Apply suggestions from code review
sethaxen Dec 12, 2022
7ae2231
Merge branch 'main' into unifydimorder
sethaxen Dec 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.1.5"
version = "0.2.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
10 changes: 6 additions & 4 deletions src/discretediag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,11 @@ function discretediag_sub(
end

"""
discretediag(chains::AbstractArray{<:Real,3}; frac=0.3, method=:weiss, nsim=1_000)
discretediag(samples::AbstractArray{<:Real,3}; frac=0.3, method=:weiss, nsim=1_000)

Compute discrete diagnostic where `method` can be one of `:weiss`, `:hangartner`,
Compute discrete diagnostic on `samples` with shape `(parameters, draws, chains)`.

`method` can be one of `:weiss`, `:hangartner`,
`:DARBOOT`, `:MCBOOT`, `:billinsgley`, and `:billingsleyBOOT`.

# References
Expand All @@ -441,9 +443,9 @@ function discretediag(
)
0 < frac < 1 || throw(ArgumentError("`frac` must be in (0,1)"))

num_iters = size(chains, 1)
num_iters = size(chains, 2)
between_chain_vals, within_chain_vals, _, _ = discretediag_sub(
chains, frac, method, nsim, num_iters, num_iters
permutedims(chains, (2, 1, 3)), frac, method, nsim, num_iters, num_iters
)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

return between_chain_vals, within_chain_vals
Expand Down
8 changes: 4 additions & 4 deletions src/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ end
)

Estimate the effective sample size and the potential scale reduction of the `samples` of
shape (draws, parameters, chains) with the `method` and a maximum lag of `maxlag`.
shape `(parameters, draws, chains)` with the `method` and a maximum lag of `maxlag`.

See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref)
"""
Expand All @@ -211,8 +211,8 @@ function ess_rhat(
maxlag::Int=250,
)
# compute size of matrices (each chain is split!)
niter = size(chains, 1) ÷ 2
nparams = size(chains, 2)
niter = size(chains, 2) ÷ 2
nparams = size(chains, 1)
nchains = 2 * size(chains, 3)
ntotal = niter * nchains

Expand All @@ -238,7 +238,7 @@ function ess_rhat(
rhat = Vector{T}(undef, nparams)

# for each parameter
for (i, chains_slice) in enumerate((view(chains, :, i, :) for i in axes(chains, 2)))
for (i, chains_slice) in enumerate((selectdim(chains, 1, i) for i in axes(chains, 1)))
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
# check that no values are missing
if any(x -> x === missing, chains_slice)
rhat[i] = missing
Expand Down
18 changes: 10 additions & 8 deletions src/gelmandiag.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05)
niters, nparams, nchains = size(psi)
nparams, niters, nchains = size(psi)
nchains > 1 || error("Gelman diagnostic requires at least 2 chains")

rfixed = (niters - 1) / niters
rrandomscale = (nchains + 1) / (nchains * niters)

S2 = map(Statistics.cov, (view(psi, :, :, i) for i in axes(psi, 3)))
S2 = map(x -> Statistics.cov(x; dims=2), (view(psi, :, :, i) for i in axes(psi, 3)))
devmotion marked this conversation as resolved.
Show resolved Hide resolved
W = Statistics.mean(S2)

psibar = dropdims(Statistics.mean(psi; dims=1); dims=1)'
psibar = dropdims(Statistics.mean(psi; dims=2); dims=2)'
B = niters .* Statistics.cov(psibar)

w = LinearAlgebra.diag(W)
Expand Down Expand Up @@ -52,9 +52,10 @@ function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05)
end

"""
gelmandiag(chains::AbstractArray{<:Real,3}; alpha::Real=0.95)
gelmandiag(samples::AbstractArray{<:Real,3}; alpha::Real=0.95)

Compute the Gelman, Rubin and Brooks diagnostics [^Gelman1992] [^Brooks1998]. Values of the
Compute the Gelman, Rubin and Brooks diagnostics [^Gelman1992] [^Brooks1998] on `samples`
with shape `(parameters, draws, chains)`. Values of the
diagnostic’s potential scale reduction factor (PSRF) that are close to one suggest
convergence. As a rule-of-thumb, convergence is rejected if the 97.5 percentile of a PSRF
is greater than 1.2.
Expand All @@ -70,12 +71,13 @@ function gelmandiag(chains::AbstractArray{<:Real,3}; kwargs...)
end

"""
gelmandiag_multivariate(chains::AbstractArray{<:Real,3}; alpha::Real=0.05)
gelmandiag_multivariate(samples::AbstractArray{<:Real,3}; alpha::Real=0.05)

Compute the multivariate Gelman, Rubin and Brooks diagnostics.
Compute the multivariate Gelman, Rubin and Brooks diagnostics on `samples` with shape
`(parameters, draws, chains)`.
"""
function gelmandiag_multivariate(chains::AbstractArray{<:Real,3}; kwargs...)
niters, nparams, nchains = size(chains)
nparams, niters, nchains = size(chains)
if nparams < 2
error(
"computation of the multivariate potential scale reduction factor requires ",
Expand Down
9 changes: 7 additions & 2 deletions src/mcse.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...)
mcse(samples::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...)
mcse(samples::AbstractArray{<:Real,3}; method::Symbol=:imse, kwargs...)

Compute the Monte Carlo standard error (MCSE) of samples `x`.
Compute the Monte Carlo standard error (MCSE) of `samples` of shape `(draws,)` or
`(parameters, draws, chains)`
The optional argument `method` describes how the errors are estimated. Possible options are:

- `:bm` for batch means [^Glynn1991]
Expand All @@ -23,6 +25,9 @@ function mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...)
throw(ArgumentError("unsupported MCSE method $method"))
end
end
function mcse(x::AbstractArray{<:Real,3}; kwargs...)
return dropdims(mapslices(xi -> mcse(vec(xi); kwargs...), x; dims=(2, 3)); dims=(2, 3))
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end

function mcse_bm(x::AbstractVector{<:Real}; size::Int=floor(Int, sqrt(length(x))))
n = length(x)
Expand Down
43 changes: 31 additions & 12 deletions src/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
rstar(
rng=Random.GLOBAL_RNG,
classifier,
samples::AbstractMatrix,
chain_indices::AbstractVector{Int};
samples::AbstractArray,
[chain_indices::AbstractVector{Int}];
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
subset::Real=0.8,
verbosity::Int=0,
)

Compute the ``R^*`` convergence statistic of the `samples` with shape (draws, parameters)
and corresponding chains `chain_indices` with the `classifier`.
Compute the ``R^*`` convergence statistic of the `samples` with the `classifier`.

Either `samples` has shape `(parameters, draws, chains)`, or `samples` has shape
`(parameters, draws)` and `chain_indices` must be provided.

This implementation is an adaption of algorithms 1 and 2 described by Lambert and Vehtari.

Expand All @@ -32,16 +34,14 @@ is returned (algorithm 2).
```jldoctest rstar; setup = :(using Random; Random.seed!(100))
julia> using MLJBase, MLJXGBoostInterface, Statistics

julia> samples = fill(4.0, 300, 2);

julia> chain_indices = repeat(1:3; outer=100);
julia> samples = fill(4.0, 2, 100, 3);
```

One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the
probabilistic classifier.

```jldoctest rstar
julia> distribution = rstar(XGBoostClassifier(), samples, chain_indices);
julia> distribution = rstar(XGBoostClassifier(), samples);

julia> isapprox(mean(distribution), 1; atol=0.1)
true
Expand All @@ -54,7 +54,7 @@ predicting the mode. In MLJ this corresponds to a pipeline of models.
```jldoctest rstar
julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode);

julia> value = rstar(xgboost_deterministic, samples, chain_indices);
julia> value = rstar(xgboost_deterministic, samples);

julia> isapprox(value, 1; atol=0.2)
true
Expand All @@ -73,7 +73,7 @@ function rstar(
verbosity::Int=0,
)
# checks
size(x, 1) != length(y) && throw(DimensionMismatch())
size(x, 2) != length(y) && throw(DimensionMismatch())
0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)"))

# randomly sub-select training and testing set
Expand All @@ -88,11 +88,11 @@ function rstar(
# train classifier on training data
ycategorical = MLJModelInterface.categorical(y)
fitresult, _ = MLJModelInterface.fit(
classifier, verbosity, Tables.table(x[train_ids, :]), ycategorical[train_ids]
classifier, verbosity, Tables.table(x[:, train_ids]'), ycategorical[train_ids]
devmotion marked this conversation as resolved.
Show resolved Hide resolved
)

# compute predictions on test data
xtest = Tables.table(x[test_ids, :])
xtest = Tables.table(x[:, test_ids]')
predictions = _predict(classifier, fitresult, xtest)

# compute statistic
Expand All @@ -114,6 +114,17 @@ function _predict(model::MLJModelInterface.Model, fitresult, x)
end
end

function rstar(
rng::Random.AbstractRNG,
classifier::MLJModelInterface.Supervised,
x::AbstractArray{<:Any,3};
kwargs...
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
)
samples = reshape(x, size(x, 1), :)
chain_inds = repeat(axes(x, 3); inner=size(x, 2))
return rstar(rng, classifier, samples, chain_inds; kwargs...)
end

function rstar(
classif::MLJModelInterface.Supervised,
x::AbstractMatrix,
Expand All @@ -123,6 +134,14 @@ function rstar(
return rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end

function rstar(
classif::MLJModelInterface.Supervised,
x::AbstractArray{<:Any,3};
kwargs...,
)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
return rstar(Random.GLOBAL_RNG, classif, x; kwargs...)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end

# R⋆ for deterministic predictions (algorithm 1)
function _rstar(predictions::AbstractVector{T}, ytest::AbstractVector{T}) where {T}
length(predictions) == length(ytest) ||
Expand Down
2 changes: 1 addition & 1 deletion test/discretediag.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testset "discretediag.jl" begin
nparams = 4
nchains = 2
samples = rand(-100:100, 100, nparams, nchains)
samples = rand(-100:100, nparams, 100, nchains)

@testset "results" begin
for method in
Expand Down
10 changes: 5 additions & 5 deletions test/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
end

@testset "ESS and R̂ (IID samples)" begin
rawx = randn(10_000, 40, 10)
rawx = randn(40, 10_000, 10)

# Repeat tests with different scales
for scale in (1, 50, 100)
Expand All @@ -58,7 +58,7 @@
end

@testset "ESS and R̂ (identical samples)" begin
x = ones(10_000, 40, 10)
x = ones(40, 10_000, 10)

ess_standard, rhat_standard = ess_rhat(x)
ess_standard2, rhat_standard2 = ess_rhat(x; method=ESSMethod())
Expand All @@ -75,15 +75,15 @@
end

@testset "ESS and R̂ (single sample)" begin # check that issue #137 is fixed
x = rand(1, 5, 3)
x = rand(5, 1, 3)

for method in (ESSMethod(), FFTESSMethod(), BDAESSMethod())
# analyze array
ess_array, rhat_array = ess_rhat(x; method=method)

@test length(ess_array) == size(x, 2)
@test length(ess_array) == size(x, 1)
@test all(ismissing, ess_array) # since min(maxlag, niter - 1) = 0
@test length(rhat_array) == size(x, 2)
@test length(rhat_array) == size(x, 1)
@test all(ismissing, rhat_array)
end
end
Expand Down
4 changes: 2 additions & 2 deletions test/gelmandiag.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testset "gelmandiag.jl" begin
nparams = 4
nchains = 2
samples = randn(100, nparams, nchains)
samples = randn(nparams, 100, nchains)

@testset "results" begin
result = @inferred(gelmandiag(samples))
Expand All @@ -24,6 +24,6 @@

@testset "exceptions" begin
@test_throws ErrorException gelmandiag(samples[:, :, 1:1])
@test_throws ErrorException gelmandiag_multivariate(samples[:, 1:1, :])
@test_throws ErrorException gelmandiag_multivariate(samples[1:1, :, :])
end
end
24 changes: 21 additions & 3 deletions test/mcse.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
@testset "mcse.jl" begin
samples = randn(100)

@testset "results" begin
@testset "results 1d" begin
samples = randn(100)
result = @inferred(mcse(samples))
@test result isa Float64
@test result > 0
Expand All @@ -13,13 +12,32 @@
end
end

@testset "results 3d" begin
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
nparams = 2
nchains = 4
samples = randn(nparams, 100, nchains)
result = mcse(samples) # mapslices is not type-inferrable
@test result isa Vector{Float64}
@test length(result) == nparams
@test all(r -> r > 0, result)

for method in (:imse, :ipse, :bm)
result = mcse(samples) # mapslices is not type-inferrable
@test result isa Vector{Float64}
@test length(result) == nparams
@test all(r -> r > 0, result)
end
end

@testset "warning" begin
samples = randn(100)
for size in (51, 75, 100, 153)
@test_logs (:warn,) mcse(samples; method=:bm, size=size)
end
end

@testset "exception" begin
samples = randn(100)
@test_throws ArgumentError mcse(samples; method=:somemethod)
end
end
1 change: 1 addition & 0 deletions test/rstar/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand Down
Loading