Skip to content

Commit

Permalink
Use SMC to initialize compartmental models (#2452)
Browse files Browse the repository at this point in the history
* WIP Sketch SMC initialization for compartmental models

* Get SMC heuristic working

* Remove old heuristics

* Reword docs

* nits
  • Loading branch information
fritzo authored Apr 28, 2020
1 parent cf0f0a8 commit 204e9e6
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 53 deletions.
73 changes: 69 additions & 4 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pyro.distributions.hmm
import pyro.poutine as poutine
from pyro.distributions.transforms import DiscreteCosineTransform
from pyro.infer import MCMC, NUTS, infer_discrete
from pyro.infer import MCMC, NUTS, SMCFilter, infer_discrete
from pyro.infer.autoguide import init_to_value
from pyro.infer.reparam import DiscreteCosineReparam
from pyro.util import warn_if_nan
Expand Down Expand Up @@ -101,17 +101,40 @@ def __init__(self, compartments, duration, population, *,
series = ()
full_mass = False

@abstractmethod
def heuristic(self):
@torch.no_grad()
def heuristic(self, num_particles=1024):
"""
Finds an initial feasible guess of all latent variables, consistent
with observed data. This is needed because not all hypotheses are
feasible and HMC needs to start at a feasible solution to progress.
The default implementation attempts to find a feasible state using
:class:`~pyro.infer.smcfilter.SMCFilter` with proprosals from the
prior. However this method may be overridden in cases where SMC
performs poorly e.g. in high-dimensional models.
:param int num_particles: Number of particles used for SMC.
:returns: A dictionary mapping sample site name to tensor value.
:rtype: dict
"""
raise NotImplementedError
# Run SMC.
model = _SMCModel(self)
guide = _SMCGuide(self)
smc = SMCFilter(model, guide, num_particles=num_particles,
max_plate_nesting=self.max_plate_nesting)
smc.init()
for t in range(1, self.duration):
smc.step()

# Select the most probably hypothesis.
i = int(smc.state._log_weights.max(0).indices)
init = {key: value[i] for key, value in smc.state.items()}

# Fill in sample site values.
init = self.generate(init)
init["auxiliary"] = torch.stack(
[init[name] for name in self.compartments]).clamp_(min=0.5)
return init

def global_model(self):
"""
Expand Down Expand Up @@ -433,3 +456,45 @@ def enum_shape(position):
logp = logp.reshape(-1).logsumexp(0)
warn_if_nan(logp)
pyro.factor("transition", logp)


class _SMCModel:
"""
Helper to initialize a CompartmentalModel to a feasible initial state.
"""
def __init__(self, model):
assert isinstance(model, CompartmentalModel)
self.model = model

def init(self, state):
with poutine.trace() as tr:
params = self.model.global_model()
for name, site in tr.trace.nodes.items():
if site["type"] == "sample":
state[name] = site["value"]

self.t = 0
state.update(self.model.initialize(params))
self.step(state) # Take one step since model.initialize is deterministic.

def step(self, state):
with poutine.block(), poutine.condition(data=state):
params = self.model.global_model()
with poutine.trace() as tr:
self.model.transition_fwd(params, state, self.t)
for name, site in tr.trace.nodes.items():
if site["type"] == "sample" and not site["is_observed"]:
state[name] = site["value"]
self.t += 1


class _SMCGuide(_SMCModel):
"""
Like _SMCModel but does not update state and does not observe.
"""
def init(self, state):
super().init(state.copy())

def step(self, state):
with poutine.block(hide_types=["observe"]):
super().step(state.copy())
28 changes: 0 additions & 28 deletions pyro/contrib/epidemiology/seir.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.nn.functional import pad

import pyro
import pyro.distributions as dist
from pyro.ops.tensor_utils import convolve

from .compartmental import CompartmentalModel
from .distributions import infection_dist
Expand Down Expand Up @@ -51,30 +47,6 @@ def __init__(self, population, incubation_time, recovery_time, data):
series = ("S2E", "E2I", "I2R", "obs")
full_mass = [("R0", "rho")]

def heuristic(self):
T = len(self.data)
# Start with a single exposure.
S0 = self.population - 1
# Assume 50% <= response rate <= 100%.
E2I = self.data * min(2., (S0 / self.data.sum()).sqrt())
# Assume recovery less than a month.
recovery = torch.arange(30.).div(self.recovery_time).neg().exp()
I_aux = convolve(E2I, recovery)[:T]
# Assume incubation takes less than a month.
incubation = torch.arange(30.).div(self.incubation_time).exp()
incubation = pad(incubation, (0, 1), value=0)
incubation /= incubation.sum()
S2E = convolve(E2I, incubation)
S2E_cumsum = S2E[:-T].sum() + S2E[-T:].cumsum(-1)
S_aux = S0 - S2E_cumsum
E_aux = S2E_cumsum - E2I.cumsum(-1)

return {
"R0": torch.tensor(2.0),
"rho": torch.tensor(0.5),
"auxiliary": torch.stack([S_aux, E_aux, I_aux]).clamp(min=0.5),
}

def global_model(self):
tau_e = self.incubation_time
tau_i = self.recovery_time
Expand Down
21 changes: 0 additions & 21 deletions pyro/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

import pyro
import pyro.distributions as dist
from pyro.ops.tensor_utils import convolve

from .compartmental import CompartmentalModel
from .distributions import infection_dist
Expand Down Expand Up @@ -44,24 +41,6 @@ def __init__(self, population, recovery_time, data):
series = ("S2I", "I2R", "obs")
full_mass = [("R0", "rho")]

def heuristic(self):
# Start with a single infection.
S0 = self.population - 1
# Assume 50% <= response rate <= 100%.
S2I = self.data * min(2., (S0 / self.data.sum()).sqrt())
S_aux = S0 - S2I.cumsum(-1)
# Account for the single initial infection.
S2I[0] += 1
# Assume infection lasts less than a month.
recovery = torch.arange(30.).div(self.recovery_time).neg().exp()
I_aux = convolve(S2I, recovery)[:len(self.data)]

return {
"R0": torch.tensor(2.0),
"rho": torch.tensor(0.5),
"auxiliary": torch.stack([S_aux, I_aux]).clamp(min=0.5),
}

def global_model(self):
tau = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
Expand Down

0 comments on commit 204e9e6

Please sign in to comment.