Skip to content

Commit

Permalink
Change dimension ordering (#50)
Browse files Browse the repository at this point in the history
* Update discretediag

* Update ess_rhat

* Update gelmandiag

* Update rstar

* Add 3d array method for rstar

* Add 3d array method for mcse

* Update docstrings

* Change dimension order in tests

* Update rstar tests with permuted dims

* Test rstar with 3d array

* Test mcse with 3d array

* Increment version number

* Bump compat

* Run formatter

* Update seed

Necessary because chain_inds are not identical to those in the previous example (now repeating with inner instead of outer)

* Add Compat as dependency

* Use eachslice

* Replace mapslices with an explicit loop

* Run formatter

* Avoid explicit axis check

* Remove type-instability

* Accept any table to rstar

* Update rstar documentation

* Release type constraint

* Support rstar taking matrices or vectors

* Update rstar tests

* Add type consistency check

* Don't permutedims in discretediag

* Revert all changes to mcse

* Reorder dimensions to (draw, chain, params)

* Apply suggestions from code review

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Split rstar docstring

* Convert to table once

* Clean up language

* Use correct variable name

* Allow Tables v1

* Bump Julia compat to v1.3

* Remove compat dependency

* Remove all special-casing for less than v1.2

* Test on v1.3

* Use PackageSpec for v1.3

* Merge test Project.tomls

* Move rstar  test file out of own directory

* Fix version numbers

* Apply suggestions from code review

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
sethaxen and devmotion authored Dec 12, 2022
1 parent 892b00c commit 8d74357
Show file tree
Hide file tree
Showing 17 changed files with 262 additions and 203 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.3'
- '1'
- 'nightly'
os:
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.1.5"
version = "0.2.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -23,7 +23,7 @@ MLJModelInterface = "1.6"
SpecialFunctions = "0.8, 0.9, 0.10, 1, 2"
StatsBase = "0.33"
Tables = "1"
julia = "1"
julia = "1.3"

[extras]
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Documenter = "0.27"
MCMCDiagnosticTools = "0.1"
MCMCDiagnosticTools = "0.2"
MLJBase = "0.19, 0.20, 0.21"
MLJXGBoostInterface = "0.1, 0.2, 0.3"
julia = "1.3"
14 changes: 8 additions & 6 deletions src/discretediag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ function discretediag_sub(
start_iter::Int,
step_size::Int,
)
num_iters, num_vars, num_chains = size(c)
num_iters, num_chains, num_vars = size(c)

## Between-chain diagnostic
length_results = length(start_iter:step_size:num_iters)
Expand All @@ -384,7 +384,7 @@ function discretediag_sub(
pvalue=Vector{Float64}(undef, num_vars),
)
for j in 1:num_vars
X = convert(AbstractMatrix{Int}, c[:, j, :])
X = convert(AbstractMatrix{Int}, c[:, :, j])
result = diag_all(X, method, nsim, start_iter, step_size)

plot_vals_stat[:, j] .= result.stat ./ result.df
Expand All @@ -403,7 +403,7 @@ function discretediag_sub(
)
for k in 1:num_chains
for j in 1:num_vars
x = convert(AbstractVector{Int}, c[:, j, k])
x = convert(AbstractVector{Int}, c[:, k, j])

idx1 = 1:round(Int, frac * num_iters)
idx2 = round(Int, num_iters - frac * num_iters + 1):num_iters
Expand All @@ -423,14 +423,16 @@ function discretediag_sub(
end

"""
discretediag(chains::AbstractArray{<:Real,3}; frac=0.3, method=:weiss, nsim=1_000)
discretediag(samples::AbstractArray{<:Real,3}; frac=0.3, method=:weiss, nsim=1_000)
Compute discrete diagnostic where `method` can be one of `:weiss`, `:hangartner`,
Compute discrete diagnostic on `samples` with shape `(draws, chains, parameters)`.
`method` can be one of `:weiss`, `:hangartner`,
`:DARBOOT`, `:MCBOOT`, `:billinsgley`, and `:billingsleyBOOT`.
# References
Benjamin E. Deonovic, & Brian J. Smith. (2017). Convergence diagnostics for MCMC draws of a categorical variable.
Benjamin E. Deonovic, & Brian J. Smith. (2017). Convergence diagnostics for MCMC draws of a categorical variable.
"""
function discretediag(
chains::AbstractArray{<:Real,3}; frac::Real=0.3, method::Symbol=:weiss, nsim::Int=1000
Expand Down
8 changes: 4 additions & 4 deletions src/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ end
)
Estimate the effective sample size and the potential scale reduction of the `samples` of
shape (draws, parameters, chains) with the `method` and a maximum lag of `maxlag`.
shape `(draws, chains, parameters)` with the `method` and a maximum lag of `maxlag`.
See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref)
"""
Expand All @@ -212,8 +212,8 @@ function ess_rhat(
)
# compute size of matrices (each chain is split!)
niter = size(chains, 1) ÷ 2
nparams = size(chains, 2)
nchains = 2 * size(chains, 3)
nparams = size(chains, 3)
nchains = 2 * size(chains, 2)
ntotal = niter * nchains

# do not compute estimates if there is only one sample or lag
Expand All @@ -238,7 +238,7 @@ function ess_rhat(
rhat = Vector{T}(undef, nparams)

# for each parameter
for (i, chains_slice) in enumerate((view(chains, :, i, :) for i in axes(chains, 2)))
for (i, chains_slice) in enumerate(eachslice(chains; dims=3))
# check that no values are missing
if any(x -> x === missing, chains_slice)
rhat[i] = missing
Expand Down
19 changes: 11 additions & 8 deletions src/gelmandiag.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05)
niters, nparams, nchains = size(psi)
niters, nchains, nparams = size(psi)
nchains > 1 || error("Gelman diagnostic requires at least 2 chains")

rfixed = (niters - 1) / niters
rrandomscale = (nchains + 1) / (nchains * niters)

S2 = map(Statistics.cov, (view(psi, :, :, i) for i in axes(psi, 3)))
# `eachslice(psi; dims=2)` breaks type inference
S2 = map(x -> Statistics.cov(x; dims=1), (view(psi, :, i, :) for i in axes(psi, 2)))
W = Statistics.mean(S2)

psibar = dropdims(Statistics.mean(psi; dims=1); dims=1)'
psibar = dropdims(Statistics.mean(psi; dims=1); dims=1)
B = niters .* Statistics.cov(psibar)

w = LinearAlgebra.diag(W)
Expand Down Expand Up @@ -52,9 +53,10 @@ function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05)
end

"""
gelmandiag(chains::AbstractArray{<:Real,3}; alpha::Real=0.95)
gelmandiag(samples::AbstractArray{<:Real,3}; alpha::Real=0.95)
Compute the Gelman, Rubin and Brooks diagnostics [^Gelman1992] [^Brooks1998]. Values of the
Compute the Gelman, Rubin and Brooks diagnostics [^Gelman1992] [^Brooks1998] on `samples`
with shape `(draws, chains, parameters)`. Values of the
diagnostic’s potential scale reduction factor (PSRF) that are close to one suggest
convergence. As a rule-of-thumb, convergence is rejected if the 97.5 percentile of a PSRF
is greater than 1.2.
Expand All @@ -70,12 +72,13 @@ function gelmandiag(chains::AbstractArray{<:Real,3}; kwargs...)
end

"""
gelmandiag_multivariate(chains::AbstractArray{<:Real,3}; alpha::Real=0.05)
gelmandiag_multivariate(samples::AbstractArray{<:Real,3}; alpha::Real=0.05)
Compute the multivariate Gelman, Rubin and Brooks diagnostics.
Compute the multivariate Gelman, Rubin and Brooks diagnostics on `samples` with shape
`(draws, chains, parameters)`.
"""
function gelmandiag_multivariate(chains::AbstractArray{<:Real,3}; kwargs...)
niters, nparams, nchains = size(chains)
niters, nchains, nparams = size(chains)
if nparams < 2
error(
"computation of the multivariate potential scale reduction factor requires ",
Expand Down
2 changes: 1 addition & 1 deletion src/rafterydiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function rafterydiag(
dichot = Int[(x .<= StatsBase.quantile(x, q))...]
kthin = 0
bic = 1.0
local test , ntest
local test, ntest
while bic >= 0.0
kthin += 1
test = dichot[1:kthin:nx]
Expand Down
154 changes: 94 additions & 60 deletions src/rstar.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,91 @@
"""
rstar(
rng=Random.GLOBAL_RNG,
classifier,
samples::AbstractMatrix,
rng::Random.AbstractRNG=Random.default_rng(),
classifier::MLJModelInterface.Supervised,
samples,
chain_indices::AbstractVector{Int};
subset::Real=0.8,
verbosity::Int=0,
)
Compute the ``R^*`` convergence statistic of the `samples` with shape (draws, parameters)
and corresponding chains `chain_indices` with the `classifier`.
Compute the ``R^*`` convergence statistic of the table `samples` with the `classifier`.
`samples` must be either an `AbstractMatrix`, an `AbstractVector`, or a table
(i.e. implements the Tables.jl interface) whose rows are draws and whose columns are
parameters.
`chain_indices` indicates the chain ids of each row of `samples`.
This method supports ragged chains, i.e. chains of nonequal lengths.
"""
function rstar(
rng::Random.AbstractRNG,
classifier::MLJModelInterface.Supervised,
x,
y::AbstractVector{Int};
subset::Real=0.8,
verbosity::Int=0,
)
# checks
MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch())
0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)"))

# randomly sub-select training and testing set
N = length(y)
Ntrain = round(Int, N * subset)
0 < Ntrain < N ||
throw(ArgumentError("training and test data subsets must not be empty"))
ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain + 1):N)

