Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example of sparsely observed SIR model #2457

Merged
merged 2 commits into from
May 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyro/contrib/epidemiology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from .compartmental import CompartmentalModel
from .distributions import infection_dist
from .seir import OverdispersedSEIRModel, SimpleSEIRModel
from .sir import OverdispersedSIRModel, SimpleSIRModel
from .sir import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel

__all__ = [
"CompartmentalModel",
"OverdispersedSEIRModel",
"OverdispersedSIRModel",
"SimpleSEIRModel",
"SimpleSIRModel",
"SparseSIRModel",
"infection_dist",
]
110 changes: 108 additions & 2 deletions pyro/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class SimpleSIRModel(CompartmentalModel):
:param int population: Total ``population = S + I + R``.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of new observed infections.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> I``
transitions. This allows false negative but no false positives.
Expand Down Expand Up @@ -136,7 +135,6 @@ class OverdispersedSIRModel(CompartmentalModel):
:param int population: Total ``population = S + I + R``.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of new observed infections.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> I``
transitions. This allows false negative but no false positives.
Expand Down Expand Up @@ -212,3 +210,111 @@ def transition_bwd(self, params, prev, curr, t):
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2I, rho),
obs=self.data[t])


class SparseSIRModel(CompartmentalModel):
"""
Susceptible-Infected-Recovered model with sparsely observed infections.

To customize this model we recommend forking and editing this class.

This is a stochastic discrete-time discrete-state model with four
compartments: "S" for susceptible, "I" for infected, and "R" for
recovered individuals (the recovered individuals are implicit: ``R =
population - S - I``) with transitions ``S -> I -> R``.

This model allows observations of **cumulative** infections at uneven time
intervals. To preserve Markov structure (and hence tractable inference)
this model adds an auxiliary compartment ``O`` denoting the fully-observed
cumulative number of observations at each time point. At observed times
(when ``mask[t] == True``) ``O`` must exactly match the provided data;
between observed times ``O`` stochastically imputes the provided data.

:param int population: Total ``population = S + I + R``.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of **cumulative** observed infections.
Whenever ``mask[t] == True``, ``data[t]`` corresponds to an
observation; otherwise ``data[t]`` can be arbitrary, e.g. NAN.
:param iterable mask: Boolean time series denoting whether an observation
is made at each time step. Should satisfy ``len(mask) == len(data)``.
"""

def __init__(self, population, recovery_time, data, mask):
assert len(data) == len(mask)
duration = len(data)
compartments = ("S", "I", "O") # O is auxiliary, R is implicit.
super().__init__(compartments, duration, population)

assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time

self.data = data
self.mask = mask

series = ("S2I", "I2R", "S2O", "obs")
full_mass = [("R0", "rho")]

def global_model(self):
tau = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
rho = pyro.sample("rho", dist.Uniform(0, 1))
return R0, tau, rho

def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "I": 1, "O": 0}

def transition_fwd(self, params, state, t):
R0, tau, rho = params

# Sample flows between compartments.
S2I = pyro.sample("S2I_{}".format(t),
infection_dist(individual_rate=R0 / tau,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population))
I2R = pyro.sample("I2R_{}".format(t),
dist.Binomial(state["I"], 1 / tau))
S2O = pyro.sample("S2O_{}".format(t),
dist.ExtendedBinomial(S2I, rho))

# Update compartments with flows.
state["S"] = state["S"] - S2I
state["I"] = state["I"] + S2I - I2R
state["O"] = state["O"] + S2O

