-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RenyiELBO no longer compatible with sequential plates? [question] [documentation] #2361
Comments
@fritzo I believe this is a valid fix? sum_dims = list(sorted(f.dim for f in sum_plates if f.dim is not None)) https://github.com/pyro-ppl/pyro/blob/dev/pyro/infer/util.py#L116 I've just got around to putting-in the plating fix (#2321) into rws, and it looks like this fix works for a simple sequential plate situation as above. Any reason why this might be incorrect? |
@iffsid Good catch! Thanks for debugging this issue! Could you submit a PR so that we can see if all other tests pass with the change? |
* fix summation with plates when manually unrolling issue: #2361 * add test for manual unrolling of plates * drop unnecessary assertions and interations - only need to ensure that the model runs without crashing
Fixed by #2541 |
Thanks for addressing this in 1.4.0! However, I'm seeing another issue with Thanks again! 😃 Code snippet: import sys, torch, pyro
print(sys.version, torch.__version__, pyro.__version__)
# Generative model: data = x @ weights + eps
def model(data, weights):
loc = torch.tensor(1.0)
scale = torch.tensor(0.1)
# Sample latents (shares no dimensions with data)
with pyro.plate('x_plate', weights.shape[0]):
x = pyro.sample('x', pyro.distributions.Normal(loc, scale))
# Combine with weights and sample
with pyro.plate('data_plate_1', data.shape[-1]):
with pyro.plate('data_plate_2', data.shape[-2]):
pyro.sample('data', pyro.distributions.Normal(x @ weights, scale), obs=data)
return
def guide(data, weights):
loc = pyro.param('x_loc', torch.tensor(0.5))
scale = torch.tensor(0.1)
with pyro.plate('x_plate', weights.shape[0]):
x = pyro.sample('x', pyro.distributions.Normal(loc, scale))
return
# Works with other ELBO
loss = pyro.infer.TraceGraph_ELBO(num_particles=3, vectorize_particles=True)
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({'lr': 0.01}), loss)
svi.step(torch.randn([5, 3]), torch.randn([2, 3]))
# Dim colision with RenyiELBO
loss = pyro.infer.RenyiELBO(num_particles=7, vectorize_particles=True)
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({'lr': 0.01}), loss)
svi.step(torch.randn([5, 3]), torch.randn([2, 3])) Output:
|
From issue pyro-ppl#2361. Remove repeated dimensions in `get_dependent_plate_dims()`, which makes sure the output can be passed to e.g. `torch.sum()`. Minor change: `list(sorted())` simplified as `sorted` returns a `list` directly.
Based on example in pyro-ppl#2361, and follows format of similar test introduced by pyro-ppl#2541.
* Fix summation over non-nested plates sharing dimensions From issue #2361. Remove repeated dimensions in `get_dependent_plate_dims()`, which makes sure the output can be passed to e.g. `torch.sum()`. Minor change: `list(sorted())` simplified as `sorted` returns a `list` directly. * Add regression test Based on example in #2361, and follows format of similar test introduced by #2541. Co-authored-by: Sam Harrison <harrison@biomed.ee.ethz.ch>
Is this now resolved? |
Yes, I think so |
Issue Description
In Pyro 1.3.0, the behaviour of
RenyiELBO
changed. In particular, it now seems to be incompatible with sequential plates. See MWE below.Questions:
RenyiELBO
in Pyro < 1.3.0 (i.e. what was supported before Support plates in RenyiELBO #2321)?Many thanks! 😄
Code Snippet
Output:
The text was updated successfully, but these errors were encountered: