Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rank-normalized ESS and other variants #22

Closed
sethaxen opened this issue Sep 26, 2021 · 14 comments
Closed

Add rank-normalized ESS and other variants #22

sethaxen opened this issue Sep 26, 2021 · 14 comments

Comments

@sethaxen
Copy link
Member

ESS is defined in the context of a specific estimate. For example, our current ESS implementations are all in the context of estimating the mean, so they will have problems whenever the mean/variance is not finite, e.g. for the Cauchy distribution. Vehtari et al proposed several variants to the ESS/R-hat for diagnosing various problems that can manifest in posterior draws.

The different variants covered in the paper fall in the following categories

Splitting:

  • split
  • non-split

Pre-processing:

  • rank-normalization to satisfy assumption of normality
  • folding to avoid R-hat being fooled when chains have same location but different scales

Quantity being estimated:

  • mean
  • std
  • quantile
  • quantile interval
  • mad

Some notes:

  • splitting seems universally better than non-splitting, so I don't think we should support non-splitting.
  • rank-normalization seems in general to be more robust than non-normalization, except for when one computes an MCSE; then one needs the un-rank-normalized version. So to me it makes sense for bulk-ESS (rank-normalized, unfolded ESS for estimate of mean) to be the default ESS method, which would also be consistent with ArviZ and Stan.
  • Estimating ESSs for specific estimates generally requires pre- and post-processing of ESS estimates. Some examples given below (ignore the methods, this is not an API proposal):
ess_rhat(x, ::Val{:bulk}; kwargs...) = ess_rhat(rank_normalize(x); kwargs...)
function ess_rhat(x, ::Val{:tail}, p=0.90; kwargs...)
    return min(ess_rhat(x, Val(:quantile), (1-p)/2; kwargs...), ess_rhat(x, Val(:quantile), (1+p)/2; kwargs...))
end
ess_rhat(x, ::Val{:fold}; kwargs...) = ess_rhat(rank_normalize(fold(x)))
ess_rhat(x, ::Val{:quantile}, p; kwargs...) = ess_rhat(rank_normalize(qindicator(x, p)); kwargs...)
ess_rhat(x, ::Val{:median}; kwargs...) = ess_rhat(x, Val(:quantile), 0.5; kwargs...)
ess_rhat(x, ::Val{:mad}; kwargs...) = ess_rhat(rank_normalize(fold(x)), Val(:median); kwargs...)
function ess_rhat(x, ::Val{:qinterval}, pl, pu; kwargs...)
    return ess_rhat(rank_normalize(qindicatorrange(x, pl, pu)); kwargs...)
end
ess_rhat(x, ::Val{:mean}, p) = ess_rhat(x)
ess_rhat(x, ::Val{:std}, p) = min.(ess_rhat(x), ess_rhat(x.^2))

The current AbstractESSMethod approach doesn't give us the flexibility of specifying these variants. It also uses different types just more or less to specify differences in how the autocovariance is computed, whereas we can see there are more knobs a user might like to turn. Now would seem to be a good time to revisit this API.

@sethaxen
Copy link
Member Author

Relates #10 and #4

@sethaxen
Copy link
Member Author

Any thoughts on this @devmotion?

@devmotion
Copy link
Member

As mentioned somewhere else (probably in the discussion in ParetoSmooth.jl?) it would be great to improve the implementation further. We already compute (only) the split version and it was designed to be modular, but only modulo the different algorithms to compute the autocorrelation terms.

Design-wise it seems a bit easier and more modular if the orthogonal parts are kept separate. For instance, rank-normalization can be implemented and be performed with a separate function and if desired one can just call ess_rhat with the resulting values instead of the original chains. Also folding or the quantile transformation to 0s and 1s can be performed and combined easily if they are implemented in separate functions.

(For completeness: In contrast to the observation in the paper we observed that generally the FFT-based approach is the slowest, by a large margin and consistent with the observation in MCMCDiagnostics. One reason seems to be that usually only a few lags are needed until the autocorrelation terms are sufficiently small and computation can be stopped (using the same approach based on Geyer's recommendation as in the FFT case), another reason might be that FFTs were not implemented in Julia but performed with FFTW which might affect benchmarks (e.g., in StatsFuns the pure Julia implementations seem to be significantly faster than calling into Rmath). We did not notice any numerical issues of the other algorithms for the autocorrelation estimates compared with the FFT approach but, of course, maybe such problems exist and the FFT algorithm is more accurate, as stated by the authors.)

@ParadaCarleton
Copy link
Member

ParadaCarleton commented Mar 22, 2022

splitting seems universally better than non-splitting, so I don't think we should support non-splitting.

I don't think that's true, so much as splitting measures something different from non-splitting. Splitting makes r-hat a combined estimator of stationarity and mixing, while not splitting means r-hat only measures the latter. As an example, split r-hat will be worse at detecting multiple modes, because each chain will appear to have a stationary mean.

I think it's probably better as a default, although I've considered whether some other trend-based estimate of the stationarity would be better than splitting. I suspect directly trying to estimate an exponential trend and then using that to estimate the variance inflation would work better.

@devmotion devmotion mentioned this issue Jun 30, 2022
@sethaxen
Copy link
Member Author

sethaxen commented Dec 5, 2022

I've locally worked on a prototype design that I think works well. A few notes.

First, I propose we decouple ess_rhat into ess and rhat. Just because a user wants a specific variant of rhat does not mean they want that same variant for ess, and vice versa. While almost all of the work to compute rhat is also performed in ess, the reverse is not true, and there are variants like nested R-hat (#23) for which no corresponding ESS method yet exists. Also, tail-ESS and tail-Rhat use different transformations. We can, of course, keep an ess_rhat convenience function that doesn't duplicate the shared work.

I propose the following interface:

# by default return rank-normalized 
ess(x; kwargs...) = ess_bulk(x; kwargs...)
# ess for estimator f. by default split each chain into 2.
ess(f, x; nsplit::Int=2, method, kwargs...)
# bulk-ESS
ess_bulk(x; kwargs...) = ess(mean, rank_normalize(x); kwargs...)
# tail-ESS
ess_tail(x; kwargs...)
# convert x into a proxy expectand, i.e. one whose mean-ESS approximates the ESS of the estimator f
# this is used in the default ess(f, x)
as_expectand(f, x, sample_dims)

rhat(x; kwargs...) = max.(rhat_bulk(x; kwargs...), rhat_tail(x; kwargs...))
rhat_bulk(x; kwargs...) = rhat(mean, rank_normalize(x); kwargs...)
rhat_tail(x; kwargs...) = rhat(mean, rank_normalize(fold(x)); kwargs...)
rhat(f, x; nsplit::Int=2)

rank_normalize(x; dims)
fold(x; dims) = abs.(x .- median(x; dims))

This API is lightweight and flexible enough to cover all specialized ESS methods in ArviZ, posterior, and brms and those in the split-Rhat paper. Splitting is by default done but also easily disabled or increased, and a user can easily add in folding or rank-normalization with convenience functions. The recommended ESS variants from the split-Rhat paper not connected to specific expectations (bulk-ESS, tail-ESS, and the corresponding Rhats) are provided through convenience functions, so the user doesn't need to know which transformations to apply.

It's possible there are estimators whose ESS is best approximated by combining mean-ESS for multiple proxy expectands (std-ESS was an example of this, but a better method is now used that only requires a single proxy, see issues linked in #39), so I propose we go with this API for now and rework it if necessary later.

While we could support fancier splitting by passing a split_chains function to the methods that handles the splitting, I propose we wait to add this complexity until there's a paper demonstrating a different splitting approach that improves ESS/R-hat estimates.

@ParadaCarleton
Copy link
Member

I think that sounds great. Feel free to send me links to any relevant branches in Slack -- I was thinking of implementing something like this myself.

@sethaxen
Copy link
Member Author

I performed some informal benchmarks, and for realistic chains, with maxlag=typemax(Int), mean-ess_rhat is about 12x slower than a simple mean-rhat for highly autocorrelated chains and about 4x slower for uncorrelated chains. But for reference, rank-normalizing or folding the draws (not even computing R-hat or ESS) is 93x and 14x slower, respectively, than computing mean-R-hat itself.

The recommended R-hat is the maximum of bulk- and tail-R-hat, so it folds once and rank-normalizes and computes mean-Rhat twice, so it's about 200x the cost of mean-R-hat with the existing implementations. So in terms of speeding up the implementations, it makes sense to focus our efforts on

  1. speeding up the transformations if possible (perhaps by transforming one parameter at a time )
  2. only performing each transformation once.

The latter point is a good argument for keeping standalone rhat, rhat_tail, rhat_bulk, and ess_tail methods but also having ess_rhat and ess_rhat_bulk methods that share transformations between the methods.

@sethaxen
Copy link
Member Author

My (likely) final design is the following API. First there are the methods users are most likely to call, which should have the most complete documentation:

  • rhat(x): default R-hat; computes maximum of rhat_bulk and rhat_tail
  • ess(x): default bulk-ESS
  • ess(f, x): ESS for estimator f
  • ess_rhat(x): default ess(x) and rhat(x) but more efficient

Then there are the less common methods, most of which are used by the above ones, which should be only lightly documented:

  • rhat(mean, x): mean (or basic) R-hat
  • rhat_bulk: bulk-R-hat
  • rhat_tail: tail-R-hat
  • ess_rhat_bulk(x): bulk-ESS and bulk-R-hat
  • ess_bulk(x): bulk-ESS. included solely for completeness.
  • ess_tail(x): tail-ESS

The idea is that calling rhat, ess, or ess_rhat without an estimator should give you the current recommended general-purpose diagnostic and clearly document what that is.

@ParadaCarleton
Copy link
Member

Hmm, is there maybe a way to make this interface cleaner or more generalizable? e.g. rhat_bulk(x) could be replaced with rhat(x; transform=rank_norm), where we provide a rank_norm transform for robustness. As a bonus, this extends to other kinds of transforms, e.g. if users want rhat for log-transformed data.

@sethaxen
Copy link
Member Author

I don't see the benefit of a transform keyword, when we would just be applying the transform to x as a whole. They could just call rhat(rank_norm(x)) or rhat(log.(x)), which they can do anyways. The main reason to have rhat_bulk and rhat_tail is that the user should not need to know what transform is used to construct each of those diagnostics (it's not useful for understanding the interpretation of the diagnostic), and requiring them to know or remember that poses a barrier to them computing the diagnostic.

Also ess_tail is constructed from the combination of two ess estimates so can't be computed with such a transformation.

Lastly, rank_normalize doesn't always convey robustness. e.g. we don't want the user combining rank_normalize with computing the ESS for std; they may compute an ESS then, but it won't be the ESS for std.

I do agree though that it's not ideal how many methods we have. e.g. for ArviZ and MCMCChains to completely extend our methods for their storage types, they would need to implement 10 methods (yikes!)

By comparison, posterior's API is:

  • ess_basic()
  • ess_bulk()
  • ess_tail()
  • ess_quantile()
  • ess_sd()
  • mcse_mean()
  • mcse_quantile()
  • mcse_sd()
  • rhat_basic()
  • rhat()

which is not too dissimilar to what I'm proposing here,

and Python ArviZ's API is:

  • rhat(x, method)
  • ess(x, method)

which is nice in its simplicity but requires each method be named. so e.g. "mean" and "std" are methods, as is "bulk", instead of passing an estimator as we do in mcse and ess_rhat now.

@sethaxen
Copy link
Member Author

I suppose we could adopt Python ArviZ's syntax for the base rhat, ess, and ess_rhat methods, with a kind or method keyword that we then use to dispatch to the other methods. This would be both easy for users to use and package developers to support.

@sethaxen
Copy link
Member Author

Here's my best idea for a keyword API that works for ess and rhat but also is consistent with mcse:

  • Instead of our current mcse(estimator, x) and ess_rhat(estimator, x) syntax, we would have ess and mcse accept an estimator keyword providing the estimator.
  • For ess, only if estimator is not provided, a type keyword may be provided specifying :bulk, :tail, or :basic/:mean.
  • For rhat, the type parameter accepts :bulk :tail, :basic/:mean and :rank (the maximum of the two).
  • ess_rhat would have the same type options as rhat, where :rank would return :bulk for ESS and :rank for R-hat.

Suggestions, @devmotion and @ParadaCarleton?

@sethaxen
Copy link
Member Author

This week I will update the API similar to my previous comment, so we can hopefully make the breaking release.

@sethaxen
Copy link
Member Author

Fixed by #72

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants