Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different dimension permutations #5

Open
sethaxen opened this issue Jun 13, 2021 · 18 comments
Open

Support different dimension permutations #5

sethaxen opened this issue Jun 13, 2021 · 18 comments

Comments

@sethaxen
Copy link
Member

The functions currently assume draws are in a single array with shape (draws, chains, params), like MCMCChains.Chains stores. We should consider how to support different permutations (could be as simple as recommending users use PermuteDimsArrays). e.g. ArviZ defaults to (chains, draws, params). Not certain about Soss's SampleChains.

@devmotion
Copy link
Member

The functions currently assume draws are in a single array with shape (draws, chains, params)

They either work with vectors of draws or arrays of shape (draws, params, chains).

@sethaxen
Copy link
Member Author

How would one pass a vector of draws to e.g. ess_rhat?

@devmotion
Copy link
Member

No, it's not supported by ess_rhat. All functions either work with vectors or arrays of shape (draws, chains, params), but not both.

@ParadaCarleton
Copy link
Member

The functions currently assume draws are in a single array with shape (draws, chains, params), like MCMCChains.Chains stores. We should consider how to support different permutations (could be as simple as recommending users use PermuteDimsArrays). e.g. ArviZ defaults to (chains, draws, params). Not certain about Soss's SampleChains.

And I think Stan defaults to (params, chains, draws), giving us a head start on our apparent task of going through every possible permutation of indices.

@devmotion
Copy link
Member

I don't think this package should support different permutations. IMO we need a clearly documented and consistent choice for such situations but then users have to permute the arrays if their data is in a different format.

And just to reiterate, not all statistics work with 3d arrays of samples. Some work just with a vector of scalar-valued samples (one parameter, one chain) and I don't think this should be changed. Also the Rstar statistic works with a matrix of samples of size (draws, params) and a corresponding vector of chain indices (this is more general than a 3d array).

@devmotion
Copy link
Member

I should add that I don't think we have to stick with the current convention but it would be good to be consistent for 3d arrays. I guess the choice should be motivated by what is the most convenient and efficient layout. Since Julia uses column-major order and Python row-major, probably it differs from what one would choose in Python.

@ParadaCarleton
Copy link
Member

I should add that I don't think we have to stick with the current convention but it would be good to be consistent for 3d arrays. I guess the choice should be motivated by what is the most convenient and efficient layout. Since Julia uses column-major order and Python row-major, probably it differs from what one would choose in Python.

I mean, being able to work with different permutations of indices comes as a free side effect of using something like AxisKeys.jl, which I think we should be using anyways to avoid bugs from messing up the order we're indexing in. We can always provide a free "Fallback" that assumes dimensions are ordered in some clearly specified way.

@devmotion
Copy link
Member

Please no, let's just stick with generic AbstractArrays - one main motivation here is to get rid of the Chains/AxisArrays mess and just be as generic as possible.

@ParadaCarleton
Copy link
Member

I think the best arrangement would be (parameters, draws, chains). Operations on chains are usually done in parallel, e.g. sampling a different chain on each core, so it's not necessary to have chains located close to each other in memory. Parameters are sampled together, so they should be located close in memory so that every time one parameter is written to memory, the next parameter can be written to memory pretty easily. Iterations should be located somewhat close to each other, but aren't always accessed together the same way that parameters from a single iteration usually are. Please correct me if I'm wrong, I'm not a computer scientist.

@ParadaCarleton
Copy link
Member

ParadaCarleton commented Jun 23, 2021

No, it's not supported by ess_rhat. All functions either work with vectors or arrays of shape (draws, chains, params), but not both.

