From 81f163bffbb66c687d8f4d6eae5ba2351cf818de Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 16 May 2020 14:54:53 -0700 Subject: [PATCH 1/4] Use systematic resampling in SMCFilter --- pyro/infer/smcfilter.py | 24 +++++++++++++++++++----- tests/infer/test_smcfilter.py | 18 ++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/pyro/infer/smcfilter.py b/pyro/infer/smcfilter.py index 2199a89f80..3cb49592c1 100644 --- a/pyro/infer/smcfilter.py +++ b/pyro/infer/smcfilter.py @@ -130,16 +130,30 @@ def _maybe_importance_resample(self): return # Decide whether to resample based on ESS. logp = self.state._log_weights - logp -= logp.logsumexp(dim=-1) - ess = logp.mul(2).exp().sum().reciprocal() + logp -= logp.logsumexp(-1) + probs = logp.exp() + ess = probs.dot(probs).reciprocal() + print(f"DEBUG ess = {ess:0.1f}") if ess < self.ess_threshold * self.num_particles: - self._importance_resample() + self._importance_resample(probs) - def _importance_resample(self): - index = dist.Categorical(logits=self.state._log_weights).sample(sample_shape=(self.num_particles,)) + def _importance_resample(self, probs): + index = _systematic_sample(probs) self.state._resample(index) +def _systematic_sample(probs): + # Systematic sampling preserves diversity better than multinomial sampling + # via Categorical(probs).sample(). + batch_shape, size = probs.shape[:-1], probs.size(-1) + n = probs.cumsum(-1).mul_(size).add_(torch.rand(batch_shape + (1,))) + n = n.floor_().clamp_(min=0, max=size).long() + diff = probs.new_zeros(batch_shape + (size + 1,)) + diff.scatter_add_(-1, n, torch.ones_like(probs)) + index = diff[..., :-1].cumsum(-1).long() + return index + + class SMCState(dict): """ Dictionary-like object to hold a vectorized collection of tensors to diff --git a/tests/infer/test_smcfilter.py b/tests/infer/test_smcfilter.py index 1c180a656e..4073496f10 100644 --- a/tests/infer/test_smcfilter.py +++ b/tests/infer/test_smcfilter.py @@ -8,9 +8,27 @@ import pyro.distributions as dist import pyro.poutine as poutine from pyro.infer import SMCFilter +from pyro.infer.smcfilter import _systematic_sample from tests.common import assert_close +@pytest.mark.parametrize("size", range(1, 32)) +def test_systematic_sample(size): + pyro.set_rng_seed(size) + probs = torch.randn(size).exp() + probs /= probs.sum() + + num_samples = 20000 + index = _systematic_sample(probs.expand(num_samples, size)) + histogram = torch.zeros_like(probs) + histogram.scatter_add_(-1, index.reshape(-1), + probs.new_ones(1).expand(num_samples * size)) + + expected = probs * size + actual = histogram / num_samples + assert_close(actual, expected, atol=0.01) + + class SmokeModel: def __init__(self, state_size, plate_size): From cc829f0f2831aa609c48d0418b6d002954ce0741 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 16 May 2020 14:55:29 -0700 Subject: [PATCH 2/4] Remove debug statement --- pyro/infer/smcfilter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyro/infer/smcfilter.py b/pyro/infer/smcfilter.py index 3cb49592c1..0e59988d2c 100644 --- a/pyro/infer/smcfilter.py +++ b/pyro/infer/smcfilter.py @@ -133,7 +133,6 @@ def _maybe_importance_resample(self): logp -= logp.logsumexp(-1) probs = logp.exp() ess = probs.dot(probs).reciprocal() - print(f"DEBUG ess = {ess:0.1f}") if ess < self.ess_threshold * self.num_particles: self._importance_resample(probs) From 40909a5870efb6ad5a4efbc18a7b000ed83a9546 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 18 May 2020 08:25:26 -0700 Subject: [PATCH 3/4] Add retry logic to SMC for epidemiology --- pyro/contrib/epidemiology/compartmental.py | 28 +++++++++++++++------- pyro/infer/smcfilter.py | 16 +++++++++---- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index a95baf46e7..3f5d69aa31 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -20,6 +20,7 @@ from pyro.infer.autoguide import init_to_generated, init_to_value from pyro.infer.mcmc import ArrowheadMassMatrix from pyro.infer.reparam import DiscreteCosineReparam +from pyro.infer.smcfilter import SMCFailed from pyro.util import warn_if_nan from .distributions import set_approx_sample_thresh @@ -144,7 +145,7 @@ def _clear_plates(self): @torch.no_grad() @set_approx_sample_thresh(1000) - def heuristic(self, num_particles=1024, ess_threshold=0.5): + def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10): """ Finds an initial feasible guess of all latent variables, consistent with observed data. This is needed because not all hypotheses are @@ -163,12 +164,20 @@ def heuristic(self, num_particles=1024, ess_threshold=0.5): # Run SMC. model = _SMCModel(self) guide = _SMCGuide(self) - smc = SMCFilter(model, guide, num_particles=num_particles, - ess_threshold=ess_threshold, - max_plate_nesting=self.max_plate_nesting) - smc.init() - for t in range(1, self.duration): - smc.step() + for attempt in range(1, 1 + retries): + smc = SMCFilter(model, guide, num_particles=num_particles, + ess_threshold=ess_threshold, + max_plate_nesting=self.max_plate_nesting) + try: + smc.init() + for t in range(1, self.duration): + smc.step() + break + except SMCFailed as e: + if attempt == retries: + raise + logger.info("{}. Retrying...".format(e)) + continue # Select the most probable hypothesis. i = int(smc.state._log_weights.max(0).indices) @@ -309,7 +318,6 @@ def fit(self, **options): if k.startswith("heuristic_")} def heuristic(): - logger.info("Heuristically initializing...") with poutine.block(): init_values = self.heuristic(**heuristic_options) assert isinstance(init_values, dict) @@ -321,6 +329,10 @@ def heuristic(): x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x) x = DiscreteCosineTransform(smooth=self._dct)(x) init_values["auxiliary_dct"] = x + logger.info("Heuristic init: {}".format(", ".join( + "{}={:0.3g}".format(k, v) + for k, v in init_values.items() + if v.numel() == 1))) return init_to_value(values=init_values) # Configure a kernel. diff --git a/pyro/infer/smcfilter.py b/pyro/infer/smcfilter.py index 0e59988d2c..22d2360748 100644 --- a/pyro/infer/smcfilter.py +++ b/pyro/infer/smcfilter.py @@ -13,6 +13,14 @@ from pyro.poutine.util import prune_subsample_sites +class SMCFailed(ValueError): + """ + Exception raised when :class:`SMCFilter` fails to find any hypothesis with + nonzero probability. + """ + pass + + class SMCFilter: """ :class:`SMCFilter` is the top-level interface for filtering via sequential @@ -112,16 +120,16 @@ def _update_weights(self, model_trace, guide_trace): log_q = guide_site["log_prob"].reshape(self.num_particles, -1).sum(-1) self.state._log_weights += log_p - log_q if not (self.state._log_weights.max() > -math.inf): - raise ValueError("Failed to find feasible hypothesis after site {}" - .format(name)) + raise SMCFailed("Failed to find feasible hypothesis after site {}" + .format(name)) for site in model_trace.nodes.values(): if site["type"] == "sample" and site["is_observed"]: log_p = site["log_prob"].reshape(self.num_particles, -1).sum(-1) self.state._log_weights += log_p if not (self.state._log_weights.max() > -math.inf): - raise ValueError("Failed to find feasible hypothesis after site {}" - .format(site["name"])) + raise SMCFailed("Failed to find feasible hypothesis after site {}" + .format(site["name"])) self.state._log_weights -= self.state._log_weights.max() From 143a61a85f5db458889e74733f38cd56b7aa824c Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 18 May 2020 13:36:30 -0700 Subject: [PATCH 4/4] Fix formatting issue --- pyro/contrib/epidemiology/compartmental.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index 3f5d69aa31..0633b69110 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -330,7 +330,7 @@ def heuristic(): x = DiscreteCosineTransform(smooth=self._dct)(x) init_values["auxiliary_dct"] = x logger.info("Heuristic init: {}".format(", ".join( - "{}={:0.3g}".format(k, v) + "{}={:0.3g}".format(k, v.item()) for k, v in init_values.items() if v.numel() == 1))) return init_to_value(values=init_values)