Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] draft subsampling bootstrap for mcse #1974

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 100 additions & 22 deletions arviz/stats/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,40 +341,70 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
)


def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None):
"""Calculate Markov Chain Standard Error statistic.
def mcse(
data,
*,
var_names=None,
method="mean",
prob=None,
func=None,
mcse_kwargs=None,
func_kwargs=None,
dask_kwargs=None,
):
r"""Calculate Markov Chain Standard Error statistic.

Parameters
----------
data : obj
data : InferenceData-like or 2D array-like
Any object that can be converted to an :class:`arviz.InferenceData` object
Refer to documentation of :func:`arviz.convert_to_dataset` for details
For ndarray: shape = (chain, draw).
For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
var_names : list
var_names : list of str, optional
Names of variables to include in the rhat report
method : str
Select mcse method. Valid methods are:
method : {'mean', 'sd', 'median', 'quantile', 'func'}, optional
The method to use when estimating the MCSE.
- "mean"
- "sd"
- "median"
- "quantile"
- "func"

prob : float
Methods "mean", "sd", "median" and "quantile" are described in [1]_.

prob : float, optional
Quantile information.
func : callable, optional
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could also consider allowing some strings here. e.g. using "circmean" expands to stats.circmean as func here and also fills the mcse_kwargs with {"var_func": stats.circvar}

Summary function whose MCSE should be calculated. Only used whem
method is "func".
TODO: add call signature info, something like ``func(ary, **func_kwargs)``
func_kwargs : dict, optional
Keyword arguments passed to *func* when calling it.
dask_kwargs : dict, optional
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.

Returns
-------
xarray.Dataset
Return the msce dataset
Dataset with the MCSE results

Other Parameters
----------------
mcse_kwargs : dict, optional
Extra keyword arguments passed to the MCSE estimation method.

See Also
--------
ess : Compute autocovariance estimates for every lag for the input array.
summary : Create a data frame with summary statistics.
plot_mcse : Plot quantile or local Monte Carlo Standard Error.
ess : Compute autocovariance estimates for every lag for the input array.

References
----------
.. [1] Vehtari, Aki, et al. "Rank-normalization, folding, and localization: an improved
$\hat{R}$ for assessing convergence of MCMC (with discussion)."
Bayesian analysis 16.2 (2021): 667-718. https://doi.org/10.1214/20-BA1221

Examples
--------
Expand All @@ -398,6 +428,7 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None):
"sd": _mcse_sd,
"median": _mcse_median,
"quantile": _mcse_quantile,
"func": _mcse_func_sbm,
}
if method not in methods:
raise TypeError(
Expand All @@ -410,32 +441,44 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None):
if method == "quantile" and prob is None:
raise TypeError("Quantile (prob) information needs to be defined.")

if method == "func" and func is None:
raise TypeError("func argument needs to be defined.")

mcse_kwargs = {} if mcse_kwargs is None else mcse_kwargs
if prob is not None:
mcse_kwargs.setdefault("prob", prob)
elif func is not None:
mcse_kwargs.setdefault("func", func)
mcse_kwargs.setdefault("func_kwargs", func_kwargs)

if isinstance(data, np.ndarray):
data = np.atleast_2d(data)
if len(data.shape) < 3:
if prob is not None:
return mcse_func(data, prob=prob) # pylint: disable=unexpected-keyword-arg

return mcse_func(data)

msg = (
"Only uni-dimensional ndarray variables are supported."
" Please transform first to dataset with `az.convert_to_dataset`."
)
raise TypeError(msg)
if data.size < 1000 and method == "func":
warnings.warn(
"Not enough samples for reliable estimate of MCSE for arbitrary functions"
)
return mcse_func(data, **mcse_kwargs)
else:
msg = (
"Only uni-dimensional ndarray variables are supported."
" Please transform first to dataset with `az.convert_to_dataset`."
)
raise TypeError(msg)

dataset = convert_to_dataset(data, group="posterior")
if (dataset.dims["chain"] * dataset.dims["draw"]) < 1000 and method == "func":
warnings.warn("Not enough samples for reliable estimate of MCSE for arbitrary functions")
var_names = _var_names(var_names, dataset)

dataset = dataset if var_names is None else dataset[var_names]

ufunc_kwargs = {"ravel": False}
func_kwargs = {} if prob is None else {"prob": prob}
return _wrap_xarray_ufunc(
mcse_func,
dataset,
ufunc_kwargs=ufunc_kwargs,
func_kwargs=func_kwargs,
func_kwargs=mcse_kwargs,
dask_kwargs=dask_kwargs,
)

Expand Down Expand Up @@ -813,13 +856,48 @@ def _mcse_mean(ary):
return np.nan
ess = _ess_mean(ary)
if _numba_flag:
sd = _sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1))
sd = _sqrt(svar(np.ravel(ary), ddof=1), 0)
else:
sd = np.std(ary, ddof=1)
mcse_mean_value = sd / np.sqrt(ess)
return mcse_mean_value


def _mcse_func_sbm(ary, func, b=None, var_func=np.var, func_kwargs=None):
"""Compute the Markov Chain error on an arbitrary function."""
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=10, min_chains=1)):
return np.nan
n = ary.size
if b is None:
b = int(np.sqrt(n))
if func_kwargs is None:
func_kwargs = {}
func_estimate_sd = _sbm(ary, func, b=b, var_func=var_func, func_kwargs=func_kwargs)
mcse_func_value = func_estimate_sd / np.sqrt(n)
return mcse_func_value


def _sbm(ary, func, b, var_func, func_kwargs):
"""Subsampling bootstrap method.

References
----------
.. [1] Doss, Charles R., et al. "Markov chain Monte Carlo estimation of quantiles."
*Electronic Journal of Statistics* 8.2 (2014): 2448-2478.
https://doi.org/10.1214/14-EJS957

"""
flat_ary = np.ravel(ary)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this sensitive for which order the ravel is done?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it technically is, but there should be no difference (hopefully) if the model has converged. It is also not clear to me how should multiple chains be handled when implementing this algorithm, I started with this flatten approach but I can test a couple options.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@OriolAbril have you tested alternative approaches for handling multiple chains yet?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I benchmarked 3 different approaches for estimating the mcse of the mean

  • ess: using ess to estimate mcse, for reference
  • sbm_stack_chains: estimating mcse with SBM by concatenating the chains (as done here)
  • sbm_stack_draws: estimating mcse with SBM by interleaving chains (i.e. flattening on the other dimension)
  • sbm_separate_chains: estimate the mcse of each chain with SBM and then sum the variances and divide by nchains^2.
  • sbm_shuffle: flatten the chains and shuffle the draws before running SBM
    I performed the same benchmark as in https://avehtari.github.io/rhat_ess/ess_comparison.html, and transforming the chains to target different stationary distributions. Here's the result:
    mcse_methods_rmse

In all cases SBM underestimates the MCSE; this is particularly severe when autocorrelation is high and sample sizes are low. sbm_stack_chains is consistently better than the alternatives though. I didn't even bother plotting sbm_shuffle, since it was apparent pretty quickly that it was far worse than the others.

n = len(flat_ary)
func_estimates = np.empty(n - b)
for i in range(n - b):
sub_ary = flat_ary[i : i + b]
func_estimates[i] = func(sub_ary, **func_kwargs)
func_estimate_sd = np.sqrt(b * var_func(func_estimates))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably decide API-wise if we want to keep this or instead move to std_func and multiply that by the square root of b.

return func_estimate_sd


def _mcse_sd(ary):
"""Compute the Markov Chain sd error."""
_numba_flag = Numba.numba_flag
Expand Down