diff --git a/pyro/distributions/unit.py b/pyro/distributions/unit.py index 50d721da0a..c80bab1ac6 100644 --- a/pyro/distributions/unit.py +++ b/pyro/distributions/unit.py @@ -30,6 +30,7 @@ def __init__(self, log_factor, *, has_rsample=None, validate_args=None): super().__init__(batch_shape, event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): + batch_shape = torch.Size(batch_shape) new = self._get_checked_instance(Unit, _instance) new.log_factor = self.log_factor.expand(batch_shape) if "has_rsample" in self.__dict__: diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index 64b86a0f35..5f362d2468 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import itertools +from abc import ABCMeta, abstractmethod from collections import OrderedDict, defaultdict from contextlib import ExitStack from types import SimpleNamespace @@ -14,7 +15,7 @@ import pyro.distributions as dist import pyro.poutine as poutine from pyro.distributions import constraints -from pyro.infer.inspect import get_dependencies +from pyro.infer.inspect import get_dependencies, is_sample_site from pyro.nn.module import PyroModule, PyroParam from pyro.poutine.runtime import am_i_wrapped, get_plates from pyro.poutine.util import site_is_subsample @@ -30,7 +31,7 @@ # AutoGaussianDense(model) # The intent is to avoid proliferation of subclasses and docstrings, # and provide a single interface AutoGaussian(...). -class AutoGaussianMeta(type(AutoGuide)): +class AutoGaussianMeta(type(AutoGuide), ABCMeta): backends = {} default_backend = "dense" @@ -41,8 +42,9 @@ def __init__(cls, *args, **kwargs): cls.backends[key] = cls def __call__(cls, *args, **kwargs): - backend = kwargs.pop("backend", cls.default_backend) - cls = cls.backends[backend] + if cls is AutoGaussian: + backend = kwargs.pop("backend", cls.default_backend) + cls = cls.backends[backend] return super(AutoGaussianMeta, cls).__call__(*args, **kwargs) @@ -117,6 +119,12 @@ def __init__( model = InitMessenger(init_loc_fn)(model) super().__init__(model) + @staticmethod + def _prototype_hide_fn(msg): + # In contrast to the AutoGuide base class, this includes observation + # sites and excludes deterministic sites. + return not is_sample_site(msg) + def _setup_prototype(self, *args, **kwargs) -> None: super()._setup_prototype(*args, **kwargs) @@ -135,6 +143,12 @@ def _setup_prototype(self, *args, **kwargs) -> None: "prior_dependencies" ] + # Eliminate observations with no upstream latents. + for d, upstreams in list(self.dependencies.items()): + if all(self.prototype_trace.nodes[u]["is_observed"] for u in upstreams): + del self.dependencies[d] + del self.prototype_trace.nodes[d] + # Collect factors and plates. for d, site in self.prototype_trace.nodes.items(): # Prune non-essential parts of the trace to save memory. @@ -153,14 +167,19 @@ def _setup_prototype(self, *args, **kwargs) -> None: "Are you missing a pyro.plate() or .to_event()?" ) if site["is_observed"]: - # Eagerly eliminate irrelevant observation plates. - plates &= frozenset.union( + # Break irrelevant observation plates. + plates &= frozenset().union( *(self._plates[u] for u in self.dependencies[d] if u != d) ) self._plates[d] = plates # Create location-scale parameters, one per latent variable. if site["is_observed"]: + # This may slightly overestimate, e.g. for Multinomial. + self._event_numel[d] = site["fn"].event_shape.numel() + # Account for broken irrelevant observation plates. + for f in set(site["cond_indep_stack"]) - plates: + self._event_numel[d] *= f.size continue with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() @@ -184,16 +203,23 @@ def _setup_prototype(self, *args, **kwargs) -> None: for d, site in self._factors.items(): u_size = 0 for u in self.dependencies[d]: - broken_shape = _plates_to_shape(self._plates[u] - self._plates[d]) - u_size += broken_shape.numel() * self._event_numel[u] + if not self._factors[u]["is_observed"]: + broken_shape = _plates_to_shape(self._plates[u] - self._plates[d]) + u_size += broken_shape.numel() * self._event_numel[u] d_size = self._event_numel[d] if site["is_observed"]: d_size = min(d_size, u_size) # just an optimization batch_shape = _plates_to_shape(self._plates[d]) # Create a square root parameter (full, not lower triangular). - sqrt = init_loc.new_zeros(batch_shape + (u_size, d_size)) - if d in self.dependencies[d]: + # We initialize with noise to avoid singular gradient. + sqrt = torch.rand( + batch_shape + (u_size, d_size), + dtype=init_loc.dtype, + device=init_loc.device, + ) + sqrt.sub_(0.5).mul_(self._init_scale) + if not site["is_observed"]: # Initialize the [d,d] block to the identity matrix. sqrt.diagonal(dim1=-2, dim2=-1).fill_(1) deep_setattr(self.factors, d, PyroParam(sqrt, event_dim=2)) @@ -223,6 +249,8 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # Replay via Pyro primitives. plates = self._create_plates(*args, **kwargs) for name, site in self._factors.items(): + if site["is_observed"]: + continue with ExitStack() as stack: for frame in site["cond_indep_stack"]: stack.enter_context(plates[frame.name]) @@ -253,6 +281,8 @@ def _transform_values( log_densities = defaultdict(float) compute_density = am_i_wrapped() and poutine.get_mask() is not False for name, site in self._factors.items(): + if site["is_observed"]: + continue loc = deep_getattr(self.locs, name) scale = deep_getattr(self.scales, name) unconstrained = aux_values[name] * scale + loc @@ -268,6 +298,7 @@ def _transform_values( return values, log_densities + @abstractmethod def _sample_aux_values(self) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -305,15 +336,18 @@ def _setup_prototype(self, *args, **kwargs): index = torch.zeros(precision_shape, dtype=torch.long) # Collect local offsets. + upstreams = [ + u for u in self.dependencies[d] if not self._factors[u]["is_observed"] + ] local_offsets = {} pos = 0 - for u in self.dependencies[d]: + for u in upstreams: local_offsets[u] = pos broken_plates = self._plates[u] - self._plates[d] pos += self._event_numel[u] * _plates_to_shape(broken_plates).numel() # Create indices blockwise. - for u, v in itertools.product(self.dependencies[d], self.dependencies[d]): + for u, v in itertools.product(upstreams, upstreams): u_index = global_indices[u] v_index = global_indices[v] @@ -333,17 +367,17 @@ def _setup_prototype(self, *args, **kwargs): self._dense_scatter[d] = index.reshape(-1) def _sample_aux_values(self) -> Dict[str, torch.Tensor]: - # Sample from a dense joint Gaussian over flattened variables. - precision = self._get_precision() - loc = precision.new_zeros(self._dense_size) flat_samples = pyro.sample( - f"_{self._pyro_name}", - dist.MultivariateNormal(loc, precision_matrix=precision), + f"_{self._pyro_name}_latent", + self._dense_get_mvn(), infer={"is_auxiliary": True}, ) - sample_shape = flat_samples.shape[:-1] + samples = self._dense_unflatten(flat_samples) + return samples - # Convert flat to shaped tensors. + def _dense_unflatten(self, flat_samples: torch.Tensor) -> Dict[str, torch.Tensor]: + # Convert a single flattened sample to a dict of shaped samples. + sample_shape = flat_samples.shape[:-1] samples = {} pos = 0 for d, (batch_shape, event_shape) in self._dense_shapes.items(): @@ -356,14 +390,25 @@ def _sample_aux_values(self) -> Dict[str, torch.Tensor]: ) return samples - def _get_precision(self): + def _dense_flatten(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor: + # Convert a dict of shaped samples single flattened sample. + flat_samples = [] + for d, (batch_shape, event_shape) in self._dense_shapes.items(): + shape = samples[d].shape + sample_shape = shape[: len(shape) - len(batch_shape) - len(event_shape)] + flat_samples.append(samples[d].reshape(sample_shape + (-1,))) + return torch.cat(flat_samples, dim=-1) + + def _dense_get_mvn(self): + # Create a dense joint Gaussian over flattened variables. flat_precision = torch.zeros(self._dense_size ** 2) for d, index in self._dense_scatter.items(): sqrt = deep_getattr(self.factors, d) precision = sqrt @ sqrt.transpose(-1, -2) flat_precision.scatter_add_(0, index, precision.reshape(-1)) precision = flat_precision.reshape(self._dense_size, self._dense_size) - return precision + loc = precision.new_zeros(self._dense_size) + return dist.MultivariateNormal(loc, precision_matrix=precision) class AutoGaussianFunsor(AutoGaussian): @@ -403,11 +448,13 @@ def _setup_prototype(self, *args, **kwargs): plate_to_dim: Dict[str, int] = {} for d, site in self._factors.items(): inputs = OrderedDict() - for f in sorted(site["cond_indep_stack"], key=lambda f: f.dim): + for f in sorted(self._plates[d], key=lambda f: f.dim): plate_to_dim[f.name] = f.dim inputs[f.name] = funsor.Bint[f.size] eliminate.add(f.name) for u in self.dependencies[d]: + if self._factors[u]["is_observed"]: + continue inputs[u] = funsor.Reals[self._unconstrained_event_shapes[u]] eliminate.add(u) factor_inputs[d] = inputs diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 27a7fed53a..27d6ac55af 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -149,9 +149,11 @@ def _create_plates(self, *args, **kwargs): self.plates = self.master().plates return self.plates + _prototype_hide_fn = staticmethod(prototype_hide_fn) + def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure - model = poutine.block(self.model, prototype_hide_fn) + model = poutine.block(self.model, self._prototype_hide_fn) self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( *args, **kwargs ) @@ -193,9 +195,7 @@ class AutoGuideList(AutoGuide, nn.ModuleList): """ def _check_prototype(self, part_trace): - for name, part_site in part_trace.nodes.items(): - if part_site["type"] != "sample": - continue + for name, part_site in part_trace.iter_stochastic_nodes(): self_site = self.prototype_trace.nodes[name] assert part_site["fn"].batch_shape == self_site["fn"].batch_shape assert part_site["fn"].event_shape == self_site["fn"].event_shape @@ -1187,7 +1187,7 @@ class AutoDiscreteParallel(AutoGuide): def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure - model = poutine.block(config_enumerate(self.model), prototype_hide_fn) + model = poutine.block(config_enumerate(self.model), self._prototype_hide_fn) self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( *args, **kwargs ) diff --git a/tests/infer/autoguide/test_gaussian.py b/tests/infer/autoguide/test_gaussian.py index d0f79b0622..a0f2dec2a9 100644 --- a/tests/infer/autoguide/test_gaussian.py +++ b/tests/infer/autoguide/test_gaussian.py @@ -10,7 +10,7 @@ import pyro.distributions as dist import pyro.poutine as poutine from pyro.infer import SVI, JitTrace_ELBO, Predictive, Trace_ELBO -from pyro.infer.autoguide import AutoGaussian +from pyro.infer.autoguide import AutoGaussian, AutoGuideList from pyro.infer.autoguide.gaussian import ( AutoGaussianDense, AutoGaussianFunsor, @@ -18,7 +18,7 @@ ) from pyro.infer.reparam import LocScaleReparam from pyro.optim import Adam -from tests.common import assert_equal, xfail_if_not_implemented +from tests.common import assert_close, assert_equal, xfail_if_not_implemented BACKENDS = [ "dense", @@ -81,65 +81,253 @@ def model(): guide = AutoGaussian(model, backend=backend) if backend == "dense": assert isinstance(guide, AutoGaussianDense) + guide = AutoGaussianDense(model) + assert isinstance(guide, AutoGaussianDense) elif backend == "funsor": assert isinstance(guide, AutoGaussianFunsor) + guide = AutoGaussianFunsor(model) + assert isinstance(guide, AutoGaussianFunsor) else: raise ValueError(f"Unknown backend: {backend}") -def check_structure(model, expected_str): +def check_structure(model, expected_str, expected_dependencies=None): guide = AutoGaussian(model, backend="dense") guide() # initialize + if expected_dependencies is not None: + assert guide.dependencies == expected_dependencies # Inject random noise into all unconstrained parameters. for parameter in guide.parameters(): parameter.data.normal_() with torch.no_grad(): - precision = guide._get_precision() + # Check flatten & unflatten. + mvn = guide._dense_get_mvn() + expected = mvn.sample() + samples = guide._dense_unflatten(expected) + actual = guide._dense_flatten(samples) + assert_equal(actual, expected) + + # Check sparsity structure. + precision = mvn.precision_matrix actual = precision.abs().gt(1e-5).long() + str_to_number = {"?": 1, ".": 0} + expected = torch.tensor( + [[str_to_number[c] for c in row if c != " "] for row in expected_str] + ) + assert (actual == expected).all() + + +def check_backends_agree(model): + guide1 = AutoGaussian(model, backend="dense") + guide2 = AutoGaussian(model, backend="funsor") + guide1() + with xfail_if_not_implemented(): + guide2() - str_to_number = {"?": 1, ".": 0} - expected = torch.tensor( - [[str_to_number[c] for c in row if c != " "] for row in expected_str] + # Inject random noise into all unconstrained parameters. + params1 = dict(guide1.named_parameters()) + params2 = dict(guide2.named_parameters()) + assert set(params1) == set(params2) + for k, v in params1.items(): + v.data.normal_() + params2[k].data.copy_(v.data) + names = sorted(params1) + + # Check densities agree between backends. + with torch.no_grad(), poutine.trace() as tr: + aux = guide2._sample_aux_values() + flat = guide1._dense_flatten(aux) + tr.trace.compute_log_prob() + log_prob_funsor = tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"] + with torch.no_grad(), poutine.trace() as tr: + with poutine.condition(data={"_AutoGaussianDense_latent": flat}): + guide1._sample_aux_values() + tr.trace.compute_log_prob() + log_prob_dense = tr.trace.nodes["_AutoGaussianDense_latent"]["log_prob"] + assert_equal(log_prob_funsor, log_prob_dense) + + # Check Monte Carlo estimate of entropy. + entropy1 = guide1._dense_get_mvn().entropy() + with pyro.plate("particle", 100000, dim=-3), poutine.trace() as tr: + guide2._sample_aux_values() + tr.trace.compute_log_prob() + entropy2 = -tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"].mean() + assert_close(entropy1, entropy2, atol=1e-2) + grads1 = torch.autograd.grad( + entropy1, [params1[k] for k in names], allow_unused=True ) - assert_equal(actual, expected) + grads2 = torch.autograd.grad( + entropy2, [params2[k] for k in names], allow_unused=True + ) + for name, grad1, grad2 in zip(names, grads1, grads2): + # Gradients should agree to very high precision. + assert_close(grad1, grad2, msg=f"{name}:\n{grad1} vs {grad2}") + + # Check elbos agree between backends. + elbo = Trace_ELBO(num_particles=100000, vectorize_particles=True) + loss1 = elbo.differentiable_loss(model, guide1) + loss2 = elbo.differentiable_loss(model, guide2) + assert_close(loss1, loss2, atol=1e-2, rtol=0.05) + grads1 = torch.autograd.grad(loss1, [params1[k] for k in names], allow_unused=True) + grads2 = torch.autograd.grad(loss2, [params2[k] for k in names], allow_unused=True) + for name, grad1, grad2 in zip(names, grads1, grads2): + assert_close( + grad1, grad2, atol=0.05, rtol=0.05, msg=f"{name}:\n{grad1} vs {grad2}" + ) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_0(backend): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + pyro.sample("b", dist.Normal(a, 1), obs=torch.ones(())) + + # size = 1 + structure = [ + "?", + ] + dependencies = { + "a": {"a": set()}, + "b": {"b": set(), "a": set()}, + } + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure, dependencies) -def test_structure_1(): +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_1(backend): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("i", 3): + pyro.sample("b", dist.Normal(a, 1), obs=torch.ones(3)) + + # size = 1 + structure = [ + "?", + ] + dependencies = { + "a": {"a": set()}, + "b": {"b": set(), "a": set()}, + } + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure, dependencies) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_2(backend): def model(): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(a, 1)) c = pyro.sample("c", dist.Normal(b, 1)) - pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(1.0)) - expected = [ + # size = 1 + 1 + 1 = 3 + structure = [ "? ? .", "? ? ?", ". ? ?", ] - check_structure(model, expected) + dependencies = { + "a": {"a": set()}, + "b": {"b": set(), "a": set()}, + "c": {"c": set(), "b": set()}, + "d": {"c": set(), "d": set()}, + } + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure, dependencies) -def test_structure_2(): +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_3(backend): + def model(): + with pyro.plate("i", 2): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.Normal(a, 1)) + c = pyro.sample("c", dist.Normal(b, 1)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(1.0)) + + # size = 2 + 2 + 2 = 6 + structure = [ + "? . ? . . .", + ". ? . ? . .", + "? . ? . ? .", + ". ? . ? . ?", + ". . ? . ? .", + ". . . ? . ?", + ] + dependencies = { + "a": {"a": set()}, + "b": {"b": set(), "a": set()}, + "c": {"c": set(), "b": set()}, + "d": {"c": set(), "d": set()}, + } + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure, dependencies) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_4(backend): + def model(): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("i", 2): + b = pyro.sample("b", dist.Normal(a, 1)) + c = pyro.sample("c", dist.Normal(b, 1)) + pyro.sample("d", dist.Normal(c.sum(), 1), obs=torch.tensor(1.0)) + + # size = 1 + 2 + 2 = 5 + structure = [ + "? ? ? . .", + "? ? . ? .", + "? . ? . ?", + ". ? . ? ?", + ". . ? ? ?", + ] + dependencies = { + "a": {"a": set()}, + "b": {"b": set(), "a": set()}, + "c": {"c": set(), "b": set()}, + "d": {"c": set(), "d": set()}, + } + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure, dependencies) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_5(backend): def model(): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(0, 1)) with pyro.plate("i", 2): c = pyro.sample("c", dist.Normal(a, b.exp())) - pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(1.0)) # size = 1 + 1 + 2 = 4 - expected = [ - "? . ? ?", - ". ? ? ?", + structure = [ + "? ? ? ?", + "? ? ? ?", "? ? ? .", "? ? . ?", ] - check_structure(model, expected) + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure) -def test_structure_3(): +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_6(backend): I, J = 2, 3 def model(): @@ -151,15 +339,15 @@ def model(): x = pyro.sample("x", dist.Normal(0, 1)) with i_plate, j_plate: y = pyro.sample("y", dist.Normal(w, x.exp())) - pyro.sample("z", dist.Normal(0, 1), obs=y) + pyro.sample("z", dist.Normal(1, 1), obs=y) # size = 2 + 3 + 2 * 3 = 2 + 3 + 6 = 11 - expected = [ - "? . . . . ? . ? . ? .", - ". ? . . . . ? . ? . ?", - ". . ? . . ? ? . . . .", - ". . . ? . . . ? ? . .", - ". . . . ? . . . . ? ?", + structure = [ + "? . ? ? ? ? . ? . ? .", + ". ? ? ? ? . ? . ? . ?", + "? ? ? . . ? ? . . . .", + "? ? . ? . . . ? ? . .", + "? ? . . ? . . . . ? ?", "? . ? . . ? . . . . .", ". ? ? . . . ? . . . .", "? . . ? . . . ? . . .", @@ -167,10 +355,14 @@ def model(): "? . . . ? . . . . ? .", ". ? . . ? . . . . . ?", ] - check_structure(model, expected) + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure) -def test_structure_4(): +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_7(backend): I, J = 2, 3 def model(): @@ -182,37 +374,44 @@ def model(): with j_plate: c = pyro.sample("c", dist.Normal(b.mean(), 1)) d = pyro.sample("d", dist.Normal(c.mean(), 1)) - pyro.sample("e", dist.Normal(0, 1), obs=d) + pyro.sample("e", dist.Normal(1, 1), obs=d) # size = 1 + 2 + 3 + 1 = 7 - expected = [ + structure = [ "? ? ? . . . .", - "? ? . ? ? ? .", - "? . ? ? ? ? .", - ". ? ? ? . . ?", - ". ? ? . ? . ?", - ". ? ? . . ? ?", + "? ? ? ? ? ? .", + "? ? ? ? ? ? .", + ". ? ? ? ? ? ?", + ". ? ? ? ? ? ?", + ". ? ? ? ? ? ?", ". . . ? ? ? ?", ] - check_structure(model, expected) + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure) -def test_structure_5(): +@pytest.mark.parametrize("backend", BACKENDS) +def test_structure_8(backend): def model(): i_plate = pyro.plate("i", 2, dim=-1) with i_plate: a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(a.mean(-1), 1)) with i_plate: - pyro.sample("c", dist.Normal(b, 1), obs=torch.zeros(2)) + pyro.sample("c", dist.Normal(b, 1), obs=torch.ones(2)) # size = 2 + 1 = 3 - expected = [ - "? . ?", - ". ? ?", + structure = [ + "? ? ?", + "? ? ?", "? ? ?", ] - check_structure(model, expected) + if backend == "funsor": + check_backends_agree(model) + else: + check_structure(model, structure) @pytest.mark.parametrize("backend", BACKENDS) @@ -300,6 +499,53 @@ def pyrocov_model(dataset): ) +# This is modified by relaxing rate from deterministic to latent. +def pyrocov_model_relaxed(dataset): + # Tensor shapes are commented at the end of some lines. + features = dataset["features"] + local_time = dataset["local_time"][..., None] # [T, P, 1] + T, P, _ = local_time.shape + S, F = features.shape + weekly_strains = dataset["weekly_strains"] + assert weekly_strains.shape == (T, P, S) + + # Sample global random variables. + coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] + rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))[..., None] + rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] + init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None] + init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None] + + # Assume relative growth rate depends strongly on mutations and weakly on place. + coef_loc = torch.zeros(F) + coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1)) # [F] + rate_loc = pyro.sample( + "rate_loc", + dist.Normal(0.01 * coef @ features.T, rate_loc_scale).to_event(1), + ) # [S] + + # Assume initial infections depend strongly on strain and place. + init_loc = pyro.sample( + "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1) + ) # [S] + with pyro.plate("place", P, dim=-1): + rate = pyro.sample( + "rate", dist.Normal(rate_loc, rate_scale).to_event(1) + ) # [P, S] + init = pyro.sample( + "init", dist.Normal(init_loc, init_scale).to_event(1) + ) # [P, S] + + # Finally observe counts. + with pyro.plate("time", T, dim=-2): + logits = init + rate * local_time # [T, P, S] + pyro.sample( + "obs", + dist.Multinomial(logits=logits, validate_args=False), + obs=weekly_strains, + ) + + # This is modified by more precisely tracking plates for features and strains. def pyrocov_model_plated(dataset): # Tensor shapes are commented at the end of some lines. @@ -390,16 +636,36 @@ def pyrocov_model_poisson(dataset): pyro.sample("obs", dist.Poisson(logits), obs=weekly_strains) +class PoissonGuide(AutoGuideList): + def __init__(self, model, backend): + super().__init__(model) + self.append( + AutoGaussian(poutine.block(model, hide_fn=self.hide_fn_1), backend=backend) + ) + self.append( + AutoGaussian(poutine.block(model, hide_fn=self.hide_fn_2), backend=backend) + ) + + @staticmethod + def hide_fn_1(msg): + return msg["type"] == "sample" and "pois" in msg["name"] + + @staticmethod + def hide_fn_2(msg): + return msg["type"] == "sample" and "pois" not in msg["name"] + + PYRO_COV_MODELS = [ - pyrocov_model, - pyrocov_model_plated, - pyrocov_model_poisson, + (pyrocov_model, AutoGaussian), + (pyrocov_model_relaxed, AutoGaussian), + (pyrocov_model_plated, AutoGaussian), + (pyrocov_model_poisson, PoissonGuide), ] -@pytest.mark.parametrize("model", PYRO_COV_MODELS) +@pytest.mark.parametrize("model, Guide", PYRO_COV_MODELS) @pytest.mark.parametrize("backend", BACKENDS) -def test_pyrocov_smoke(model, backend): +def test_pyrocov_smoke(model, Guide, backend): T, P, S, F = 3, 4, 5, 6 dataset = { "features": torch.randn(S, F), @@ -407,7 +673,7 @@ def test_pyrocov_smoke(model, backend): "weekly_strains": torch.randn(T, P, S).exp().round(), } - guide = AutoGaussian(model, backend=backend) + guide = Guide(model, backend=backend) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) for step in range(2): with xfail_if_not_implemented(): @@ -417,9 +683,9 @@ def test_pyrocov_smoke(model, backend): predictive(dataset) -@pytest.mark.parametrize("model", PYRO_COV_MODELS) +@pytest.mark.parametrize("model, Guide", PYRO_COV_MODELS) @pytest.mark.parametrize("backend", BACKENDS) -def test_pyrocov_reparam(model, backend): +def test_pyrocov_reparam(model, Guide, backend): T, P, S, F = 2, 3, 4, 5 dataset = { "features": torch.randn(S, F), @@ -436,7 +702,7 @@ def test_pyrocov_reparam(model, backend): "init": LocScaleReparam(), } model = poutine.reparam(model, config) - guide = AutoGaussian(model, backend=backend) + guide = Guide(model, backend=backend) svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO()) for step in range(2): with xfail_if_not_implemented(): @@ -457,30 +723,27 @@ def test_pyrocov_structure(): "weekly_strains": torch.randn(T, P, S).exp().round(), } - guide = AutoGaussian(pyrocov_model_poisson, backend="funsor") + guide = PoissonGuide(pyrocov_model_poisson, backend="funsor") guide(dataset) # initialize + guide = guide[0] # pull out AutoGaussian part of PoissonGuide - expected_plates = frozenset(["time", "place", "strain"]) + expected_plates = frozenset(["place", "strain"]) assert guide._funsor_plates == expected_plates expected_eliminate = frozenset( [ - "time", - "place", - "strain", + "coef", "coef_scale", - "rate_loc_scale", - "rate_scale", + "init", + "init_loc", "init_loc_scale", "init_scale", - "coef", - "rate_loc", - "init_loc", + "place", "rate", - "init", - "pois_loc", - "pois_scale", - "pois", + "rate_loc", + "rate_loc_scale", + "rate_scale", + "strain", ] ) assert guide._funsor_eliminate == expected_eliminate @@ -491,8 +754,6 @@ def test_pyrocov_structure(): "rate_scale": OrderedDict([("rate_scale", Real)]), "init_loc_scale": OrderedDict([("init_loc_scale", Real)]), "init_scale": OrderedDict([("init_scale", Real)]), - "pois_loc": OrderedDict([("pois_loc", Real)]), - "pois_scale": OrderedDict([("pois_scale", Real)]), "coef": OrderedDict([("coef", Reals[5]), ("coef_scale", Real)]), "rate_loc": OrderedDict( [ @@ -503,7 +764,11 @@ def test_pyrocov_structure(): ] ), "init_loc": OrderedDict( - [("strain", Bint[4]), ("init_loc", Real), ("init_loc_scale", Real)] + [ + ("strain", Bint[4]), + ("init_loc", Real), + ("init_loc_scale", Real), + ] ), "rate": OrderedDict( [ @@ -523,13 +788,12 @@ def test_pyrocov_structure(): ("init_loc", Real), ] ), - "pois": OrderedDict( + "obs": OrderedDict( [ - ("time", Bint[2]), ("place", Bint[3]), - ("pois", Real), - ("pois_loc", Real), - ("pois_scale", Real), + ("strain", Bint[4]), + ("rate", Real), + ("init", Real), ] ), } @@ -552,7 +816,7 @@ def test_profile(backend, jit, n=1, num_steps=1, log_every=1): } print("Initializing guide") - guide = AutoGaussian(model, backend=backend) + guide = PoissonGuide(model, backend=backend) guide(dataset) # initialize print("Parameter shapes:") for name, param in guide.named_parameters(): diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 399eda3d6e..73711a4280 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -1309,9 +1309,10 @@ def model(data): elbo = JitTrace_ELBO( num_particles=100, vectorize_particles=True, ignore_jit_warnings=True ) - optim = Adam({"lr": 0.01}) + num_steps = 500 + optim = ClippedAdam({"lr": 0.05, "lrd": 0.1 ** (1 / num_steps)}) svi = SVI(model, guide, optim, elbo) - for step in range(500): + for step in range(num_steps): svi.step(data) guide.requires_grad_(False)