Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FR ExpandedDistribution for easier custom distributions #2186

Closed
fritzo opened this issue Nov 22, 2019 · 3 comments · Fixed by #2209
Closed

FR ExpandedDistribution for easier custom distributions #2186

fritzo opened this issue Nov 22, 2019 · 3 comments · Fixed by #2209
Assignees
Milestone

Comments

@fritzo
Copy link
Member

fritzo commented Nov 22, 2019

As of #1377 it is onerous to define custom distributions: each new distribution requires a custom .expand() method --a method that is difficult to understand and inconsistently implemented among reference distributions. This issue proposes to resurrect ReshapedDistribution as a simpler ExpandedDistribution and to use that as a fallback in case .expand() is not implemented.

Request: classes deriving from TorchDistribution need not implement .expand()

@neerajprad
Copy link
Member

neerajprad commented Nov 22, 2019

This issue proposes to resurrect ReshapedDistribution as a simpler ExpandedDistribution and to use that as a fallback in case .expand() is not implemented.

We can easily resurrect this. The reason why we had a .expand method (first in Pyro, and later moved upstream) was to allow for a generic broadcasting mechanism where distributions with shape like (10, 1, 2) could be broadcasted to (10, 100, 2) when inside a plate messenger. To have this generic broadcasting behavior, we will need to implement a custom .expand per distribution, unfortunately. But we can have a fallback like - use .expand, use ReshapedDistribution (for broadcast on left), throw exception. Does that seem reasonable?

@fritzo
Copy link
Member Author

fritzo commented Nov 22, 2019

Thanks for explaining, I had forgotten about our early difficulty with interstitial dims. However, in retrospect I believe we could support interstitial dims in both ExpandedDistribution.log_prob() and .rsample() by combining .transpose() and .reshape() logic. Do you agree?

Let me clarify motivation.
When a user wants to implement a custom inference algorithm for part of a Pyro model, there are basically two ways: (1) implementing a custom TorchDistribution subclass that encapsulates their logic in a shape-safe way, or (2) calling pyro.factor() with a relatively shape-unsafe loss. Since our initial 0.1 release, I have always advocated for custom distributions over pyro.factor() statements. This issue aims to lower the bar for implementing those custom distributions.

@neerajprad
Copy link
Member

This issue aims to lower the bar for implementing those custom distributions.

That makes sense.

I believe we could support interstitial dims in both ExpandedDistribution.log_prob() and .rsample() by combining .transpose() and .reshape() logic. Do you agree?

I have thought about this a couple of times. :) I don't fully remember the constraints that we were operating in under the time. I think one reason was that it might be hard to support distributions that are manually batched in a generic way, but if we only rely on poutine.broadcast, this should be pretty straightforward. I also remember ReshapedDistribution.expand being a source of a few bugs. I do think that this should be straightforward at least for the large majority of cases, if not all (as an example, numpyro doesn't have .expand and we haven't so far needed it). I'll update this after thinking about it some more.

@fritzo fritzo self-assigned this Nov 23, 2019
@fritzo fritzo added this to the 1.1 milestone Dec 3, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants