Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fixes for rstar #52

Merged
merged 24 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.2.0"
version = "0.2.1"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Expand All @@ -18,6 +19,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
[compat]
AbstractFFTs = "0.5, 1"
DataAPI = "1.6"
DataStructures = "0.18.3"
Distributions = "0.25"
MLJModelInterface = "1.6"
SpecialFunctions = "0.8, 0.9, 0.10, 1, 2"
Expand Down
2 changes: 2 additions & 0 deletions src/MCMCDiagnosticTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module MCMCDiagnosticTools

using AbstractFFTs: AbstractFFTs
using DataAPI: DataAPI
using DataStructures: DataStructures
using Distributions: Distributions
using MLJModelInterface: MLJModelInterface
using SpecialFunctions: SpecialFunctions
Expand All @@ -22,6 +23,7 @@ export mcse
export rafterydiag
export rstar

include("utils.jl")
include("bfmi.jl")
include("discretediag.jl")
include("ess.jl")
Expand Down
27 changes: 15 additions & 12 deletions src/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
classifier::MLJModelInterface.Supervised,
samples,
chain_indices::AbstractVector{Int};
subset::Real=0.8,
subset::Real=0.7,
split_chains::Int=2,
verbosity::Int=0,
)

Expand All @@ -23,26 +24,25 @@ function rstar(
classifier::MLJModelInterface.Supervised,
x,
y::AbstractVector{Int};
subset::Real=0.8,
subset::Real=0.7,
split_chains::Int=2,
verbosity::Int=0,
)
# checks
MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch())
0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)"))

ysplit = split_chain_indices(y, split_chains)

# randomly sub-select training and testing set
N = length(y)
Ntrain = round(Int, N * subset)
0 < Ntrain < N ||
train_ids, test_ids = shuffle_split_stratified(rng, ysplit, subset)
0 < length(train_ids) < length(y) ||
throw(ArgumentError("training and test data subsets must not be empty"))
ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain + 1):N)

xtable = _astable(x)

# train classifier on training data
ycategorical = MLJModelInterface.categorical(y)
ycategorical = MLJModelInterface.categorical(ysplit)
xtrain = MLJModelInterface.selectrows(xtable, train_ids)
fitresult, _ = MLJModelInterface.fit(
classifier, verbosity, xtrain, ycategorical[train_ids]
Expand Down Expand Up @@ -79,7 +79,8 @@ end
rng::Random.AbstractRNG=Random.default_rng(),
classifier::MLJModelInterface.Supervised,
samples::AbstractArray{<:Real,3};
subset::Real=0.8,
subset::Real=0.7,
split_chains::Int=2,
verbosity::Int=0,
)

Expand All @@ -91,8 +92,10 @@ This implementation is an adaption of algorithms 1 and 2 described by Lambert an

The `classifier` has to be a supervised classifier of the MLJ framework (see the
[MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/list_of_supported_models/#model_list)
for a list of supported models). It is trained with a `subset` of the samples. The training
of the classifier can be inspected by adjusting the `verbosity` level.
for a list of supported models). It is trained with a `subset` of the samples from each
chain. Each chain is split into `split_chains` separate chains to additionally check for
within-chain convergence. The training of the classifier can be inspected by adjusting the
`verbosity` level.

If the classifier is deterministic, i.e., if it predicts a class, the value of the ``R^*``
statistic is returned (algorithm 1). If the classifier is probabilistic, i.e., if it outputs
Expand Down
99 changes: 99 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
unique_indices(x) -> (unique, indices)

