From 7d9a93f6fd0b5ee07784c9105ad669d37d841825 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Jan 2021 17:28:15 -0500 Subject: [PATCH 1/2] Make TransformReparam compatible with .to_event() --- pyro/infer/reparam/transform.py | 12 ++++++++++-- tests/infer/reparam/test_transform.py | 24 +++++++++++++----------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/pyro/infer/reparam/transform.py b/pyro/infer/reparam/transform.py index 4eb14a7dba..ec271e4f88 100644 --- a/pyro/infer/reparam/transform.py +++ b/pyro/infer/reparam/transform.py @@ -20,15 +20,23 @@ class TransformReparam(Reparam): """ def __call__(self, name, fn, obs): assert obs is None, "TransformReparam does not support observe statements" + fn, event_dim = self._unwrap(fn) assert isinstance(fn, dist.TransformedDistribution) # Draw noise from the base distribution. - x = pyro.sample("{}_base".format(name), fn.base_dist) + base_event_dim = event_dim + try: # requires https://github.com/pyro-ppl/pyro/pull/2739 + for t in reversed(fn.transforms): + base_event_dim += t.domain.event_dim - t.codomain.event_dim + except AttributeError: + pass + x = pyro.sample("{}_base".format(name), + self._wrap(fn.base_dist, base_event_dim)) # Differentiably transform. for t in fn.transforms: x = t(x) # Simulate a pyro.deterministic() site. - new_fn = dist.Delta(x, event_dim=fn.event_dim) + new_fn = dist.Delta(x, event_dim=event_dim).mask(False) return new_fn, x diff --git a/tests/infer/reparam/test_transform.py b/tests/infer/reparam/test_transform.py index 07d3f7e89e..8fa4d5e096 100644 --- a/tests/infer/reparam/test_transform.py +++ b/tests/infer/reparam/test_transform.py @@ -27,29 +27,31 @@ def get_moments(x): return torch.stack([m1, m2, m3, m4]) -@pytest.mark.parametrize("shape", [(), (4,), (2, 3)], ids=str) -def test_log_normal(shape): +@pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str) +@pytest.mark.parametrize("event_shape", [(), (5,)], ids=str) +def test_log_normal(batch_shape, event_shape): + shape = batch_shape + event_shape loc = torch.empty(shape).uniform_(-1, 1) scale = torch.empty(shape).uniform_(0.5, 1.5) def model(): - with pyro.plate_stack("plates", shape): + fn = dist.TransformedDistribution( + dist.Normal(torch.zeros_like(loc), torch.ones_like(scale)), + [AffineTransform(loc, scale), ExpTransform()]) + fn = fn.to_event(len(event_shape)) + with pyro.plate_stack("plates", batch_shape): with pyro.plate("particles", 200000): - return pyro.sample("x", - dist.TransformedDistribution( - dist.Normal(torch.zeros_like(loc), - torch.ones_like(scale)), - [AffineTransform(loc, scale), - ExpTransform()])) + return pyro.sample("x", fn) with poutine.trace() as tr: value = model() - assert isinstance(tr.trace.nodes["x"]["fn"], dist.TransformedDistribution) + assert isinstance(tr.trace.nodes["x"]["fn"], + (dist.TransformedDistribution, dist.Independent)) expected_moments = get_moments(value) with poutine.reparam(config={"x": TransformReparam()}): with poutine.trace() as tr: value = model() - assert isinstance(tr.trace.nodes["x"]["fn"], dist.Delta) + assert isinstance(tr.trace.nodes["x"]["fn"], (dist.Delta, dist.MaskedDistribution)) actual_moments = get_moments(value) assert_close(actual_moments, expected_moments, atol=0.05) From 2a383164c331b2495f21bf62823c340beb6111c6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Jan 2021 17:54:21 -0500 Subject: [PATCH 2/2] Strengthen test --- tests/infer/reparam/test_transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/infer/reparam/test_transform.py b/tests/infer/reparam/test_transform.py index 8fa4d5e096..d0dc746f87 100644 --- a/tests/infer/reparam/test_transform.py +++ b/tests/infer/reparam/test_transform.py @@ -38,7 +38,8 @@ def model(): fn = dist.TransformedDistribution( dist.Normal(torch.zeros_like(loc), torch.ones_like(scale)), [AffineTransform(loc, scale), ExpTransform()]) - fn = fn.to_event(len(event_shape)) + if event_shape: + fn = fn.to_event(len(event_shape)) with pyro.plate_stack("plates", batch_shape): with pyro.plate("particles", 200000): return pyro.sample("x", fn)