Skip to content

Commit

Permalink
add missing annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad committed Jun 6, 2018
1 parent 4995fcf commit 76fa97c
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/infer/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ def test_subsample_gradient(Elbo, reparameterized, subsample):
precision = 0.06
Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

@poutine.broadcast
def model(subsample):
with pyro.iarange("data", len(data), subsample_size, subsample) as ind:
x = data[ind]
z = pyro.sample("z", Normal(0, 1))
pyro.sample("x", Normal(z, 1), obs=x)

@poutine.broadcast
def guide(subsample):
loc = pyro.param("loc", lambda: torch.zeros(len(data), requires_grad=True))
scale = pyro.param("scale", lambda: torch.tensor([1.0], requires_grad=True))
Expand Down

0 comments on commit 76fa97c

Please sign in to comment.