Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use systematic resampling in SMCFilter #2488

Merged
merged 4 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pyro.infer.autoguide import init_to_generated, init_to_value
from pyro.infer.mcmc import ArrowheadMassMatrix
from pyro.infer.reparam import DiscreteCosineReparam
from pyro.infer.smcfilter import SMCFailed
from pyro.util import warn_if_nan

from .distributions import set_approx_sample_thresh
Expand Down Expand Up @@ -144,7 +145,7 @@ def _clear_plates(self):

@torch.no_grad()
@set_approx_sample_thresh(1000)
def heuristic(self, num_particles=1024, ess_threshold=0.5):
def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10):
"""
Finds an initial feasible guess of all latent variables, consistent
with observed data. This is needed because not all hypotheses are
Expand All @@ -163,12 +164,20 @@ def heuristic(self, num_particles=1024, ess_threshold=0.5):
# Run SMC.
model = _SMCModel(self)
guide = _SMCGuide(self)
smc = SMCFilter(model, guide, num_particles=num_particles,
ess_threshold=ess_threshold,
max_plate_nesting=self.max_plate_nesting)
smc.init()
for t in range(1, self.duration):
smc.step()
for attempt in range(1, 1 + retries):
smc = SMCFilter(model, guide, num_particles=num_particles,
ess_threshold=ess_threshold,
max_plate_nesting=self.max_plate_nesting)
try:
smc.init()
for t in range(1, self.duration):
smc.step()
break
except SMCFailed as e:
if attempt == retries:
raise
logger.info("{}. Retrying...".format(e))
continue

# Select the most probable hypothesis.
i = int(smc.state._log_weights.max(0).indices)
Expand Down Expand Up @@ -309,7 +318,6 @@ def fit(self, **options):
if k.startswith("heuristic_")}

def heuristic():
logger.info("Heuristically initializing...")
with poutine.block():
init_values = self.heuristic(**heuristic_options)
assert isinstance(init_values, dict)
Expand All @@ -321,6 +329,10 @@ def heuristic():
x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x)
x = DiscreteCosineTransform(smooth=self._dct)(x)
init_values["auxiliary_dct"] = x
logger.info("Heuristic init: {}".format(", ".join(
"{}={:0.3g}".format(k, v)
for k, v in init_values.items()
if v.numel() == 1)))
return init_to_value(values=init_values)

# Configure a kernel.
Expand Down
39 changes: 30 additions & 9 deletions pyro/infer/smcfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
from pyro.poutine.util import prune_subsample_sites


class SMCFailed(ValueError):
"""
Exception raised when :class:`SMCFilter` fails to find any hypothesis with
nonzero probability.
"""
pass


class SMCFilter:
"""
:class:`SMCFilter` is the top-level interface for filtering via sequential
Expand Down Expand Up @@ -112,16 +120,16 @@ def _update_weights(self, model_trace, guide_trace):
log_q = guide_site["log_prob"].reshape(self.num_particles, -1).sum(-1)
self.state._log_weights += log_p - log_q
if not (self.state._log_weights.max() > -math.inf):
raise ValueError("Failed to find feasible hypothesis after site {}"
.format(name))
raise SMCFailed("Failed to find feasible hypothesis after site {}"
.format(name))

for site in model_trace.nodes.values():
if site["type"] == "sample" and site["is_observed"]:
log_p = site["log_prob"].reshape(self.num_particles, -1).sum(-1)
self.state._log_weights += log_p
if not (self.state._log_weights.max() > -math.inf):
raise ValueError("Failed to find feasible hypothesis after site {}"
.format(site["name"]))
raise SMCFailed("Failed to find feasible hypothesis after site {}"
.format(site["name"]))

self.state._log_weights -= self.state._log_weights.max()

Expand All @@ -130,16 +138,29 @@ def _maybe_importance_resample(self):
return
# Decide whether to resample based on ESS.
logp = self.state._log_weights
logp -= logp.logsumexp(dim=-1)
ess = logp.mul(2).exp().sum().reciprocal()
logp -= logp.logsumexp(-1)
probs = logp.exp()
ess = probs.dot(probs).reciprocal()
if ess < self.ess_threshold * self.num_particles:
self._importance_resample()
self._importance_resample(probs)

def _importance_resample(self):
index = dist.Categorical(logits=self.state._log_weights).sample(sample_shape=(self.num_particles,))
def _importance_resample(self, probs):
index = _systematic_sample(probs)
self.state._resample(index)


def _systematic_sample(probs):
# Systematic sampling preserves diversity better than multinomial sampling
# via Categorical(probs).sample().
batch_shape, size = probs.shape[:-1], probs.size(-1)
n = probs.cumsum(-1).mul_(size).add_(torch.rand(batch_shape + (1,)))
n = n.floor_().clamp_(min=0, max=size).long()
diff = probs.new_zeros(batch_shape + (size + 1,))
diff.scatter_add_(-1, n, torch.ones_like(probs))
index = diff[..., :-1].cumsum(-1).long()
return index


class SMCState(dict):
"""
Dictionary-like object to hold a vectorized collection of tensors to
Expand Down
18 changes: 18 additions & 0 deletions tests/infer/test_smcfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,27 @@
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import SMCFilter
from pyro.infer.smcfilter import _systematic_sample
from tests.common import assert_close


@pytest.mark.parametrize("size", range(1, 32))
def test_systematic_sample(size):
pyro.set_rng_seed(size)
probs = torch.randn(size).exp()
probs /= probs.sum()

num_samples = 20000
index = _systematic_sample(probs.expand(num_samples, size))
histogram = torch.zeros_like(probs)
histogram.scatter_add_(-1, index.reshape(-1),
probs.new_ones(1).expand(num_samples * size))

expected = probs * size
actual = histogram / num_samples
assert_close(actual, expected, atol=0.01)


class SmokeModel:

def __init__(self, state_size, plate_size):
Expand Down