# Condition on cumulative observations.
mask_t = self.mask[t] if t < self.duration else False
data_t = self.data[t] if t < self.duration else None
pyro.sample("obs_{}".format(t),
dist.Delta(state["O"]).mask(mask_t),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the auxiliary necessary? or does this just play better with the structure of CompartmentalModel?

Copy link
Member Author

@fritzo fritzo May 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl;dr The auxiliary variable is needed to preserve Markov structure.

The observations in this model are aggregated over intervals: obs=S2I[t_prev+1:t_curr+1].sum() where t_prev is the time of the last observation and t_curr is the time of the current observation. In our enumeration strategy, this would couple all t_curr-t_prev-many enumeration variables, growing exponentially in the number of variables. While the non-parallel-scan enumeration strategy could handle this without erroring, it would be prohibitively expensive, and would not allow e.g. large gaps in sensor data (as e.g. when a government shuts down or runs out of tests). The trick we're using is to add an auxiliary variable for the entire cumulative observation trajectory (with the same likelihood as in the usual SIR models), and then Delta-clamp that auxiliary to the true observations at a few sparse time steps. This makes more work for HMC adds one enumeration variable per time step and increases the complexity of variable elimination by a constant factor of Q**2, but crucially this factor is independent of gap size.

I had been struggling with this issue for a while since Lucy's model simulates 4 times per day but is observed only once. The only alternative I could see was to do parallel-scan variable elimination where each DiscreteHMM state covered the joint distribution over an entire day (four time steps), resulting in complexity Q**(2 * 4 * 2) for an SIR model or Q**(3 * 4 * 2) for an SEIR model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation. to clarify though: if all you had was occasional missing data you wouldn't need this construction. this is really for the cumulative case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. It appears the cumulative case is more common in epidemiology.

obs=data_t)

def transition_bwd(self, params, prev, curr, t):
R0, tau, rho = params

# Reverse the flow computation.
S2I = prev["S"] - curr["S"]
I2R = prev["I"] - curr["I"] + S2I
S2O = curr["O"] - prev["O"]

# Condition on flows between compartments.
pyro.sample("S2I_{}".format(t),
infection_dist(individual_rate=R0 / tau,
num_susceptible=prev["S"],
num_infectious=prev["I"],
population=self.population),
obs=S2I)
pyro.sample("I2R_{}".format(t),
dist.ExtendedBinomial(prev["I"], 1 / tau),
obs=I2R)
pyro.sample("S2O_{}".format(t),
dist.ExtendedBinomial(S2I, rho),
obs=S2O)

# Condition on cumulative observations.
pyro.sample("obs_{}".format(t),
dist.Delta(curr["O"]).mask(self.mask[t]),
obs=self.data[t])
48 changes: 47 additions & 1 deletion tests/contrib/epidemiology/test_sir.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import logging
import math

import pytest
import torch

from pyro.contrib.epidemiology import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel

from pyro.contrib.epidemiology import OverdispersedSIRModel, SimpleSIRModel
logger = logging.getLogger(__name__)


@pytest.mark.parametrize("duration", [3, 7])
Expand Down Expand Up @@ -66,3 +72,43 @@ def test_overdispersed_smoke(duration, forecast, options):
samples = model.predict(forecast=forecast)
assert samples["S"].shape == (num_samples, duration + forecast)
assert samples["I"].shape == (num_samples, duration + forecast)


@pytest.mark.parametrize("duration", [4, 12])
@pytest.mark.parametrize("forecast", [7])
@pytest.mark.parametrize("options", [
{},
{"dct": 1.},
{"num_quant_bins": 8},
], ids=str)
def test_sparse_smoke(duration, forecast, options):
population = 100
recovery_time = 7.0

# Generate data.
data = [None] * duration
mask = torch.arange(duration) % 4 == 3
model = SparseSIRModel(population, recovery_time, data, mask)
for attempt in range(100):
data = model.generate({"R0": 1.5, "rho": 0.5})["obs"]
if data.sum():
break
assert data.sum() > 0, "failed to generate positive data"
assert (data[1:] >= data[:-1]).all()
data[~mask] = math.nan
logger.info("data:\n{}".format(data))

# Infer.
model = SparseSIRModel(population, recovery_time, data, mask)
num_samples = 5
model.fit(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options)

# Predict and forecast.
samples = model.predict(forecast=forecast)
assert samples["S"].shape == (num_samples, duration + forecast)
assert samples["I"].shape == (num_samples, duration + forecast)
assert samples["O"].shape == (num_samples, duration + forecast)
assert (samples["O"][..., 1:] >= samples["O"][..., :-1]).all()
for O in samples["O"]:
logger.info("imputed:\n{}".format(O))
assert (O[:duration][mask] == data[mask]).all()