From 2f89668b5ab994a9824e765d52ce286bd85e4e00 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Oct 2021 11:21:03 -0400 Subject: [PATCH 01/16] Attempt to fix AutoGaussian dispatch --- pyro/infer/autoguide/gaussian.py | 5 +++-- tests/infer/autoguide/test_gaussian.py | 4 ++++ tests/infer/test_autoguide.py | 26 ++++++++++++++++++++------ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 4c763de714..12869afd71 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -41,8 +41,9 @@ def __init__(cls, *args, **kwargs): cls.backends[key] = cls def __call__(cls, *args, **kwargs): - backend = kwargs.pop("backend", cls.default_backend) - cls = cls.backends[backend] + if cls is AutoGaussian: + backend = kwargs.pop("backend", cls.default_backend) + cls = cls.backends[backend] return super(AutoGaussianMeta, cls).__call__(*args, **kwargs) diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 9387ec64d1..a35a6fc995 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -81,8 +81,12 @@ def model(): guide = AutoGaussian(model, backend=backend) if backend == "dense": assert isinstance(guide, AutoGaussianDense) + guide = AutoGaussianDense(model) + assert isinstance(guide, AutoGaussianDense) elif backend == "funsor": assert isinstance(guide, AutoGaussianFunsor) + guide = AutoGaussianFunsor(model) + assert isinstance(guide, AutoGaussianFunsor) else: raise ValueError(f"Unknown backend: {backend}") diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 69bed3bd31..ce8c61cf68 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -14,7 +14,14 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.infer import SVI, Predictive, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO +from pyro.infer import ( + SVI, + JitTrace_ELBO, + Predictive, + Trace_ELBO, + TraceEnum_ELBO, + TraceGraph_ELBO, +) from pyro.infer.autoguide import ( AutoCallable, AutoDelta, @@ -1277,7 +1284,9 @@ def model(data): expected_loss = float(g.event_logsumexp() - g.condition(data).event_logsumexp()) guide = Guide(model) - elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) + elbo = JitTrace_ELBO( + num_particles=100, vectorize_particles=True, ignore_jit_warnings=True + ) optim = Adam({"lr": 0.01}) svi = SVI(model, guide, optim, elbo) for step in range(500): @@ -1335,10 +1344,13 @@ def model(data): ) guide = Guide(model) - elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) - optim = Adam({"lr": 0.01}) + elbo = JitTrace_ELBO( + num_particles=100, vectorize_particles=True, ignore_jit_warnings=True + ) + num_steps = 500 + optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) svi = SVI(model, guide, optim, elbo) - for step in range(500): + for step in range(num_steps): svi.step(data) guide.requires_grad_(False) @@ -1402,7 +1414,9 @@ def model(data): expected_loss = float(g.event_logsumexp() - g_cond.event_logsumexp()) guide = Guide(model) - elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) + elbo = JitTrace_ELBO( + num_particles=100, vectorize_particles=True, ignore_jit_warnings=True + ) num_steps = 500 optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) svi = SVI(model, guide, optim, elbo) From 532f2f132a2ec6f768e9fc56cd9ae9d1335d5b26 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 20 Oct 2021 18:34:25 -0400 Subject: [PATCH 02/16] Add xfailing tests --- pyro/infer/autoguide/gaussian.py | 29 ++-- tests/infer/autoguide/test_gaussian.py | 191 ++++++++++++++++++++----- 2 files changed, 176 insertions(+), 44 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index bc0bebbdd2..c993804091 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -334,17 +334,17 @@ def _setup_prototype(self, *args, **kwargs): self._dense_scatter[d] = index.reshape(-1) def _sample_aux_values(self) -> Dict[str, torch.Tensor]: - # Sample from a dense joint Gaussian over flattened variables. - precision = self._get_precision() - loc = precision.new_zeros(self._dense_size) flat_samples = pyro.sample( - f"_{self._pyro_name}", - dist.MultivariateNormal(loc, precision_matrix=precision), + f"_{self._pyro_name}_latent", + self._dense_get_mvn(), infer={"is_auxiliary": True}, ) - sample_shape = flat_samples.shape[:-1] + samples = self._dense_unflatten(flat_samples) + return samples - # Convert flat to shaped tensors. + def _dense_unflatten(self, flat_samples: torch.Tensor) -> Dict[str, torch.Tensor]: + # Convert a single flattened sample to a dict of shaped samples. + sample_shape = flat_samples.shape[:-1] samples = {} pos = 0 for d, (batch_shape, event_shape) in self._dense_shapes.items(): @@ -357,14 +357,25 @@ def _sample_aux_values(self) -> Dict[str, torch.Tensor]: ) return samples - def _get_precision(self): + def _dense_flatten(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor: + # Convert a dict of shaped samples single flattened sample. + flat_samples = [] + for d, (batch_shape, event_shape) in self._dense_shapes.items(): + shape = samples[d].shape + sample_shape = shape[: len(shape) - len(batch_shape) - len(event_shape)] + flat_samples.append(samples[d].reshape(sample_shape + (-1,))) + return torch.cat(flat_samples, dim=-1) + + def _dense_get_mvn(self): + # Create a dense joint Gaussian over flattened variables. flat_precision = torch.zeros(self._dense_size ** 2) for d, index in self._dense_scatter.items(): sqrt = deep_getattr(self.factors, d) precision = sqrt @ sqrt.transpose(-1, -2) flat_precision.scatter_add_(0, index, precision.reshape(-1)) precision = flat_precision.reshape(self._dense_size, self._dense_size) - return precision + loc = precision.new_zeros(self._dense_size) + return dist.MultivariateNormal(loc, precision_matrix=precision) class AutoGaussianFunsor(AutoGaussian): diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 6f35ada3b2..c85d63991e 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -18,7 +18,7 @@ ) from pyro.infer.reparam import LocScaleReparam from pyro.optim import Adam -from tests.common import assert_equal, xfail_if_not_implemented +from tests.common import assert_close, assert_equal, xfail_if_not_implemented BACKENDS = [ "dense", @@ -100,32 +100,138 @@ def check_structure(model, expected_str): parameter.data.normal_() with torch.no_grad(): - precision = guide._get_precision() + # Check flatten & unflatten. + mvn = guide._dense_get_mvn() + expected = mvn.sample() + samples = guide._dense_unflatten(expected) + actual = guide._dense_flatten(samples) + assert_equal(actual, expected) + + # Check sparsity structure. + precision = mvn.precision_matrix actual = precision.abs().gt(1e-5).long() + str_to_number = {"?": 1, ".": 0} + expected = torch.tensor( + [[str_to_number[c] for c in row if c != " "] for row in expected_str] + ) + assert (actual == expected).all() + + +def check_backends_agree(model): + guide1 = AutoGaussian(model, backend="dense") + guide2 = AutoGaussian(model, backend="funsor") + guide1() + with xfail_if_not_implemented(): + guide2() - str_to_number = {"?": 1, ".": 0} - expected = torch.tensor( - [[str_to_number[c] for c in row if c != " "] for row in expected_str] + # Inject random noise into all unconstrained parameters. + params1 = dict(guide1.named_parameters()) + params2 = dict(guide2.named_parameters()) + assert set(params1) == set(params2) + for k, v in params1.items(): + v.data.normal_() + params2[k].data.copy_(v.data) + + # Check that densities agree between backends. + with torch.no_grad(), poutine.trace() as tr: + aux = guide2._sample_aux_values() + flat = guide1._dense_flatten(aux) + tr.trace.compute_log_prob() + log_prob_funsor = tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"] + with torch.no_grad(), poutine.trace() as tr: + with poutine.condition(data={"_AutoGaussianDense_latent": flat}): + guide1._sample_aux_values() + tr.trace.compute_log_prob() + log_prob_dense = tr.trace.nodes["_AutoGaussianDense_latent"]["log_prob"] + assert_equal(log_prob_funsor, log_prob_dense) + + # Check Monte Carlo estimate of entropy. + expected_entropy = guide1._dense_get_mvn().entropy() + with pyro.plate("particle", 100000, dim=-3), poutine.trace() as tr: + guide2._sample_aux_values() + tr.trace.compute_log_prob() + actual_entropy = -tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"].mean() + assert_close(actual_entropy, expected_entropy, atol=1e-1) + + # Check gradients. + names = sorted(params1) + expected_grads = torch.autograd.grad( + expected_entropy, [params1[k] for k in names], allow_unused=True ) - assert_equal(actual, expected) + actual_grads = torch.autograd.grad( + actual_entropy, [params2[k] for k in names], allow_unused=True + ) + for name, actual, expected in zip(names, actual_grads, expected_grads): + assert_close(actual, expected, msg=name) -def test_structure_1(): +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_1(backend): def model(): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(a, 1)) c = pyro.sample("c", dist.Normal(b, 1)) pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) - expected = [ + # size = 1 + 1 + 1 = 3 + structure = [ "? ? .", "? ? ?", ". ? ?", ] - check_structure(model, expected) + check_structure(model, structure) + check_backends_agree(model) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_2(backend): + def model(): + with pyro.plate("i", 2): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(a, 1)) + c = pyro.sample("c", dist.Normal(b, 1)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) + + # size = 2 + 2 + 2 = 6 + structure = [ + "? . ? . . .", + ". ? . ? . .", + "? . ? . ? .", + ". ? . ? . ?", + ". . ? . ? .", + ". . . ? . ?", + ] + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_3(backend): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("i", 2): + b = pyro.sample("b", dist.Normal(a, 1)) + c = pyro.sample("c", dist.Normal(b, 1)) + pyro.sample("d", dist.Normal(c.sum(), 1), obs=torch.tensor(0.0)) + + # size = 1 + 2 + 2 = 5 + structure = [ + "? ? ? . .", + "? ? . ? .", + "? . ? . ?", + ". ? . ? ?", + ". . ? ? ?", + ] + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure) -def test_structure_2(): +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_4(backend): def model(): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(0, 1)) @@ -134,16 +240,20 @@ def model(): pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) # size = 1 + 1 + 2 = 4 - expected = [ - "? . ? ?", - ". ? ? ?", + structure = [ + "? ? ? ?", + "? ? ? ?", "? ? ? .", "? ? . ?", ] - check_structure(model, expected) + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure) -def test_structure_3(): +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_5(backend): I, J = 2, 3 def model(): @@ -158,12 +268,12 @@ def model(): pyro.sample("z", dist.Normal(0, 1), obs=y) # size = 2 + 3 + 2 * 3 = 2 + 3 + 6 = 11 - expected = [ - "? . . . . ? . ? . ? .", - ". ? . . . . ? . ? . ?", - ". . ? . . ? ? . . . .", - ". . . ? . . . ? ? . .", - ". . . . ? . . . . ? ?", + structure = [ + "? . ? ? ? ? . ? . ? .", + ". ? ? ? ? . ? . ? . ?", + "? ? ? . . ? ? . . . .", + "? ? . ? . . . ? ? . .", + "? ? . . ? . . . . ? ?", "? . ? . . ? . . . . .", ". ? ? . . . ? . . . .", "? . . ? . . . ? . . .", @@ -171,10 +281,14 @@ def model(): "? . . . ? . . . . ? .", ". ? . . ? . . . . . ?", ] - check_structure(model, expected) + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure) -def test_structure_4(): +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_6(backend): I, J = 2, 3 def model(): @@ -189,19 +303,23 @@ def model(): pyro.sample("e", dist.Normal(0, 1), obs=d) # size = 1 + 2 + 3 + 1 = 7 - expected = [ + structure = [ "? ? ? . . . .", - "? ? . ? ? ? .", - "? . ? ? ? ? .", - ". ? ? ? . . ?", - ". ? ? . ? . ?", - ". ? ? . . ? ?", + "? ? ? ? ? ? .", + "? ? ? ? ? ? .", + ". ? ? ? ? ? ?", + ". ? ? ? ? ? ?", + ". ? ? ? ? ? ?", ". . . ? ? ? ?", ] - check_structure(model, expected) + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure) -def test_structure_5(): +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_7(backend): def model(): i_plate = pyro.plate("i", 2, dim=-1) with i_plate: @@ -211,12 +329,15 @@ def model(): pyro.sample("c", dist.Normal(b, 1), obs=torch.zeros(2)) # size = 2 + 1 = 3 - expected = [ - "? . ?", - ". ? ?", + structure = [ + "? ? ?", + "? ? ?", "? ? ?", ] - check_structure(model, expected) + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure) @pytest.mark.parametrize("backend", BACKENDS) From 0c4c4acbcbb84f6d32b951acb311191bf0e28ad8 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 20 Oct 2021 22:14:49 -0400 Subject: [PATCH 03/16] Fix bug excluding obs sites from prototype_trace --- pyro/infer/autoguide/gaussian.py | 6 ++++++ pyro/infer/autoguide/guides.py | 10 +++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index c993804091..56827a4842 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -47,6 +47,11 @@ def __call__(cls, *args, **kwargs): return super(AutoGaussianMeta, cls).__call__(*args, **kwargs) +def prototype_hide_fn(msg): + # Record only sample and observe sites in the prototype_trace. + return msg["type"] != "sample" or site_is_subsample(msg) + + class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta): """ Gaussian guide with optimal conditional independence structure. @@ -102,6 +107,7 @@ class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta): """ scale_constraint = constraints.softplus_positive + _prototype_hide_fn = staticmethod(prototype_hide_fn) def __init__( self, diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 27a7fed53a..27d6ac55af 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -149,9 +149,11 @@ def _create_plates(self, *args, **kwargs): self.plates = self.master().plates return self.plates + _prototype_hide_fn = staticmethod(prototype_hide_fn) + def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure - model = poutine.block(self.model, prototype_hide_fn) + model = poutine.block(self.model, self._prototype_hide_fn) self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( *args, **kwargs ) @@ -193,9 +195,7 @@ class AutoGuideList(AutoGuide, nn.ModuleList): """ def _check_prototype(self, part_trace): - for name, part_site in part_trace.nodes.items(): - if part_site["type"] != "sample": - continue + for name, part_site in part_trace.iter_stochastic_nodes(): self_site = self.prototype_trace.nodes[name] assert part_site["fn"].batch_shape == self_site["fn"].batch_shape assert part_site["fn"].event_shape == self_site["fn"].event_shape @@ -1187,7 +1187,7 @@ class AutoDiscreteParallel(AutoGuide): def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure - model = poutine.block(config_enumerate(self.model), prototype_hide_fn) + model = poutine.block(config_enumerate(self.model), self._prototype_hide_fn) self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( *args, **kwargs ) From c271a927ac257a0a7ab2009dfd6bc6764d006345 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Oct 2021 21:04:58 -0400 Subject: [PATCH 04/16] Fix more bugs --- pyro/distributions/unit.py | 1 + pyro/infer/autoguide/gaussian.py | 47 ++++++++++++++++++-------- tests/infer/autoguide/test_gaussian.py | 32 +++++++++++++++--- 3 files changed, 61 insertions(+), 19 deletions(-) diff --git a/pyro/distributions/unit.py b/pyro/distributions/unit.py index 455d5d24d8..7326583ce9 100644 --- a/pyro/distributions/unit.py +++ b/pyro/distributions/unit.py @@ -28,6 +28,7 @@ def __init__(self, log_factor, validate_args=None): super().__init__(batch_shape, event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): + batch_shape = torch.Size(batch_shape) new = self._get_checked_instance(Unit, _instance) new.log_factor = self.log_factor.expand(batch_shape) super(Unit, new).__init__(batch_shape, self.event_shape, validate_args=False) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 56827a4842..2df99909ab 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -14,7 +14,7 @@ import pyro.distributions as dist import pyro.poutine as poutine from pyro.distributions import constraints -from pyro.infer.inspect import get_dependencies +from pyro.infer.inspect import get_dependencies, is_sample_site from pyro.nn.module import PyroModule, PyroParam from pyro.poutine.runtime import am_i_wrapped, get_plates from pyro.poutine.util import site_is_subsample @@ -47,11 +47,6 @@ def __call__(cls, *args, **kwargs): return super(AutoGaussianMeta, cls).__call__(*args, **kwargs) -def prototype_hide_fn(msg): - # Record only sample and observe sites in the prototype_trace. - return msg["type"] != "sample" or site_is_subsample(msg) - - class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta): """ Gaussian guide with optimal conditional independence structure. @@ -107,7 +102,6 @@ class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta): """ scale_constraint = constraints.softplus_positive - _prototype_hide_fn = staticmethod(prototype_hide_fn) def __init__( self, @@ -124,6 +118,12 @@ def __init__( model = InitMessenger(init_loc_fn)(model) super().__init__(model) + @staticmethod + def _prototype_hide_fn(msg): + # In contrast to the AutoGuide base class, this includes observation + # sites and excludes deterministic sites. + return not is_sample_site(msg) + def _setup_prototype(self, *args, **kwargs) -> None: super()._setup_prototype(*args, **kwargs) @@ -142,6 +142,12 @@ def _setup_prototype(self, *args, **kwargs) -> None: "prior_dependencies" ] + # Eliminate observations with no upstream latents. + for d, upstreams in list(self.dependencies.items()): + if all(self.prototype_trace.nodes[u]["is_observed"] for u in upstreams): + del self.dependencies[d] + del self.prototype_trace.nodes[d] + # Collect factors and plates. for d, site in self.prototype_trace.nodes.items(): # Prune non-essential parts of the trace to save memory. @@ -161,13 +167,15 @@ def _setup_prototype(self, *args, **kwargs) -> None: ) if site["is_observed"]: # Eagerly eliminate irrelevant observation plates. - plates &= frozenset.union( + plates &= frozenset().union( *(self._plates[u] for u in self.dependencies[d] if u != d) ) self._plates[d] = plates # Create location-scale parameters, one per latent variable. if site["is_observed"]: + # This may slightly overestimate, e.g. for Multinomial. + self._event_numel[d] = site["fn"].event_shape.numel() continue with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() @@ -191,8 +199,9 @@ def _setup_prototype(self, *args, **kwargs) -> None: for d, site in self._factors.items(): u_size = 0 for u in self.dependencies[d]: - broken_shape = _plates_to_shape(self._plates[u] - self._plates[d]) - u_size += broken_shape.numel() * self._event_numel[u] + if not self._factors[u]["is_observed"]: + broken_shape = _plates_to_shape(self._plates[u] - self._plates[d]) + u_size += broken_shape.numel() * self._event_numel[u] d_size = self._event_numel[d] if site["is_observed"]: d_size = min(d_size, u_size) # just an optimization @@ -200,7 +209,7 @@ def _setup_prototype(self, *args, **kwargs) -> None: # Create a square root parameter (full, not lower triangular). sqrt = init_loc.new_zeros(batch_shape + (u_size, d_size)) - if d in self.dependencies[d]: + if not site["is_observed"]: # Initialize the [d,d] block to the identity matrix. sqrt.diagonal(dim1=-2, dim2=-1).fill_(1) deep_setattr(self.factors, d, PyroParam(sqrt, event_dim=2)) @@ -230,6 +239,8 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # Replay via Pyro primitives. plates = self._create_plates(*args, **kwargs) for name, site in self._factors.items(): + if site["is_observed"]: + continue with ExitStack() as stack: for frame in site["cond_indep_stack"]: stack.enter_context(plates[frame.name]) @@ -260,6 +271,8 @@ def _transform_values( log_densities = defaultdict(float) compute_density = am_i_wrapped() and poutine.get_mask() is not False for name, site in self._factors.items(): + if site["is_observed"]: + continue loc = deep_getattr(self.locs, name) scale = deep_getattr(self.scales, name) unconstrained = aux_values[name] * scale + loc @@ -312,15 +325,19 @@ def _setup_prototype(self, *args, **kwargs): index = torch.zeros(precision_shape, dtype=torch.long) # Collect local offsets. + upstreams = [ + u for u in self.dependencies[d] if not self._factors[u]["is_observed"] + ] local_offsets = {} pos = 0 - for u in self.dependencies[d]: + for u in upstreams: local_offsets[u] = pos broken_plates = self._plates[u] - self._plates[d] + print("DEBUG", d, u, [f.name for f in broken_plates]) pos += self._event_numel[u] * _plates_to_shape(broken_plates).numel() # Create indices blockwise. - for u, v in itertools.product(self.dependencies[d], self.dependencies[d]): + for u, v in itertools.product(upstreams, upstreams): u_index = global_indices[u] v_index = global_indices[v] @@ -421,11 +438,13 @@ def _setup_prototype(self, *args, **kwargs): plate_to_dim: Dict[str, int] = {} for d, site in self._factors.items(): inputs = OrderedDict() - for f in sorted(site["cond_indep_stack"], key=lambda f: f.dim): + for f in sorted(self._plates[d], key=lambda f: f.dim): plate_to_dim[f.name] = f.dim inputs[f.name] = funsor.Bint[f.size] eliminate.add(f.name) for u in self.dependencies[d]: + if self._factors[u]["is_observed"]: + continue inputs[u] = funsor.Reals[self._unconstrained_event_shapes[u]] eliminate.add(u) factor_inputs[d] = inputs diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index c85d63991e..14a8dabafa 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -91,9 +91,11 @@ def model(): raise ValueError(f"Unknown backend: {backend}") -def check_structure(model, expected_str): +def check_structure(model, expected_str, expected_dependencies=None): guide = AutoGaussian(model, backend="dense") guide() # initialize + if expected_dependencies is not None: + assert guide.dependencies == expected_dependencies # Inject random noise into all unconstrained parameters. for parameter in guide.parameters(): @@ -179,8 +181,16 @@ def model(): "? ? ?", ". ? ?", ] - check_structure(model, structure) - check_backends_agree(model) + dependencies = { + "a": {"a": set()}, + "b": {"b": set(), "a": set()}, + "c": {"c": set(), "b": set()}, + "d": {"c": set(), "d": set()}, + } + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure, dependencies) @pytest.mark.parametrize("backend", BACKENDS) @@ -201,10 +211,16 @@ def model(): ". . ? . ? .", ". . . ? . ?", ] + dependencies = { + "a": {"a": set()}, + "b": {"b": set(), "a": set()}, + "c": {"c": set(), "b": set()}, + "d": {"c": set(), "d": set()}, + } if backend == "funsor": check_backends_agree(model) else: - check_structure(model, structure) + check_structure(model, structure, dependencies) @pytest.mark.parametrize("backend", BACKENDS) @@ -224,10 +240,16 @@ def model(): ". ? . ? ?", ". . ? ? ?", ] + dependencies = { + "a": {"a": set()}, + "b": {"b": set(), "a": set()}, + "c": {"c": set(), "b": set()}, + "d": {"c": set(), "d": set()}, + } if backend == "funsor": check_backends_agree(model) else: - check_structure(model, structure) + check_structure(model, structure, dependencies) @pytest.mark.parametrize("backend", BACKENDS) From 46939339ccff32ae5c4d11a07d98a9bc4e8aa29c Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Oct 2021 22:43:39 -0400 Subject: [PATCH 05/16] Fix more tests --- pyro/infer/autoguide/gaussian.py | 13 ++- tests/infer/autoguide/test_gaussian.py | 131 ++++++++++++++++++------- 2 files changed, 108 insertions(+), 36 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 2df99909ab..ad21f3b578 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -166,7 +166,7 @@ def _setup_prototype(self, *args, **kwargs) -> None: "Are you missing a pyro.plate() or .to_event()?" ) if site["is_observed"]: - # Eagerly eliminate irrelevant observation plates. + # Break irrelevant observation plates. plates &= frozenset().union( *(self._plates[u] for u in self.dependencies[d] if u != d) ) @@ -176,6 +176,9 @@ def _setup_prototype(self, *args, **kwargs) -> None: if site["is_observed"]: # This may slightly overestimate, e.g. for Multinomial. self._event_numel[d] = site["fn"].event_shape.numel() + # Account for broken irrelevant observation plates. + for f in set(site["cond_indep_stack"]) - plates: + self._event_numel[d] *= f.size continue with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() @@ -208,7 +211,12 @@ def _setup_prototype(self, *args, **kwargs) -> None: batch_shape = _plates_to_shape(self._plates[d]) # Create a square root parameter (full, not lower triangular). - sqrt = init_loc.new_zeros(batch_shape + (u_size, d_size)) + # We initialize with noise to avoid singular gradient. + sqrt = torch.randn( + batch_shape + (u_size, d_size), + dtype=init_loc.dtype, + device=init_loc.device, + ).mul_(self._init_scale) if not site["is_observed"]: # Initialize the [d,d] block to the identity matrix. sqrt.diagonal(dim1=-2, dim2=-1).fill_(1) @@ -333,7 +341,6 @@ def _setup_prototype(self, *args, **kwargs): for u in upstreams: local_offsets[u] = pos broken_plates = self._plates[u] - self._plates[d] - print("DEBUG", d, u, [f.name for f in broken_plates]) pos += self._event_numel[u] * _plates_to_shape(broken_plates).numel() # Create indices blockwise. diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 14a8dabafa..101468ccc9 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -10,7 +10,7 @@ import pyro.distributions as dist import pyro.poutine as poutine from pyro.infer import SVI, JitTrace_ELBO, Predictive, Trace_ELBO -from pyro.infer.autoguide import AutoGaussian +from pyro.infer.autoguide import AutoGaussian, AutoGuideList from pyro.infer.autoguide.gaussian import ( AutoGaussianDense, AutoGaussianFunsor, @@ -447,6 +447,53 @@ def pyrocov_model(dataset): ) +# This is modified by relaxing rate from deterministic to latent. +def pyrocov_model_relaxed(dataset): + # Tensor shapes are commented at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] + assert weekly_strains.shape == (T, P, S) + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] + rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))[..., None] + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None] + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None] + + # Assume relative growth rate depends strongly on mutations and weakly on place. + coef_loc = torch.zeros(F) + coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1)) # [F] + rate_loc = pyro.sample( + "rate_loc", + dist.Normal(0.01 * coef @ features.T, rate_loc_scale).to_event(1), + ) # [S] + + # Assume initial infections depend strongly on strain and place. + init_loc = pyro.sample( + "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1) + ) # [S] + with pyro.plate("place", P, dim=-1): + rate = pyro.sample( + "rate", dist.Normal(rate_loc, rate_scale).to_event(1) + ) # [P, S] + init = pyro.sample( + "init", dist.Normal(init_loc, init_scale).to_event(1) + ) # [P, S] + + # Finally observe counts. + with pyro.plate("time", T, dim=-2): + logits = init + rate * local_time # [T, P, S] + pyro.sample( + "obs", + dist.Multinomial(logits=logits, validate_args=False), + obs=weekly_strains, + ) + + # This is modified by more precisely tracking plates for features and strains. def pyrocov_model_plated(dataset): # Tensor shapes are commented at the end of some lines. @@ -537,16 +584,36 @@ def pyrocov_model_poisson(dataset): pyro.sample("obs", dist.Poisson(logits), obs=weekly_strains) +class PoissonGuide(AutoGuideList): + def __init__(self, model, backend): + super().__init__(model) + self.append( + AutoGaussian(poutine.block(model, hide_fn=self.hide_fn_1), backend=backend) + ) + self.append( + AutoGaussian(poutine.block(model, hide_fn=self.hide_fn_2), backend=backend) + ) + + @staticmethod + def hide_fn_1(msg): + return msg["type"] == "sample" and "pois" in msg["name"] + + @staticmethod + def hide_fn_2(msg): + return msg["type"] == "sample" and "pois" not in msg["name"] + + PYRO_COV_MODELS = [ - pyrocov_model, - pyrocov_model_plated, - pyrocov_model_poisson, + (pyrocov_model, AutoGaussian), + (pyrocov_model_relaxed, AutoGaussian), + (pyrocov_model_plated, AutoGaussian), + (pyrocov_model_poisson, PoissonGuide), ] -@pytest.mark.parametrize("model", PYRO_COV_MODELS) +@pytest.mark.parametrize("model, Guide", PYRO_COV_MODELS) @pytest.mark.parametrize("backend", BACKENDS) -def test_pyrocov_smoke(model, backend): +def test_pyrocov_smoke(model, Guide, backend): T, P, S, F = 3, 4, 5, 6 dataset = { "features": torch.randn(S, F), @@ -554,7 +621,7 @@ def test_pyrocov_smoke(model, backend): "weekly_strains": torch.randn(T, P, S).exp().round(), } - guide = AutoGaussian(model, backend=backend) + guide = Guide(model, backend=backend) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) for step in range(2): with xfail_if_not_implemented(): @@ -564,9 +631,9 @@ def test_pyrocov_smoke(model, backend): predictive(dataset) -@pytest.mark.parametrize("model", PYRO_COV_MODELS) +@pytest.mark.parametrize("model, Guide", PYRO_COV_MODELS) @pytest.mark.parametrize("backend", BACKENDS) -def test_pyrocov_reparam(model, backend): +def test_pyrocov_reparam(model, Guide, backend): T, P, S, F = 2, 3, 4, 5 dataset = { "features": torch.randn(S, F), @@ -583,7 +650,7 @@ def test_pyrocov_reparam(model, backend): "init": LocScaleReparam(), } model = poutine.reparam(model, config) - guide = AutoGaussian(model, backend=backend) + guide = Guide(model, backend=backend) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) for step in range(2): with xfail_if_not_implemented(): @@ -604,30 +671,27 @@ def test_pyrocov_structure(): "weekly_strains": torch.randn(T, P, S).exp().round(), } - guide = AutoGaussian(pyrocov_model_poisson, backend="funsor") + guide = PoissonGuide(pyrocov_model_poisson, backend="funsor") guide(dataset) # initialize + guide = guide[0] # pull out AutoGaussian part of PoissonGuide - expected_plates = frozenset(["time", "place", "strain"]) + expected_plates = frozenset(["place", "strain"]) assert guide._funsor_plates == expected_plates expected_eliminate = frozenset( [ - "time", - "place", - "strain", + "coef", "coef_scale", - "rate_loc_scale", - "rate_scale", + "init", + "init_loc", "init_loc_scale", "init_scale", - "coef", - "rate_loc", - "init_loc", + "place", "rate", - "init", - "pois_loc", - "pois_scale", - "pois", + "rate_loc", + "rate_loc_scale", + "rate_scale", + "strain", ] ) assert guide._funsor_eliminate == expected_eliminate @@ -638,8 +702,6 @@ def test_pyrocov_structure(): "rate_scale": OrderedDict([("rate_scale", Real)]), "init_loc_scale": OrderedDict([("init_loc_scale", Real)]), "init_scale": OrderedDict([("init_scale", Real)]), - "pois_loc": OrderedDict([("pois_loc", Real)]), - "pois_scale": OrderedDict([("pois_scale", Real)]), "coef": OrderedDict([("coef", Reals[5]), ("coef_scale", Real)]), "rate_loc": OrderedDict( [ @@ -650,7 +712,11 @@ def test_pyrocov_structure(): ] ), "init_loc": OrderedDict( - [("strain", Bint[4]), ("init_loc", Real), ("init_loc_scale", Real)] + [ + ("strain", Bint[4]), + ("init_loc", Real), + ("init_loc_scale", Real), + ] ), "rate": OrderedDict( [ @@ -670,13 +736,12 @@ def test_pyrocov_structure(): ("init_loc", Real), ] ), - "pois": OrderedDict( + "obs": OrderedDict( [ - ("time", Bint[2]), ("place", Bint[3]), - ("pois", Real), - ("pois_loc", Real), - ("pois_scale", Real), + ("strain", Bint[4]), + ("rate", Real), + ("init", Real), ] ), } @@ -699,7 +764,7 @@ def test_profile(backend, jit, n=1, num_steps=1, log_every=1): } print("Initializing guide") - guide = AutoGaussian(model, backend=backend) + guide = PoissonGuide(model, backend=backend) guide(dataset) # initialize print("Parameter shapes:") for name, param in guide.named_parameters(): From e606e615a728985444e70d0d933fa433d5526162 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Oct 2021 23:36:29 -0400 Subject: [PATCH 06/16] Add failing test of elbo gradient --- pyro/infer/autoguide/gaussian.py | 5 +- tests/infer/autoguide/test_gaussian.py | 67 ++++++++++++++++++-------- tests/infer/test_autoguide.py | 5 +- 3 files changed, 54 insertions(+), 23 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index ad21f3b578..25474415d1 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -212,11 +212,12 @@ def _setup_prototype(self, *args, **kwargs) -> None: # Create a square root parameter (full, not lower triangular). # We initialize with noise to avoid singular gradient. - sqrt = torch.randn( + sqrt = torch.rand( batch_shape + (u_size, d_size), dtype=init_loc.dtype, device=init_loc.device, - ).mul_(self._init_scale) + ) + sqrt.sub_(0.5).mul_(self._init_scale) if not site["is_observed"]: # Initialize the [d,d] block to the identity matrix. sqrt.diagonal(dim1=-2, dim2=-1).fill_(1) diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 101468ccc9..bde3d8e07d 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -133,8 +133,9 @@ def check_backends_agree(model): for k, v in params1.items(): v.data.normal_() params2[k].data.copy_(v.data) + names = sorted(params1) - # Check that densities agree between backends. + # Check densities agree between backends. with torch.no_grad(), poutine.trace() as tr: aux = guide2._sample_aux_values() flat = guide1._dense_flatten(aux) @@ -147,28 +148,56 @@ def check_backends_agree(model): log_prob_dense = tr.trace.nodes["_AutoGaussianDense_latent"]["log_prob"] assert_equal(log_prob_funsor, log_prob_dense) + # Check elbos agree between backends. + elbo = Trace_ELBO(num_particles=100000, vectorize_particles=True) + loss1 = elbo.differentiable_loss(model, guide1) + loss2 = elbo.differentiable_loss(model, guide2) + assert_close(loss1, loss2, atol=0.05, rtol=0.05) + grads1 = torch.autograd.grad(loss1, [params1[k] for k in names], allow_unused=True) + grads2 = torch.autograd.grad(loss2, [params2[k] for k in names], allow_unused=True) + for name, grad1, grad2 in zip(names, grads1, grads2): + assert_close(grad1, grad2, atol=0.05, msg=f"{name}:\n{grad1} vs {grad2}") + # Check Monte Carlo estimate of entropy. - expected_entropy = guide1._dense_get_mvn().entropy() + entropy1 = guide1._dense_get_mvn().entropy() with pyro.plate("particle", 100000, dim=-3), poutine.trace() as tr: guide2._sample_aux_values() tr.trace.compute_log_prob() - actual_entropy = -tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"].mean() - assert_close(actual_entropy, expected_entropy, atol=1e-1) - - # Check gradients. - names = sorted(params1) - expected_grads = torch.autograd.grad( - expected_entropy, [params1[k] for k in names], allow_unused=True + aentropy2 = -tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"].mean() + assert_close(entropy1, entropy2, atol=1e-1) + grads1 = torch.autograd.grad( + entropy1, [params1[k] for k in names], allow_unused=True ) - actual_grads = torch.autograd.grad( - actual_entropy, [params2[k] for k in names], allow_unused=True + grads2 = torch.autograd.grad( + entropy2, [params2[k] for k in names], allow_unused=True ) - for name, actual, expected in zip(names, actual_grads, expected_grads): - assert_close(actual, expected, msg=name) + for name, grad1, grad2 in zip(names, grads1, grads2): + assert_close(grad1, grad2, atol=0.05, msg=f"{name}:\n{grad1} vs {grad2}") @pytest.mark.parametrize("backend", BACKENDS) def test_structure_1(backend): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("i", 3): + pyro.sample("b", dist.Normal(a, 1), obs=torch.zeros(3)) + + # size = 1 + structure = [ + "?", + ] + dependencies = { + "a": {"a": set()}, + "b": {"b": set(), "a": set()}, + } + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure, dependencies) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_2(backend): def model(): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(a, 1)) @@ -194,7 +223,7 @@ def model(): @pytest.mark.parametrize("backend", BACKENDS) -def test_structure_2(backend): +def test_structure_3(backend): def model(): with pyro.plate("i", 2): a = pyro.sample("a", dist.Normal(0, 1)) @@ -224,7 +253,7 @@ def model(): @pytest.mark.parametrize("backend", BACKENDS) -def test_structure_3(backend): +def test_structure_4(backend): def model(): a = pyro.sample("a", dist.Normal(0, 1)) with pyro.plate("i", 2): @@ -253,7 +282,7 @@ def model(): @pytest.mark.parametrize("backend", BACKENDS) -def test_structure_4(backend): +def test_structure_5(backend): def model(): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(0, 1)) @@ -275,7 +304,7 @@ def model(): @pytest.mark.parametrize("backend", BACKENDS) -def test_structure_5(backend): +def test_structure_6(backend): I, J = 2, 3 def model(): @@ -310,7 +339,7 @@ def model(): @pytest.mark.parametrize("backend", BACKENDS) -def test_structure_6(backend): +def test_structure_7(backend): I, J = 2, 3 def model(): @@ -341,7 +370,7 @@ def model(): @pytest.mark.parametrize("backend", BACKENDS) -def test_structure_7(backend): +def test_structure_8(backend): def model(): i_plate = pyro.plate("i", 2, dim=-1) with i_plate: diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 399eda3d6e..73711a4280 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -1309,9 +1309,10 @@ def model(data): elbo = JitTrace_ELBO( num_particles=100, vectorize_particles=True, ignore_jit_warnings=True ) - optim = Adam({"lr": 0.01}) + num_steps = 500 + optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) svi = SVI(model, guide, optim, elbo) - for step in range(500): + for step in range(num_steps): svi.step(data) guide.requires_grad_(False) From f3771493d443c7af27da8011ecb1409c0ce15d4b Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Oct 2021 23:54:28 -0400 Subject: [PATCH 07/16] lint --- tests/infer/autoguide/test_gaussian.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index bde3d8e07d..894fe80542 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -163,7 +163,7 @@ def check_backends_agree(model): with pyro.plate("particle", 100000, dim=-3), poutine.trace() as tr: guide2._sample_aux_values() tr.trace.compute_log_prob() - aentropy2 = -tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"].mean() + entropy2 = -tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"].mean() assert_close(entropy1, entropy2, atol=1e-1) grads1 = torch.autograd.grad( entropy1, [params1[k] for k in names], allow_unused=True From f378036a7eb82d6449b47fbc59862a082dc575c1 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 26 Oct 2021 18:59:55 -0400 Subject: [PATCH 08/16] Make test less trivial --- tests/infer/autoguide/test_gaussian.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 894fe80542..44a213245f 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -180,7 +180,7 @@ def test_structure_1(backend): def model(): a = pyro.sample("a", dist.Normal(0, 1)) with pyro.plate("i", 3): - pyro.sample("b", dist.Normal(a, 1), obs=torch.zeros(3)) + pyro.sample("b", dist.Normal(a, 1), obs=torch.ones(3)) # size = 1 structure = [ @@ -202,7 +202,7 @@ def model(): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(a, 1)) c = pyro.sample("c", dist.Normal(b, 1)) - pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(1.0)) # size = 1 + 1 + 1 = 3 structure = [ @@ -229,7 +229,7 @@ def model(): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(a, 1)) c = pyro.sample("c", dist.Normal(b, 1)) - pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(1.0)) # size = 2 + 2 + 2 = 6 structure = [ @@ -259,7 +259,7 @@ def model(): with pyro.plate("i", 2): b = pyro.sample("b", dist.Normal(a, 1)) c = pyro.sample("c", dist.Normal(b, 1)) - pyro.sample("d", dist.Normal(c.sum(), 1), obs=torch.tensor(0.0)) + pyro.sample("d", dist.Normal(c.sum(), 1), obs=torch.tensor(1.0)) # size = 1 + 2 + 2 = 5 structure = [ @@ -288,7 +288,7 @@ def model(): b = pyro.sample("b", dist.Normal(0, 1)) with pyro.plate("i", 2): c = pyro.sample("c", dist.Normal(a, b.exp())) - pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(1.0)) # size = 1 + 1 + 2 = 4 structure = [ @@ -316,7 +316,7 @@ def model(): x = pyro.sample("x", dist.Normal(0, 1)) with i_plate, j_plate: y = pyro.sample("y", dist.Normal(w, x.exp())) - pyro.sample("z", dist.Normal(0, 1), obs=y) + pyro.sample("z", dist.Normal(1, 1), obs=y) # size = 2 + 3 + 2 * 3 = 2 + 3 + 6 = 11 structure = [ @@ -351,7 +351,7 @@ def model(): with j_plate: c = pyro.sample("c", dist.Normal(b.mean(), 1)) d = pyro.sample("d", dist.Normal(c.mean(), 1)) - pyro.sample("e", dist.Normal(0, 1), obs=d) + pyro.sample("e", dist.Normal(1, 1), obs=d) # size = 1 + 2 + 3 + 1 = 7 structure = [ @@ -377,7 +377,7 @@ def model(): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(a.mean(-1), 1)) with i_plate: - pyro.sample("c", dist.Normal(b, 1), obs=torch.zeros(2)) + pyro.sample("c", dist.Normal(b, 1), obs=torch.ones(2)) # size = 2 + 1 = 3 structure = [ From 34ce03c580f4e66aa94bac97c3c2d42757a1e327 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 27 Oct 2021 10:10:21 -0400 Subject: [PATCH 09/16] Strengthen tests, make AutoGaussian abstract --- pyro/infer/autoguide/gaussian.py | 4 ++- tests/infer/autoguide/test_gaussian.py | 43 +++++++++++++++++++------- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 25474415d1..93f5820ef2 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import itertools +from abc import ABCMeta, abstractmethod from collections import OrderedDict, defaultdict from contextlib import ExitStack from types import SimpleNamespace @@ -30,7 +31,7 @@ # AutoGaussianDense(model) # The intent is to avoid proliferation of subclasses and docstrings, # and provide a single interface AutoGaussian(...). -class AutoGaussianMeta(type(AutoGuide)): +class AutoGaussianMeta(type(AutoGuide), ABCMeta): backends = {} default_backend = "dense" @@ -297,6 +298,7 @@ def _transform_values( return values, log_densities + @abstractmethod def _sample_aux_values(self) -> Dict[str, torch.Tensor]: raise NotImplementedError diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 44a213245f..3e4c56826e 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -148,33 +148,54 @@ def check_backends_agree(model): log_prob_dense = tr.trace.nodes["_AutoGaussianDense_latent"]["log_prob"] assert_equal(log_prob_funsor, log_prob_dense) - # Check elbos agree between backends. - elbo = Trace_ELBO(num_particles=100000, vectorize_particles=True) - loss1 = elbo.differentiable_loss(model, guide1) - loss2 = elbo.differentiable_loss(model, guide2) - assert_close(loss1, loss2, atol=0.05, rtol=0.05) - grads1 = torch.autograd.grad(loss1, [params1[k] for k in names], allow_unused=True) - grads2 = torch.autograd.grad(loss2, [params2[k] for k in names], allow_unused=True) - for name, grad1, grad2 in zip(names, grads1, grads2): - assert_close(grad1, grad2, atol=0.05, msg=f"{name}:\n{grad1} vs {grad2}") - # Check Monte Carlo estimate of entropy. entropy1 = guide1._dense_get_mvn().entropy() with pyro.plate("particle", 100000, dim=-3), poutine.trace() as tr: guide2._sample_aux_values() tr.trace.compute_log_prob() entropy2 = -tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"].mean() - assert_close(entropy1, entropy2, atol=1e-1) + assert_close(entropy1, entropy2, atol=1e-2) grads1 = torch.autograd.grad( entropy1, [params1[k] for k in names], allow_unused=True ) grads2 = torch.autograd.grad( entropy2, [params2[k] for k in names], allow_unused=True ) + for name, grad1, grad2 in zip(names, grads1, grads2): + assert_close(grad1, grad2, msg=f"{name}:\n{grad1} vs {grad2}") + + # Check elbos agree between backends. + elbo = Trace_ELBO(num_particles=1000000, vectorize_particles=True) + loss1 = elbo.differentiable_loss(model, guide1) + loss2 = elbo.differentiable_loss(model, guide2) + assert_close(loss1, loss2, atol=1e-2, rtol=0.05) + grads1 = torch.autograd.grad(loss1, [params1[k] for k in names], allow_unused=True) + grads2 = torch.autograd.grad(loss2, [params2[k] for k in names], allow_unused=True) for name, grad1, grad2 in zip(names, grads1, grads2): assert_close(grad1, grad2, atol=0.05, msg=f"{name}:\n{grad1} vs {grad2}") +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_0(backend): + @poutine.scale(scale=1e-20) # DEBUG + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + pyro.sample("b", dist.Normal(a, 1), obs=torch.ones(())) + + # size = 1 + structure = [ + "?", + ] + dependencies = { + "a": {"a": set()}, + "b": {"b": set(), "a": set()}, + } + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure, dependencies) + + @pytest.mark.parametrize("backend", BACKENDS) def test_structure_1(backend): def model(): From dab6d958c676d09ff9ec8c85c35e84eac4b9ce83 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 27 Oct 2021 10:34:06 -0400 Subject: [PATCH 10/16] Add has_rsample kwarg to pyro.factor --- pyro/distributions/unit.py | 7 ++++++- pyro/primitives.py | 17 ++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pyro/distributions/unit.py b/pyro/distributions/unit.py index 7326583ce9..d474564673 100644 --- a/pyro/distributions/unit.py +++ b/pyro/distributions/unit.py @@ -20,17 +20,19 @@ class Unit(TorchDistribution): arg_constraints = {"log_factor": constraints.real} support = constraints.real - def __init__(self, log_factor, validate_args=None): + def __init__(self, log_factor, *, has_rsample=False, validate_args=None): log_factor = torch.as_tensor(log_factor) batch_shape = log_factor.shape event_shape = torch.Size((0,)) # This satisfies .numel() == 0. self.log_factor = log_factor + self.has_rsample = has_rsample super().__init__(batch_shape, event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) new = self._get_checked_instance(Unit, _instance) new.log_factor = self.log_factor.expand(batch_shape) + new.has_rsample = self.has_rsample super(Unit, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new @@ -38,6 +40,9 @@ def expand(self, batch_shape, _instance=None): def sample(self, sample_shape=torch.Size()): return self.log_factor.new_empty(sample_shape + self.shape()) + def rsample(self, sample_shape=torch.Size()): + return self.log_factor.new_empty(sample_shape + self.shape()) + def log_prob(self, value): shape = broadcast_shape(self.batch_shape, value.shape[:-1]) return self.log_factor.expand(shape) diff --git a/pyro/primitives.py b/pyro/primitives.py index a96e75e44d..d722737486 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -165,22 +165,25 @@ def sample(name, fn, *args, **kwargs): return msg["value"] -def factor(name, log_factor): +def factor(name, log_factor, *, has_rsample=False): """ Factor statement to add arbitrary log probability factor to a probabilisitic model. .. warning:: Beware using factor statements in guides. Factor statements - assume ``log_factor`` is computed from non-reparametrized statements - such as observation statements ``pyro.sample(..., obs=...)``. If - instead ``log_factor`` is computed from e.g. the Jacobian determinant - of a transformation of a reparametrized variable, factor statements - in the guide will result in incorrect results. + assume by default that ``log_factor`` is computed from + non-reparametrized statements such as observation statements + ``pyro.sample(..., obs=...)``. If instead ``log_factor`` is computed + from e.g. the Jacobian determinant of a transformation of a + reparametrized variable, you'll need to set ``has_rsample=True``. :param str name: Name of the trivial sample :param torch.Tensor log_factor: A possibly batched log probability factor. + :param bool has_rsample: Whether the ``log_factor`` arose from a fully + reparametrized distribution. Defaults to False, which is safe for use + in models (but may not be safe for use in guides). """ - unit_dist = dist.Unit(log_factor) + unit_dist = dist.Unit(log_factor, has_rsample=has_rsample) unit_value = unit_dist.sample() sample(name, unit_dist, obs=unit_value, infer={"is_auxiliary": True}) From 722581f0039f357f35b54f837092d87fa7a41a9f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 27 Oct 2021 10:35:09 -0400 Subject: [PATCH 11/16] Fix tests --- pyro/infer/autoguide/gaussian.py | 2 +- tests/infer/autoguide/test_gaussian.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 93f5820ef2..5f362d2468 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -509,7 +509,7 @@ def _sample_aux_values(self) -> Dict[str, torch.Tensor]: # Convert funsor to torch. if am_i_wrapped() and poutine.get_mask() is not False: log_prob = funsor.to_data(log_prob, name_to_dim=plate_to_dim) - pyro.factor(f"_{self._pyro_name}_latent", log_prob) + pyro.factor(f"_{self._pyro_name}_latent", log_prob, has_rsample=True) samples = { k: funsor.to_data(v, name_to_dim=plate_to_dim) for k, v in samples.items() } diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index 3e4c56826e..f4c153fee2 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -162,17 +162,20 @@ def check_backends_agree(model): entropy2, [params2[k] for k in names], allow_unused=True ) for name, grad1, grad2 in zip(names, grads1, grads2): + # Gradients should agree to very high precision. assert_close(grad1, grad2, msg=f"{name}:\n{grad1} vs {grad2}") # Check elbos agree between backends. - elbo = Trace_ELBO(num_particles=1000000, vectorize_particles=True) + elbo = Trace_ELBO(num_particles=100000, vectorize_particles=True) loss1 = elbo.differentiable_loss(model, guide1) loss2 = elbo.differentiable_loss(model, guide2) assert_close(loss1, loss2, atol=1e-2, rtol=0.05) grads1 = torch.autograd.grad(loss1, [params1[k] for k in names], allow_unused=True) grads2 = torch.autograd.grad(loss2, [params2[k] for k in names], allow_unused=True) for name, grad1, grad2 in zip(names, grads1, grads2): - assert_close(grad1, grad2, atol=0.05, msg=f"{name}:\n{grad1} vs {grad2}") + assert_close( + grad1, grad2, atol=0.05, rtol=0.05, msg=f"{name}:\n{grad1} vs {grad2}" + ) @pytest.mark.parametrize("backend", BACKENDS) From 7110f8af2753f6b69a5274f0431480efdd99787a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 27 Oct 2021 10:34:06 -0400 Subject: [PATCH 12/16] Add has_rsample kwarg to pyro.factor --- pyro/distributions/unit.py | 7 ++++++- pyro/primitives.py | 17 ++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pyro/distributions/unit.py b/pyro/distributions/unit.py index 455d5d24d8..a5b55bda76 100644 --- a/pyro/distributions/unit.py +++ b/pyro/distributions/unit.py @@ -20,16 +20,18 @@ class Unit(TorchDistribution): arg_constraints = {"log_factor": constraints.real} support = constraints.real - def __init__(self, log_factor, validate_args=None): + def __init__(self, log_factor, *, has_rsample=False, validate_args=None): log_factor = torch.as_tensor(log_factor) batch_shape = log_factor.shape event_shape = torch.Size((0,)) # This satisfies .numel() == 0. self.log_factor = log_factor + self.has_rsample = has_rsample super().__init__(batch_shape, event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Unit, _instance) new.log_factor = self.log_factor.expand(batch_shape) + new.has_rsample = self.has_rsample super(Unit, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new @@ -37,6 +39,9 @@ def expand(self, batch_shape, _instance=None): def sample(self, sample_shape=torch.Size()): return self.log_factor.new_empty(sample_shape + self.shape()) + def rsample(self, sample_shape=torch.Size()): + return self.log_factor.new_empty(sample_shape + self.shape()) + def log_prob(self, value): shape = broadcast_shape(self.batch_shape, value.shape[:-1]) return self.log_factor.expand(shape) diff --git a/pyro/primitives.py b/pyro/primitives.py index a96e75e44d..d722737486 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -165,22 +165,25 @@ def sample(name, fn, *args, **kwargs): return msg["value"] -def factor(name, log_factor): +def factor(name, log_factor, *, has_rsample=False): """ Factor statement to add arbitrary log probability factor to a probabilisitic model. .. warning:: Beware using factor statements in guides. Factor statements - assume ``log_factor`` is computed from non-reparametrized statements - such as observation statements ``pyro.sample(..., obs=...)``. If - instead ``log_factor`` is computed from e.g. the Jacobian determinant - of a transformation of a reparametrized variable, factor statements - in the guide will result in incorrect results. + assume by default that ``log_factor`` is computed from + non-reparametrized statements such as observation statements + ``pyro.sample(..., obs=...)``. If instead ``log_factor`` is computed + from e.g. the Jacobian determinant of a transformation of a + reparametrized variable, you'll need to set ``has_rsample=True``. :param str name: Name of the trivial sample :param torch.Tensor log_factor: A possibly batched log probability factor. + :param bool has_rsample: Whether the ``log_factor`` arose from a fully + reparametrized distribution. Defaults to False, which is safe for use + in models (but may not be safe for use in guides). """ - unit_dist = dist.Unit(log_factor) + unit_dist = dist.Unit(log_factor, has_rsample=has_rsample) unit_value = unit_dist.sample() sample(name, unit_dist, obs=unit_value, infer={"is_auxiliary": True}) From 60925d7fa45472987c087dd3260f4ac4ad619bcc Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 27 Oct 2021 10:55:39 -0400 Subject: [PATCH 13/16] Require specification of has_rsample for pyro.factor in guides --- pyro/distributions/unit.py | 8 +++++--- pyro/primitives.py | 19 ++++++++++--------- pyro/util.py | 16 ++++++++++++++++ tests/infer/test_valid_models.py | 23 ++++++++++++++++++++++- 4 files changed, 53 insertions(+), 13 deletions(-) diff --git a/pyro/distributions/unit.py b/pyro/distributions/unit.py index a5b55bda76..50d721da0a 100644 --- a/pyro/distributions/unit.py +++ b/pyro/distributions/unit.py @@ -20,18 +20,20 @@ class Unit(TorchDistribution): arg_constraints = {"log_factor": constraints.real} support = constraints.real - def __init__(self, log_factor, *, has_rsample=False, validate_args=None): + def __init__(self, log_factor, *, has_rsample=None, validate_args=None): log_factor = torch.as_tensor(log_factor) batch_shape = log_factor.shape event_shape = torch.Size((0,)) # This satisfies .numel() == 0. self.log_factor = log_factor - self.has_rsample = has_rsample + if has_rsample is not None: + self.has_rsample = has_rsample super().__init__(batch_shape, event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Unit, _instance) new.log_factor = self.log_factor.expand(batch_shape) - new.has_rsample = self.has_rsample + if "has_rsample" in self.__dict__: + new.has_rsample = self.has_rsample super(Unit, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new diff --git a/pyro/primitives.py b/pyro/primitives.py index d722737486..7ab6a97425 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -165,23 +165,24 @@ def sample(name, fn, *args, **kwargs): return msg["value"] -def factor(name, log_factor, *, has_rsample=False): +def factor(name, log_factor, *, has_rsample=None): """ Factor statement to add arbitrary log probability factor to a probabilisitic model. - .. warning:: Beware using factor statements in guides. Factor statements - assume by default that ``log_factor`` is computed from - non-reparametrized statements such as observation statements - ``pyro.sample(..., obs=...)``. If instead ``log_factor`` is computed - from e.g. the Jacobian determinant of a transformation of a - reparametrized variable, you'll need to set ``has_rsample=True``. + .. warning:: When using factor statements in guides, you'll need to specify + whether the factor statement originated from fully reparametrized + sampling (e.g. the Jacobian determinant of a transformation of a + reparametrized variable) or from nonreparameterized sampling (e.g. + discrete samples). For the fully reparametrized case, set + ``has_rsample=True``; for the nonreparametrized case, set + ``has_rsample=False``. This is needed only in guides, not in models. :param str name: Name of the trivial sample :param torch.Tensor log_factor: A possibly batched log probability factor. :param bool has_rsample: Whether the ``log_factor`` arose from a fully - reparametrized distribution. Defaults to False, which is safe for use - in models (but may not be safe for use in guides). + reparametrized distribution. Defaults to False when used in models, but + must be specified for use in guides. """ unit_dist = dist.Unit(log_factor, has_rsample=has_rsample) unit_value = unit_dist.sample() diff --git a/pyro/util.py b/pyro/util.py index 72ea4e1519..4726e8e239 100644 --- a/pyro/util.py +++ b/pyro/util.py @@ -367,6 +367,22 @@ def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf ) ) + # Check factor statements in guide specify has_rsample. + for name, site in guide_trace.nodes.items(): + if not site["type"] == "sample": + continue + if not site["infer"].get("is_auxiliary"): + continue + if type(site["fn"]).__name__ != "Unit": + continue + if "has_rsample" not in site["fn"].__dict__: + raise ValueError( + f'At guide site pyro.factor("{name}",...), ' + "missing specification of has_rsample. " + "Please either set has_rsample=True if the factor statement arises " + "from reparametrized sampling or has_rsample=False otherwise." + ) + def check_site_shape(site, max_plate_nesting): actual_shape = list(site["log_prob"].shape) diff --git a/tests/infer/test_valid_models.py b/tests/infer/test_valid_models.py index cf6aecb3b2..9b9a030bf0 100644 --- a/tests/infer/test_valid_models.py +++ b/tests/infer/test_valid_models.py @@ -2126,13 +2126,34 @@ def guide(): TraceTMC_ELBO, ], ) -def test_factor_in_guide_ok(Elbo): +def test_factor_in_guide_error(Elbo): def model(): pass def guide(): pyro.factor("f", torch.tensor(0.0)) + elbo = Elbo(strict_enumeration_warning=False) + assert_error(model, guide, elbo, match=".*missing specification of has_rsample.*") + + +@pytest.mark.parametrize( + "Elbo", + [ + Trace_ELBO, + TraceGraph_ELBO, + TraceEnum_ELBO, + TraceTMC_ELBO, + ], +) +@pytest.mark.parametrize("has_rsample", [False, True]) +def test_factor_in_guide_ok(Elbo, has_rsample): + def model(): + pass + + def guide(): + pyro.factor("f", torch.tensor(0.0), has_rsample=has_rsample) + elbo = Elbo(strict_enumeration_warning=False) assert_ok(model, guide, elbo) From c817cac9f0e20b37d5cef95144859de550fbc62c Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 27 Oct 2021 11:09:33 -0400 Subject: [PATCH 14/16] Remove debug statement --- tests/infer/autoguide/test_gaussian.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index f4c153fee2..a0f2dec2a9 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -180,7 +180,6 @@ def check_backends_agree(model): @pytest.mark.parametrize("backend", BACKENDS) def test_structure_0(backend): - @poutine.scale(scale=1e-20) # DEBUG def model(): a = pyro.sample("a", dist.Normal(0, 1)) pyro.sample("b", dist.Normal(a, 1), obs=torch.ones(())) From 882d36ed3ed20575a0a62983909d53d640e8e362 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 27 Oct 2021 11:17:21 -0400 Subject: [PATCH 15/16] Update AutoGaussian --- pyro/infer/autoguide/gaussian.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 9b9a8a12bb..64b86a0f35 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -462,7 +462,7 @@ def _sample_aux_values(self) -> Dict[str, torch.Tensor]: # Convert funsor to torch. if am_i_wrapped() and poutine.get_mask() is not False: log_prob = funsor.to_data(log_prob, name_to_dim=plate_to_dim) - pyro.factor(f"_{self._pyro_name}_latent", log_prob) + pyro.factor(f"_{self._pyro_name}_latent", log_prob, has_rsample=True) samples = { k: funsor.to_data(v, name_to_dim=plate_to_dim) for k, v in samples.items() } From 4f3361d74166e498cb44c0bf93ce230767491a92 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 27 Oct 2021 13:27:59 -0400 Subject: [PATCH 16/16] Fix scanvi example --- examples/scanvi/scanvi.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/scanvi/scanvi.py b/examples/scanvi/scanvi.py index c457e5ec69..1de492c407 100644 --- a/examples/scanvi/scanvi.py +++ b/examples/scanvi/scanvi.py @@ -267,7 +267,11 @@ def guide(self, x, y=None): classification_loss = y_dist.log_prob(y) # Note that the negative sign appears because we're adding this term in the guide # and the guide log_prob appears in the ELBO as -log q - pyro.factor("classification_loss", -self.alpha * classification_loss) + pyro.factor( + "classification_loss", + -self.alpha * classification_loss, + has_rsample=False, + ) z1_loc, z1_scale = self.z1_encoder(z2, y) pyro.sample("z1", dist.Normal(z1_loc, z1_scale).to_event(1))