I believe ess_rhat currently uses (draws, params, chains). (An arrangement I find extremely counterintuitive, since chains and draws aren't together.)

@devmotion
Copy link
Member

No, it's not supported by ess_rhat. All functions either work with vectors or arrays of shape (draws, chains, params), but not both.

I believe ess_rhat currently uses (draws, params, chains). (An arrangement I find extremely counterintuitive, since chains and draws aren't together.)

Ah yes, this is what I wanted to write and what is mentioned in the documentation. The reason for this layout is that it is the one used in MCMCChains.Chains. However, I don't know what motivated the design choice there, this was decided before I got involved in MCMCChains. Maybe @cpfiffer knows why it was preferred over (parameters, draws, chains)?

@cpfiffer
Copy link
Member

cpfiffer commented Jun 23, 2021

I think this was just a holdover from Mamba -- it was really something I hadn't considered.

@sethaxen
Copy link
Member Author

There are two different ways to think about this. One is to reason about what permutation users are likely to pass (which leads to reasoning about what a sensible ordering in memory would be). However, a given PPL may not deliver that ordering (e.g. MCMCChains and IIRC SampleChains). The other is to think about the permutation that would be most efficient for a given function. However, different functions might prefer different permutations, and we should be consistent.

Ideally these two would converge, but I don't know if they do.

@ParadaCarleton
Copy link
Member

ParadaCarleton commented Jun 23, 2021

I think any differences in terms of actual speed/efficiency are probably pretty small -- writing MCMC samples to memory is not going to be the bottleneck for something like HMC under any reasonable set of circumstances. Because of that, I think choices of index orderings should probably be based on what users are most likely to consider intuitive/reasonable orderings. In that case, I think either of (params, draws, chains) or (chains, draws, params) are the most intuitive, since they involve going from more general to more specific, or more specific to more general. (Multiple parameters are contained in a single draw, and several draws are contained in a single chain.) The former has the advantage that it's easier to leave off indices for chains when users are only sampling from a single chain, so I propose we go with that unless anyone has any strong objections.

@ParadaCarleton
Copy link
Member

Does anybody object, or should I make a pull request reordering axes like this? I've written some PSIS code that assumes this ordering for ess_rhat, and would like to know whether I should rearrange the PSIS code or the ordering for ess_rhat.

@ParadaCarleton
Copy link
Member

ParadaCarleton commented Jun 30, 2021

@sethaxen @devmotion I can create a PR reordering these axes, and implementing a consistent interface that works with arrays following this standard.

In cases where the interface isn't consistent, there's usually a fix that will make the function easier to work with from a user perspective. For instance, if a function only accepts one chain, then calling it on an array of multiple chains should return a vector with the results of applying it to each individual chain. (We can, of course, keep the original function for users who want to only call it on a single chain.) The goal should be to provide a polished, ArviZ-like interface that "Just works," and lets the user pass a single object to each function, rather than having to figure out how they need to permute, cut up, or use eachslice on their arrays to get the diagnostic they're looking for. This has already been done for a handful of functions, but not all of them.

@devmotion
Copy link
Member

In my opinion, we should not "unify" functions that operate on single chains by moving them to an interface that works on 3d arrays of multiple chains with multiple parameters. It just makes it more complicated for downstream packages that use a different layout to use these diagnostics (e.g. vectors of chains of possibly different length, chains based on StructArrays etc. - in particular the StructArray approach is a longstanding issue/idea that we want to explore as an alternative to MCMCChains.Chains) and I don't see any immediate advantages even for the 3d case - you can always apply the function on the different slices of the array and e.g. even in this case sometimes you might want to pool different chains.

In general, I think any changes of the input layouts or dimensions should be left for a 0.2 or even more distant release but not included in the initial 0.1 version since otherwise it will just be more difficult to replace the diagnostics in MCMCChains with this package and the transition will be less clean (it would require many additional changes in TuringLang/MCMCChains.jl#310). So unfortunately currently I would not approve any such PR that changes the input layout of any diagnostics but I think we should consider if we want to change the default permutation of 3d arrays at a later stage.

@ParadaCarleton
Copy link
Member

Sure, we can handle this later if you want.

As for the other thing, I'm not suggesting that we replace the existing methods or get rid of them, just that we provide additional methods that handle everything for users who use the "default" arrangement, which I expect would be a three-dimensional array. Not every method needs to have more complicated, but every method should accept a 3d array as input (assuming it makes any kind of sense for it to accept that array). If someone wants to work with a matrix and a vector of chain indices, they can use the methods we already have without being bothered by the fact that we have another one that works on arrays. On the other hand, users who already have their data stored in an array shouldn't have to spend even more time cleaning their data than they already do. Figuring out how to e.g. disassemble an array and convert it into a matrix representation with a bunch of chain indices is going to be a pretty big waste of time for users; why not just have it work out of the box for them?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants