Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Negative ESS #40

Closed
mschauer opened this issue Jun 28, 2022 · 8 comments · Fixed by #58
Closed

Negative ESS #40

mschauer opened this issue Jun 28, 2022 · 8 comments · Fixed by #58

Comments

@mschauer
Copy link

In this example something seems to go wrong (check ESS parameter 4): (data at https://gist.github.com/mschauer/96ffd226d91160b1c9288200c13a77d5)



julia> MCMCChains.Chains(readdlm("badsamples.dat")')
Chains MCMC chain (2000×25×1 reshape(adjoint(::Matrix{Float64}), 2000, 25, 1) with eltype Float64):

Iterations        = 1:1:2000
Number of chains  = 1
Samples per chain = 2000
parameters        = param_1, param_2, param_3, param_4, param_5, param_6, param_7, param_8, param_9, param_10, param_11, param_12, param_13, param_14, param_15, param_16, param_17, param_18, param_19, param_20, param_21, param_22, param_23, param_24, param_25

Summary Statistics
  parameters      mean       std   naive_se      mcse            ess      rhat 
      Symbol   Float64   Float64    Float64   Float64        Float64   Float64 

     param_1    3.4085    1.2065     0.0270    0.0183       955.7380    0.9995
     param_2   -0.5922    0.0681     0.0015    0.0008      4861.8138    1.0003
     param_3    0.0352    0.0087     0.0002    0.0001     10913.2548    0.9995
     param_4   -0.3874    0.0859     0.0019    0.0007   -169293.0177    0.9996
     param_5    0.0045    0.0038     0.0001    0.0001      3396.2860    0.9995
     param_6   -0.2336    0.0598     0.0013    0.0005     23391.6533    0.9997
@mschauer
Copy link
Author

The samples are slightly antithetic as seen by a scatter plot pairs of consecutive samples:

Screen Shot 2022-06-28 at 16 46 14

@devmotion
Copy link
Member

Maybe I can inspect it later today but I don't have time right now. Some quick thoughts:

  • Interestingly, the rhat value seems more reasonable (dvs not as surprising) - they are both calculated in the same function
  • It might be easier to rerun the analysis with MCMCDiagnosticToos.ess_rhat directly and without MCMCChains objects
  • Did you compare all different algorithms of MCMCDiagnosticTools.ess_rhat?
  • Does it make a difference if you increase the number of max lags?
  • Maybe one of the first few terms here is already negative and the summation is aborted too quickly?

@devmotion
Copy link
Member

devmotion commented Jun 30, 2022

OK, so I did some experiments. It's basically as I expected, it seems. The maxlag does not matter, the estimates are the same even if maxlag = 1, i.e., if only the initial pair of rhos is included. As expected, the default and the FFT-algorithm yield basically the same estimates (expected since the only difference is that the autocovariance is computed manually in one case and with FFT in the other case). The BDAESSMethod (only difference is that the autocorrelation is estimated with the variogram estimator discussed in the BDA book) yields estimates of the same order (even more negative though). These are - again unsurprisingly - similar to the estimates of MCMCDiagnostics (not identically since MCMCDiagnosticTools implements some of the improvements in the recent Vehtari paper which are not included in MCMCDiagnostics AFAIK). ArviZ seems to return generally more stable (read: less surprising) estimates but I don't know their implementation. The R package posterior, on which the implementation in MCMCDiagnosticTools is partially based, has some additional improvements that are supposed to improve stability of the estimates according the comments in the code. When I implemented ess_rhat I omitted these on purpose to keep the code simpler and since I did not know any reference for these modifications and how relevant they are in practice. However, maybe we should add them here. I'll check if they help:

In general, we want to add the other improvements in the Vehtari paper as well (see #22). I wonder if that would be helpful for this example as well.

@sethaxen
Copy link
Member

OK, so I did some experiments.

I looked into this a little as well. The comparable method in posterior to compare to is ess_basic. In ArviZ, we need to compare to ess with method=:mean. These posterior and ArviZ methods produce identical ESS estimates.

julia> using MCMCDiagnosticTools, MCMCChains, ArviZ, RCall, DelimitedFiles, PyCall, Plots

julia> R"require('posterior')";

julia> chns = MCMCChains.Chains(readdlm("badsamples.dat")');

julia> ess_df = ess_rhat(chns);

julia> ess_mcmcdt = [only(ess_df[k].nt.ess) for k in keys(chns)];

julia> ess_ds = ArviZ.ess(chns; method=:mean);

julia> ess_arviz = [ess_ds[k].values[] for k in keys(chns)];

julia> ess_post = [R"ess_basic($(chns[k]))"[1] for k in keys(chns)];

julia> maximum(abs, ess_post - ess_arviz)  # posterior and ArviZ completely agree
7.275957614183426e-12

julia> scatter(ess_mcmcdt, ess_arviz; xlabel="MCMCDT", ylabel="ArviZ", legend=false)

tmp

In general, we want to add the other improvements in the Vehtari paper as well (see #22). I wonder if that would be helpful for this example as well.

The ess_bulk method in posterior and ess with method=:bulk (default) in ArviZ both implement the recommended approach in Vehtari's paper. While it is an improved method, it doesn't resolve the issues here (one ESS very negative, also 5 ESS estimates much higher than returned by ArviZ or posterior):

julia> ess_ds_bulk = ArviZ.ess(chns; method=:bulk);

julia> ess_arviz_bulk = [ess_ds_bulk[k].values[] for k in keys(chns)];

julia> ess_post_bulk = [R"ess_bulk($(chns[k]))"[1] for k in keys(chns)];

julia> maximum(abs, ess_post_bulk - ess_arviz_bulk)  # ArviZ and posterior also identical
4.547473508864641e-12

julia> scatter(ess_mcmcdt, ess_arviz; xlabel="MCMCDT", ylabel="ArviZ", label="basic", msw=0, alpha=0.5)

julia> scatter!(ess_mcmcdt, ess_arviz_bulk; label="bulk", xlims=(0, Inf), msw=0, alpha=0.5, legend=:bottomright)

tmp2

When I implemented ess_rhat I omitted these on purpose to keep the code simpler and since I did not know any reference for these modifications and how relevant they are in practice. However, maybe we should add them here. I'll check if they help:

* Adding an additional term that supposedly improves estimates in the antithetic case: https://github.com/stan-dev/posterior/blob/c8b5739ee889d0f5ce03a2918bda8bea1d0164e5/R/convergence.R#L749-L751 and https://github.com/stan-dev/posterior/blob/c8b5739ee889d0f5ce03a2918bda8bea1d0164e5/R/convergence.R#L766-L767

* Lower bounding tau, i.e., upper bounding ESS by `#samples * log10(#samples) = #draws * #chains * log10(#draws * #chains)`: https://github.com/stan-dev/posterior/blob/c8b5739ee889d0f5ce03a2918bda8bea1d0164e5/R/convergence.R#L768-L774

Yeah, it's unfortunate how outdated the texts get as improvements are made to the methods. ArviZ made these changes when adding the other ESS approaches referenced in #22, and I'm pretty sure is just a straight port from posterior:

I agree we should see if these make a difference here.

@mschauer
Copy link
Author

https://github.com/stan-dev/posterior/blob/c8b5739ee889d0f5ce03a2918bda8bea1d0164e5/R/convergence.R#L766-L767

Looks like this was already contained in the initial commit of the GitHub code by @paul-buerkner

@sethaxen
Copy link
Member

It seems several of these may have first appeared in this Stan PR: stan-dev/stan#2774 I'll look through the discussion later.

@devmotion
Copy link
Member

Ah yes, sorry for the confusion - my comment above was incorrect: actually when I implemented ess_rhat I compared my code with the c++ version, not the R implementation.

@sethaxen
Copy link
Member

sethaxen commented Jan 3, 2023

I ran down the original sources for the modifications in #40 (comment). The original source is actually the Rank-normalized ESS/R-hat paper:
image

These were first introduced to rstan in stan-dev/rstan#618, patterned after the reference implementation in the paper repo: https://github.com/avehtari/rhat_ess/blob/44d5b15173d724bfc028065877183d16961fd2f3/code/monitornew.R

Based on this, I believe we should make the corresponding modifications to our implementations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants