Skip to content

Commit

Permalink
Fix recovery_time to known value in SIR models (#2429)
Browse files Browse the repository at this point in the history
* Generate and plot forecasted true new infections

* Fix recovery time to known value
  • Loading branch information
fritzo authored Apr 22, 2020
1 parent ae6c36b commit 91f31ae
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 35 deletions.
52 changes: 30 additions & 22 deletions examples/sir_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
# return -inf.

def global_model(population):
tau = args.recovery_time # Assume this can be measured exactly.
R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
tau = pyro.sample("tau", dist.LogNormal(math.log(7.), 1.))
rho = pyro.sample("rho", dist.Uniform(0, 1))

# Convert interpretable parameters to distribution parameters.
Expand Down Expand Up @@ -98,23 +98,30 @@ def discrete_model(data, population):
def generate_data(args):
logging.info("Generating data...")
params = {"R0": torch.tensor(args.basic_reproduction_number),
"tau": torch.tensor(args.recovery_time),
"rho": torch.tensor(args.response_rate)}
empty_data = [None] * args.duration
empty_data = [None] * (args.duration + args.forecast)

# We'll retry until we get an actual outbreak.
for attempt in range(100):
with poutine.trace() as tr:
with poutine.condition(data=params):
discrete_model(empty_data, args.population)

data = torch.stack([site["value"]
for site in tr.trace.nodes.values()
if site["name"].startswith("obs_")])
if data.sum() >= args.min_observations:
logging.info("Generated {:0.0f} observed infections:\n{}"
.format(data.sum(), " ".join([str(int(x)) for x in data])))
return data
# Concatenate sequential time series into tensors.
obs = torch.stack([site["value"]
for name, site in tr.trace.nodes.items()
if re.match("obs_[0-9]+", name)])
S2I = torch.stack([site["value"]
for name, site in tr.trace.nodes.items()
if re.match("S2I_[0-9]+", name)])
assert len(obs) == len(empty_data)

obs_sum = int(obs[:args.duration].sum())
S2I_sum = int(S2I[:args.duration].sum())
if obs_sum >= args.min_observations:
logging.info("Observed {:d}/{:d} infections:\n{}".format(
obs_sum, S2I_sum, " ".join([str(int(x)) for x in obs[:args.duration]])))
return {"S2I": S2I, "obs": obs}

raise ValueError("Failed to generate {} observations. Try increasing "
"--population or decreasing --min-observations"
Expand Down Expand Up @@ -193,7 +200,7 @@ def infer_hmc_enum(args, data):
def _infer_hmc(args, data, model, init_values={}):
logging.info("Running inference...")
kernel = NUTS(model,
full_mass=[("R0", "tau", "rho")],
full_mass=[("R0", "rho")],
max_tree_depth=args.max_tree_depth,
init_strategy=init_to_value(values=init_values),
jit_compile=args.jit, ignore_jit_warnings=True)
Expand Down Expand Up @@ -318,7 +325,6 @@ def heuristic_init(args, data):
DiscreteCosineTransform(dim=-1)])

return {
"tau": torch.tensor(10.0),
"R0": torch.tensor(2.0),
"rho": torch.tensor(0.5),
"S_aux": S_aux,
Expand Down Expand Up @@ -429,7 +435,6 @@ def vectorized_model(data, population):
def evaluate(args, samples):
# Print estimated values.
names = {"basic_reproduction_number": "R0",
"recovery_time": "tau",
"response_rate": "rho"}
for name, key in names.items():
mean = samples[key].mean().item()
Expand All @@ -441,7 +446,7 @@ def evaluate(args, samples):
if args.plot:
import matplotlib.pyplot as plt
import seaborn as sns
fig, axes = plt.subplots(3, 1, figsize=(5, 8))
fig, axes = plt.subplots(2, 1, figsize=(5, 5))
axes[0].set_title("Posterior parameter estimates")
for ax, (name, key) in zip(axes, names.items()):
truth = getattr(args, name)
Expand All @@ -468,7 +473,7 @@ def evaluate(args, samples):
# generated via infer_hmc_cont(vectorized_model, ...).

@torch.no_grad()
def predict(args, data, samples):
def predict(args, data, samples, truth=None):
logging.info("Forecasting {} steps ahead...".format(args.forecast))
particle_plate = pyro.plate("particles", args.num_samples, dim=-1)

Expand Down Expand Up @@ -517,9 +522,11 @@ def predict(args, data, samples):
time = torch.arange(len(data) + args.forecast)
p05 = S2I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values
p95 = S2I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values
plt.plot(time[:len(data)], data, "k.", label="observed")
plt.plot(time, median, "r-", label="median")
plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI")
plt.plot(time, median, "r-", label="median")
plt.plot(time[:len(data)], data, "k.", label="observed")
if truth is not None:
plt.plot(time, truth, "k--", label="truth")
plt.axvline(args.duration - 0.5, color="gray", lw=1)
plt.xlim(0, len(time) - 1)
plt.ylim(0, None)
Expand All @@ -545,22 +552,23 @@ def main(args):
pyro.enable_validation(__debug__)
pyro.set_rng_seed(args.rng_seed)

data = generate_data(args)
dataset = generate_data(args)
obs = dataset["obs"][:args.duration]

# Choose among inference methods.
if args.enum:
samples = infer_hmc_enum(args, data)
samples = infer_hmc_enum(args, obs)
elif args.sequential:
samples = infer_hmc_cont(continuous_model, args, data)
samples = infer_hmc_cont(continuous_model, args, obs)
else:
samples = infer_hmc_cont(vectorized_model, args, data)
samples = infer_hmc_cont(vectorized_model, args, obs)

# Evaluate fit.
evaluate(args, samples)

# Predict latent time series.
if args.forecast:
samples = predict(args, data, samples)
samples = predict(args, obs, samples, truth=dataset["S2I"])

return samples

Expand Down
Binary file modified tutorial/source/_static/img/sir_hmc/energy-trace.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tutorial/source/_static/img/sir_hmc/forecast.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tutorial/source/_static/img/sir_hmc/parameters.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 16 additions & 13 deletions tutorial/source/sir_hmc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,26 @@ another 30 days. This takes about 3 minutes on my laptop.

.. code-block:: none
$ python -O examples/sir_hmc.py -p 10000 -d 60 -f 30 --plot
...
Generated 452 observed infections:
$ python examples/sir_hmc.py -p 10000 -d 60 -f 30 --plot
Generating data...
Observed 452/871 infections:
0 0 2 1 2 0 0 3 2 0 1 3 1 3 0 1 0 6 4 3 6 4 4 3 3 3 5 3 3 3 5 1 4 6 4 2 6 8 7 4 11 8 14 9 17 13 9 14 10 15 16 22 20 22 19 20 28 25 23 21
Running inference...
Sample: 100%|==========================| 300/300 [02:52, 1.74it/s, step size=1.12e-01, acc. prob=0.747]
mean std median 5.0% 95.0% n_eff r_hat
R0 1.41 0.13 1.39 1.21 1.63 7.03 1.14
tau 7.46 1.32 7.01 5.65 9.68 5.38 1.29
rho 0.55 0.03 0.55 0.49 0.59 4.83 1.44
...
Sample: 100%|=========================| 300/300 [02:35, 1.93it/s, step size=9.67e-02, acc. prob=0.878]
mean std median 5.0% 95.0% n_eff r_hat
R0 1.40 0.07 1.40 1.28 1.49 26.56 1.06
rho 0.47 0.02 0.47 0.44 0.52 7.08 1.22
S_aux[0] 9998.74 0.64 9998.75 9997.84 9999.67 28.74 1.00
S_aux[1] 9998.37 0.72 9998.38 9997.28 9999.44 52.24 1.02
...
I_aux[0] 1.11 0.64 0.99 0.19 2.02 22.01 1.00
I_aux[1] 1.55 0.74 1.65 0.05 2.47 10.05 1.10
...
Number of divergences: 0
R0: truth = 1.5, estimate = 1.41 ± 0.134
tau: truth = 7, estimate = 7.46 ± 1.32
rho: truth = 0.5, estimate = 0.546 ± 0.0297
R0: truth = 1.5, estimate = 1.4 ± 0.0654
rho: truth = 0.5, estimate = 0.475 ± 0.023
.. image:: _static/img/sir_hmc/forecast.png
:alt: Forecast of new infections
Expand Down

0 comments on commit 91f31ae

Please sign in to comment.