Return the results of `unique(collect(x))` along with the a vector of the same length whose
elements are the indices in `x` at which the corresponding unique element in `unique` is
found.
"""
function unique_indices(x)
inds = eachindex(x)
T = eltype(inds)
ind_map = DataStructures.SortedDict{eltype(x),Vector{T}}()
for i in inds
xi = x[i]
inds_xi = get!(ind_map, xi) do
return T[]
end
push!(inds_xi, i)
end
unique = collect(keys(ind_map))
indices = collect(values(ind_map))
return unique, indices
end

"""
split_chain_indices(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there some existing splitting functionality for ess? Is the plan to merge these eventually?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite merge, because there are two different types of splitting we can consider. This approach supports ragged chains and is as a result more complex and doesn't discard any draws (instead dividing the remainder across the earlier splits).

For ess/rhat, we don't support ragged chains so would discard draws if necessary to keep them the same length after splitting. This implementation is much simpler and can be done in a non-allocating way with just reshape and view on a 3d array. This will be part of #22.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing splitting functionality copy_split! will go away.

chain_inds::AbstractVector{Int},
split::Int=2,
) -> AbstractVector{Int}

Split each chain in `chain_inds` into `split` chains.

For each chain in `chain_inds`, all entries are assumed to correspond to draws that have
been ordered by iteration number. The result is a vector of the same length as `chain_inds`
where each entry is the new index of the chain that the corresponding draw belongs to.
"""
function split_chain_indices(c::AbstractVector{Int}, split::Int=2)
cnew = similar(c)
if split == 1
copyto!(cnew, c)
return cnew
end
_, indices = unique_indices(c)
chain_ind = 1
for inds in indices
ndraws_per_split, rem = divrem(length(inds), split)
# here we can't use Iterators.partition because it's greedy. e.g. we can't partition
# 4 items across 3 partitions because Iterators.partition(1:4, 1) == [[1], [2], [3]]
# and Iterators.partition(1:4, 2) == [[1, 2], [3, 4]]. But we would want
# [[1, 2], [3], [4]].
i = j = 0
ndraws_this_split = ndraws_per_split + (j < rem)
for ind in inds
cnew[ind] = chain_ind
if (i += 1) == ndraws_this_split
i = 0
j += 1
ndraws_this_split = ndraws_per_split + (j < rem)
chain_ind += 1
end
end
end
return cnew
end

"""
shuffle_split_stratified(
rng::Random.AbstractRNG,
group_ids::AbstractVector,
frac::Real,
) -> (inds1, inds2)

Randomly split the indices of `group_ids` into two groups, where `frac` indices from each
group are in `inds1` and the remainder are in `inds2`.

