Skip to content

Commit

Permalink
Add pearson correlation plot to examples/.../sir.py (#2497)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored May 21, 2020
1 parent 8cc51fb commit a11e170
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import math

import torch
from torch.distributions import biject_to, constraints

import pyro
from pyro.contrib.epidemiology import OverdispersedSEIRModel, OverdispersedSIRModel, SimpleSEIRModel, SimpleSIRModel
Expand Down Expand Up @@ -90,7 +91,7 @@ def hook_fn(kernel, *unused):
return model.samples


def evaluate(args, samples):
def evaluate(args, model, samples):
# Print estimated values.
names = {"basic_reproduction_number": "R0",
"response_rate": "rho"}
Expand All @@ -107,6 +108,7 @@ def evaluate(args, samples):
import matplotlib.pyplot as plt
import seaborn as sns

# Plot individual histograms.
fig, axes = plt.subplots(len(names), 1, figsize=(5, 2.5 * len(names)))
axes[0].set_title("Posterior parameter estimates")
for ax, (name, key) in zip(axes, names.items()):
Expand All @@ -118,6 +120,7 @@ def evaluate(args, samples):
ax.legend(loc="best")
plt.tight_layout()

# Plot pairwise joint distributions for selected variables.
covariates = [(name, samples[name]) for name in names.values()]
for i, aux in enumerate(samples["auxiliary"].unbind(-2)):
covariates.append(("aux[{},0]".format(i), aux[:, 0]))
Expand All @@ -137,6 +140,36 @@ def evaluate(args, samples):
plt.tight_layout()
plt.subplots_adjust(wspace=0, hspace=0)

# Plot Pearson correlation for every pair of unconstrained variables.
def unconstrain(constraint, value):
value = biject_to(constraint).inv(value)
return value.reshape(args.num_samples, -1)

covariates = [
("R1", unconstrain(constraints.positive, samples["R0"])),
("rho", unconstrain(constraints.unit_interval, samples["rho"]))]
if "k" in samples:
covariates.append(
("k", unconstrain(constraints.positive, samples["k"])))
constraint = constraints.interval(-0.5, model.population + 0.5)
for name, aux in zip(model.compartments, samples["auxiliary"].unbind(-2)):
covariates.append((name, unconstrain(constraint, aux)))
x = torch.cat([v for _, v in covariates], dim=-1)
x -= x.mean(0)
x /= x.std(0)
x = x.t().matmul(x)
x /= args.num_samples
x.clamp_(min=-1, max=1)
plt.figure(figsize=(8, 8))
plt.imshow(x, cmap="bwr")
ticks = torch.tensor([0] + [v.size(-1) for _, v in covariates]).cumsum(0)
ticks = (ticks[1:] + ticks[:-1]) / 2
plt.yticks(ticks, [name for name, _ in covariates])
plt.xticks(())
plt.tick_params(length=0)
plt.title("Pearson correlation (unconstrained coordinates)")
plt.tight_layout()


def predict(args, model, truth):
samples = model.predict(forecast=args.forecast)
Expand Down Expand Up @@ -183,7 +216,7 @@ def main(args):
samples = infer(args, model)

# Evaluate fit.
evaluate(args, samples)
evaluate(args, model, samples)

# Predict latent time series.
if args.forecast:
Expand Down

0 comments on commit a11e170

Please sign in to comment.