xtable = _astable(x)

# train classifier on training data
ycategorical = MLJModelInterface.categorical(y)
xtrain = MLJModelInterface.selectrows(xtable, train_ids)
fitresult, _ = MLJModelInterface.fit(
classifier, verbosity, xtrain, ycategorical[train_ids]
)

# compute predictions on test data
xtest = MLJModelInterface.selectrows(xtable, test_ids)
predictions = _predict(classifier, fitresult, xtest)

# compute statistic
ytest = ycategorical[test_ids]
result = _rstar(predictions, ytest)

return result
end

_astable(x::AbstractVecOrMat) = Tables.table(x)
_astable(x) = Tables.istable(x) ? x : throw(ArgumentError("Argument is not a valid table"))

# Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863
# `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information
# TODO: Remove once the upstream issue is fixed
function _predict(model::MLJModelInterface.Model, fitresult, x)
y = MLJModelInterface.predict(model, fitresult, x)
return if :predict in MLJModelInterface.reporting_operations(model)
first(y)
else
y
end
end

"""
rstar(
rng::Random.AbstractRNG=Random.default_rng(),
classifier::MLJModelInterface.Supervised,
samples::AbstractArray{<:Real,3};
subset::Real=0.8,
verbosity::Int=0,
)
Compute the ``R^*`` convergence statistic of the `samples` with the `classifier`.
`samples` is an array of draws with the shape `(draws, chains, parameters)`.`
This implementation is an adaption of algorithms 1 and 2 described by Lambert and Vehtari.
Expand All @@ -29,19 +105,17 @@ is returned (algorithm 2).
# Examples
```jldoctest rstar; setup = :(using Random; Random.seed!(100))
```jldoctest rstar; setup = :(using Random; Random.seed!(101))
julia> using MLJBase, MLJXGBoostInterface, Statistics
julia> samples = fill(4.0, 300, 2);
julia> chain_indices = repeat(1:3; outer=100);
julia> samples = fill(4.0, 100, 3, 2);
```
One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the
probabilistic classifier.
```jldoctest rstar
julia> distribution = rstar(XGBoostClassifier(), samples, chain_indices);
julia> distribution = rstar(XGBoostClassifier(), samples);
julia> isapprox(mean(distribution), 1; atol=0.1)
true
Expand All @@ -54,7 +128,7 @@ predicting the mode. In MLJ this corresponds to a pipeline of models.
```jldoctest rstar
julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode);
julia> value = rstar(xgboost_deterministic, samples, chain_indices);
julia> value = rstar(xgboost_deterministic, samples);
julia> isapprox(value, 1; atol=0.2)
true
Expand All @@ -67,60 +141,20 @@ Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic
function rstar(
rng::Random.AbstractRNG,
classifier::MLJModelInterface.Supervised,
x::AbstractMatrix,
y::AbstractVector{Int};
subset::Real=0.8,
verbosity::Int=0,
x::AbstractArray{<:Any,3};
kwargs...,
)
# checks
size(x, 1) != length(y) && throw(DimensionMismatch())
0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)"))

# randomly sub-select training and testing set
N = length(y)
Ntrain = round(Int, N * subset)
0 < Ntrain < N ||
throw(ArgumentError("training and test data subsets must not be empty"))
ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain + 1):N)

# train classifier on training data
ycategorical = MLJModelInterface.categorical(y)
fitresult, _ = MLJModelInterface.fit(
classifier, verbosity, Tables.table(x[train_ids, :]), ycategorical[train_ids]
)

# compute predictions on test data
xtest = Tables.table(x[test_ids, :])
predictions = _predict(classifier, fitresult, xtest)

# compute statistic
ytest = ycategorical[test_ids]
result = _rstar(predictions, ytest)

return result
samples = reshape(x, :, size(x, 3))
chain_inds = repeat(axes(x, 2); inner=size(x, 1))
return rstar(rng, classifier, samples, chain_inds; kwargs...)
end

# Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863
# `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information
# TODO: Remove once the upstream issue is fixed
function _predict(model::MLJModelInterface.Model, fitresult, x)
y = MLJModelInterface.predict(model, fitresult, x)
return if :predict in MLJModelInterface.reporting_operations(model)
first(y)
else
y
end
function rstar(classif::MLJModelInterface.Supervised, x, y::AbstractVector{Int}; kwargs...)
return rstar(Random.default_rng(), classif, x, y; kwargs...)
end

function rstar(
classif::MLJModelInterface.Supervised,
x::AbstractMatrix,
y::AbstractVector{Int};
kwargs...,
)
return rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...)
function rstar(classif::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; kwargs...)
return rstar(Random.default_rng(), classif, x; kwargs...)
end

# R⋆ for deterministic predictions (algorithm 1)
Expand Down
14 changes: 13 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Distributions = "0.25"
FFTW = "1.1"
julia = "1"
MCMCDiagnosticTools = "0.2"
MLJBase = "0.19, 0.20, 0.21"
MLJLIBSVMInterface = "0.1, 0.2"
MLJXGBoostInterface = "0.1, 0.2, 0.3"
Tables = "1"
julia = "1.3"
Loading

2 comments on commit 8d74357

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/74017

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 8d7435795164d0faeae2e3c83f94f125c214ebcc
git push origin v0.2.0

Please sign in to comment.