diff --git a/tests/infer/test_gradient.py b/tests/infer/test_gradient.py index 656b5dda61..011665a559 100644 --- a/tests/infer/test_gradient.py +++ b/tests/infer/test_gradient.py @@ -29,12 +29,14 @@ def test_subsample_gradient(Elbo, reparameterized, subsample): precision = 0.06 Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal + @poutine.broadcast def model(subsample): with pyro.iarange("data", len(data), subsample_size, subsample) as ind: x = data[ind] z = pyro.sample("z", Normal(0, 1)) pyro.sample("x", Normal(z, 1), obs=x) + @poutine.broadcast def guide(subsample): loc = pyro.param("loc", lambda: torch.zeros(len(data), requires_grad=True)) scale = pyro.param("scale", lambda: torch.tensor([1.0], requires_grad=True))