This is used, for example, to split data into training and test data while preserving the
class balances.
"""
function shuffle_split_stratified(
rng::Random.AbstractRNG, group_ids::AbstractVector, frac::Real
)
_, indices = unique_indices(group_ids)
T = eltype(eltype(indices))
N1_tot = sum(x -> round(Int, length(x) * frac), indices)
N2_tot = length(group_ids) - N1_tot
inds1 = Vector{T}(undef, N1_tot)
inds2 = Vector{T}(undef, N2_tot)
items_in_1 = items_in_2 = 0
for inds in indices
N = length(inds)
N1 = round(Int, N * frac)
N2 = N - N1
Random.shuffle!(rng, inds)
copyto!(inds1, items_in_1 + 1, inds, 1, N1)
copyto!(inds2, items_in_2 + 1, inds, N1 + 1, N2)
items_in_1 += N1
items_in_2 += N2
end
return inds1, inds2
end
17 changes: 14 additions & 3 deletions test/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
@test dist isa LocationScale
@test dist.ρ isa PoissonBinomial
@test minimum(dist) == 0
@test maximum(dist) == 3
@test maximum(dist) == 6
end
@test mean(dist) ≈ 1 rtol = 0.2
wrapper === Vector && break
Expand All @@ -48,7 +48,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
@test dist isa LocationScale
@test dist.ρ isa PoissonBinomial
@test minimum(dist) == 0
@test maximum(dist) == 4
@test maximum(dist) == 8
end
@test mean(dist) ≈ 1 rtol = 0.15

Expand All @@ -58,7 +58,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
100 .* cos.(1:N) 100 .* sin.(1:N)
])
chain_indices = repeat(1:2; inner=N)
dist = rstar(classifier, samples, chain_indices)
dist = rstar(classifier, samples, chain_indices; split_chains=1)

# Mean of the statistic should be close to 2, i.e., the classifier should be able to
# learn an almost perfect decision boundary between chains.
Expand All @@ -71,6 +71,17 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
@test maximum(dist) == 2
end
@test mean(dist) ≈ 2 rtol = 0.15

# Compute the R⋆ statistic for identical chains that individually have not mixed.
samples = ones(sz)
samples[div(N, 2):end, :] .= 2
chain_indices = repeat(1:4; outer=div(N, 4))
dist = rstar(classifier, samples, chain_indices; split_chains=1)
# without split chains cannot distinguish between chains
@test mean(dist) ≈ 1 rtol = 0.15
dist = rstar(classifier, samples, chain_indices)
# with split chains can learn almost perfect decision boundary
@test mean(dist) ≈ 2 rtol = 0.15
end
wrapper === Vector && continue

Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ using Test
Random.seed!(1)

@testset "MCMCDiagnosticTools.jl" begin
@testset "utils" begin
include("utils.jl")
end

@testset "Bayesian fraction of missing information" begin
include("bfmi.jl")
end
Expand Down
62 changes: 62 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using MCMCDiagnosticTools
using Test
using Random

@testset "unique_indices" begin
@testset "indices=$(eachindex(inds))" for inds in [
rand(11:14, 100), transpose(rand(11:14, 10, 10))
]
unique, indices = @inferred MCMCDiagnosticTools.unique_indices(inds)
@test unique isa Vector{Int}
if eachindex(inds) isa CartesianIndices{2}
@test indices isa Vector{Vector{CartesianIndex{2}}}
else
@test indices isa Vector{Vector{Int}}
end
@test issorted(unique)
@test issetequal(union(indices...), eachindex(inds))
for i in eachindex(unique, indices)
@test all(inds[indices[i]] .== unique[i])
end
end
end

@testset "split_chain_indices" begin
c = [2, 2, 1, 3, 4, 3, 4, 1, 2, 1, 4, 3, 3, 2, 4, 3, 4, 1, 4, 1]
@test @inferred(MCMCDiagnosticTools.split_chain_indices(c, 1)) == c

cnew = @inferred MCMCDiagnosticTools.split_chain_indices(c, 2)
@test issetequal(Base.unique(cnew), 1:maximum(cnew)) # check no indices skipped
unique, indices = MCMCDiagnosticTools.unique_indices(c)
uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew)
for (i, inew) in enumerate(1:2:7)
@test length(indicesnew[inew]) ≥ length(indicesnew[inew + 1])
@test indices[i] == vcat(indicesnew[inew], indicesnew[inew + 1])
end

cnew = MCMCDiagnosticTools.split_chain_indices(c, 3)
@test issetequal(Base.unique(cnew), 1:maximum(cnew)) # check no indices skipped
unique, indices = MCMCDiagnosticTools.unique_indices(c)
uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew)
for (i, inew) in enumerate(1:3:11)
@test length(indicesnew[inew]) ≥
length(indicesnew[inew + 1]) ≥
length(indicesnew[inew + 2])
@test indices[i] ==
vcat(indicesnew[inew], indicesnew[inew + 1], indicesnew[inew + 2])
end
end

@testset "shuffle_split_stratified" begin
rng = Random.default_rng()
c = rand(1:4, 100)
unique, indices = MCMCDiagnosticTools.unique_indices(c)
@testset "frac=$frac" for frac in [0.3, 0.5, 0.7]
inds1, inds2 = @inferred(MCMCDiagnosticTools.shuffle_split_stratified(rng, c, frac))
@test issetequal(vcat(inds1, inds2), eachindex(c))
for inds in indices
common_inds = intersect(inds1, inds)
@test length(common_inds) == round(frac * length(inds))
end
end
end