Skip to content

Commit

Permalink
fix summation with plates when manually unrolling (#2541)
Browse files Browse the repository at this point in the history
* fix summation with plates when manually unrolling

issue: #2361

* add test for manual unrolling of plates

* drop unnecessary assertions and interations

 - only need to ensure that the model runs without crashing
  • Loading branch information
iffsid authored Jun 28, 2020
1 parent 5e8d475 commit 2b3cbbe
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_dependent_plate_dims(sites):
all_plates = set().union(*plate_sets)
common_plates = all_plates.intersection(*plate_sets)
sum_plates = all_plates - common_plates
sum_dims = list(sorted(f.dim for f in sum_plates))
sum_dims = list(sorted(f.dim for f in sum_plates if f.dim is not None))
return sum_dims


Expand Down
26 changes: 26 additions & 0 deletions tests/infer/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,29 @@ def model():
loss = svi.step()
if step % 20 == 0:
logger.info("step {} loss = {:0.4g}".format(step, loss))


@pytest.mark.stage("integration", "integration_batch_1")
def test_sequential_plating_sum():
"""Example from https://github.com/pyro-ppl/pyro/issues/2361"""

def model(data):
x = pyro.sample('x', dist.Bernoulli(torch.tensor(0.5)))
for i in pyro.plate('data_plate', len(data)):
pyro.sample('data_{:d}'.format(i),
dist.Normal(x, scale=torch.tensor(0.1)),
obs=data[i])

def guide(data):
p = pyro.param('p', torch.tensor(0.5))
pyro.sample('x', pyro.distributions.Bernoulli(p))

data = torch.cat([torch.randn([5]), 1. + torch.randn([5])])
adam = optim.Adam({"lr": 0.01})
loss_fn = RenyiELBO(alpha=0, num_particles=30, vectorize_particles=True)
svi = SVI(model, guide, adam, loss_fn)

for step in range(1):
loss = svi.step(data)
if step % 20 == 0:
logger.info("step {} loss = {:0.4g}".format(step, loss))

0 comments on commit 2b3cbbe

Please sign in to comment.