Skip to content

Commit

Permalink
Merge branch 'dev' into add_dim_to_coupling
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored May 22, 2020
2 parents 7315a5b + a11e170 commit 8b3e050
Show file tree
Hide file tree
Showing 39 changed files with 1,010 additions and 258 deletions.
18 changes: 14 additions & 4 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,13 @@ Stable
:undoc-members:
:show-inheritance:

TruncatedPolyaGamma
-------------------
.. autoclass:: pyro.distributions.TruncatedPolyaGamma
:members:
:undoc-members:
:show-inheritance:

Unit
----
.. autoclass:: pyro.distributions.Unit
Expand Down Expand Up @@ -340,6 +347,13 @@ ELUTransform
:undoc-members:
:show-inheritance:

HaarTransform
-------------
.. autoclass:: pyro.distributions.transforms.HaarTransform
:members:
:undoc-members:
:show-inheritance:

LeakyReLUTransform
------------------
.. autoclass:: pyro.distributions.transforms.LeakyReLUTransform
Expand Down Expand Up @@ -622,7 +636,3 @@ spline
sylvester
---------
.. autofunction:: pyro.distributions.transforms.sylvester

tanh
----
.. autofunction:: pyro.distributions.transforms.tanh
27 changes: 27 additions & 0 deletions docs/source/infer.reparam.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ Discrete Cosine Transform
:special-members: __call__
:show-inheritance:

Haar Transform
--------------
.. automodule:: pyro.infer.reparam.haar
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:

Unit Jacobian Transforms
------------------------
.. automodule:: pyro.infer.reparam.unit_jacobian
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:

StudentT Distributions
----------------------
.. automodule:: pyro.infer.reparam.studentt
Expand Down Expand Up @@ -77,6 +95,15 @@ Hidden Markov Models
:special-members: __call__
:show-inheritance:

Site Splitting
--------------
.. automodule:: pyro.infer.reparam.split
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:

Neural Transport
----------------
.. automodule:: pyro.infer.reparam.neutra
Expand Down
4 changes: 4 additions & 0 deletions examples/contrib/epidemiology/regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def hook_fn(kernel, *unused):
num_samples=args.num_samples,
max_tree_depth=args.max_tree_depth,
num_quant_bins=args.num_bins,
haar=args.haar,
haar_full_mass=args.haar_full_mass,
hook_fn=hook_fn)

mcmc.summary()
Expand Down Expand Up @@ -135,6 +137,8 @@ def main(args):
parser.add_argument("-R0", "--basic-reproduction-number", default=1.5, type=float)
parser.add_argument("-tau", "--recovery-time", default=7.0, type=float)
parser.add_argument("-rho", "--response-rate", default=0.5, type=float)
parser.add_argument("--haar", action="store_true")
parser.add_argument("-hfm", "--haar-full-mass", default=0, type=int)
parser.add_argument("-n", "--num-samples", default=200, type=int)
parser.add_argument("-np", "--num-particles", default=1024, type=int)
parser.add_argument("-ess", "--ess-threshold", default=0.5, type=float)
Expand Down
44 changes: 39 additions & 5 deletions examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import math

import torch
from torch.distributions import biject_to, constraints

import pyro
from pyro.contrib.epidemiology import OverdispersedSEIRModel, OverdispersedSIRModel, SimpleSEIRModel, SimpleSIRModel
Expand Down Expand Up @@ -73,7 +74,8 @@ def hook_fn(kernel, *unused):
max_tree_depth=args.max_tree_depth,
arrowhead_mass=args.arrowhead_mass,
num_quant_bins=args.num_bins,
dct=args.dct,
haar=args.haar,
haar_full_mass=args.haar_full_mass,
hook_fn=hook_fn)

mcmc.summary()
Expand All @@ -89,7 +91,7 @@ def hook_fn(kernel, *unused):
return model.samples


def evaluate(args, samples):
def evaluate(args, model, samples):
# Print estimated values.
names = {"basic_reproduction_number": "R0",
"response_rate": "rho"}
Expand All @@ -106,6 +108,7 @@ def evaluate(args, samples):
import matplotlib.pyplot as plt
import seaborn as sns

# Plot individual histograms.
fig, axes = plt.subplots(len(names), 1, figsize=(5, 2.5 * len(names)))
axes[0].set_title("Posterior parameter estimates")
for ax, (name, key) in zip(axes, names.items()):
Expand All @@ -117,6 +120,7 @@ def evaluate(args, samples):
ax.legend(loc="best")
plt.tight_layout()

# Plot pairwise joint distributions for selected variables.
covariates = [(name, samples[name]) for name in names.values()]
for i, aux in enumerate(samples["auxiliary"].unbind(-2)):
covariates.append(("aux[{},0]".format(i), aux[:, 0]))
Expand All @@ -136,6 +140,36 @@ def evaluate(args, samples):
plt.tight_layout()
plt.subplots_adjust(wspace=0, hspace=0)

# Plot Pearson correlation for every pair of unconstrained variables.
def unconstrain(constraint, value):
value = biject_to(constraint).inv(value)
return value.reshape(args.num_samples, -1)

covariates = [
("R1", unconstrain(constraints.positive, samples["R0"])),
("rho", unconstrain(constraints.unit_interval, samples["rho"]))]
if "k" in samples:
covariates.append(
("k", unconstrain(constraints.positive, samples["k"])))
constraint = constraints.interval(-0.5, model.population + 0.5)
for name, aux in zip(model.compartments, samples["auxiliary"].unbind(-2)):
covariates.append((name, unconstrain(constraint, aux)))
x = torch.cat([v for _, v in covariates], dim=-1)
x -= x.mean(0)
x /= x.std(0)
x = x.t().matmul(x)
x /= args.num_samples
x.clamp_(min=-1, max=1)
plt.figure(figsize=(8, 8))
plt.imshow(x, cmap="bwr")
ticks = torch.tensor([0] + [v.size(-1) for _, v in covariates]).cumsum(0)
ticks = (ticks[1:] + ticks[:-1]) / 2
plt.yticks(ticks, [name for name, _ in covariates])
plt.xticks(())
plt.tick_params(length=0)
plt.title("Pearson correlation (unconstrained coordinates)")
plt.tight_layout()


def predict(args, model, truth):
samples = model.predict(forecast=args.forecast)
Expand Down Expand Up @@ -182,7 +216,7 @@ def main(args):
samples = infer(args, model)

# Evaluate fit.
evaluate(args, samples)
evaluate(args, model, samples)

# Predict latent time series.
if args.forecast:
Expand All @@ -204,8 +238,8 @@ def main(args):
parser.add_argument("-k", "--concentration", default=math.inf, type=float,
help="If finite, use a superspreader model.")
parser.add_argument("-rho", "--response-rate", default=0.5, type=float)
parser.add_argument("--dct", type=float,
help="smoothing for discrete cosine reparameterizer")
parser.add_argument("--haar", action="store_true")
parser.add_argument("-hfm", "--haar-full-mass", default=0, type=int)
parser.add_argument("-n", "--num-samples", default=200, type=int)
parser.add_argument("-np", "--num-particles", default=1024, type=int)
parser.add_argument("-ess", "--ess-threshold", default=0.5, type=float)
Expand Down
104 changes: 74 additions & 30 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import pyro.distributions as dist
import pyro.distributions.hmm
import pyro.poutine as poutine
from pyro.distributions.transforms import DiscreteCosineTransform
from pyro.distributions.transforms import HaarTransform
from pyro.infer import MCMC, NUTS, SMCFilter, infer_discrete
from pyro.infer.autoguide import init_to_generated, init_to_value
from pyro.infer.mcmc import ArrowheadMassMatrix
from pyro.infer.reparam import DiscreteCosineReparam
from pyro.infer.reparam import HaarReparam, SplitReparam
from pyro.infer.smcfilter import SMCFailed
from pyro.util import warn_if_nan

from .distributions import set_approx_sample_thresh
Expand Down Expand Up @@ -143,8 +144,8 @@ def _clear_plates(self):
full_mass = False

@torch.no_grad()
@set_approx_sample_thresh(1000)
def heuristic(self, num_particles=1024, ess_threshold=0.5):
@set_approx_sample_thresh(100) # This is robust to gross approximation.
def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10):
"""
Finds an initial feasible guess of all latent variables, consistent
with observed data. This is needed because not all hypotheses are
Expand All @@ -163,12 +164,20 @@ def heuristic(self, num_particles=1024, ess_threshold=0.5):
# Run SMC.
model = _SMCModel(self)
guide = _SMCGuide(self)
smc = SMCFilter(model, guide, num_particles=num_particles,
ess_threshold=ess_threshold,
max_plate_nesting=self.max_plate_nesting)
smc.init()
for t in range(1, self.duration):
smc.step()
for attempt in range(1, 1 + retries):
smc = SMCFilter(model, guide, num_particles=num_particles,
ess_threshold=ess_threshold,
max_plate_nesting=self.max_plate_nesting)
try:
smc.init()
for t in range(1, self.duration):
smc.step()
break
except SMCFailed as e:
if attempt == retries:
raise
logger.info("{}. Retrying...".format(e))
continue

# Select the most probable hypothesis.
i = int(smc.state._log_weights.max(0).indices)
Expand Down Expand Up @@ -249,6 +258,7 @@ def transition_bwd(self, params, prev, curr, t):
# Inference interface ########################################

@torch.no_grad()
@set_approx_sample_thresh(1000)
def generate(self, fixed={}):
"""
Generate data from the prior.
Expand Down Expand Up @@ -290,50 +300,74 @@ def fit(self, **options):
:param int num_quant_bins: The number of quantization bins to use. Note
that computational cost is exponential in `num_quant_bins`.
Defaults to 4.
:param float dct: If provided, use a discrete cosine reparameterizer
with this value as smoothness.
:param bool haar: Whether to use a Haar wavelet reparameterizer.
:param int haar_full_mass: Number of low frequency Haar components to
include in the full mass matrix. If nonzero this implies
``haar=True``.
:param int heuristic_num_particles: Passed to :meth:`heuristic` as
``num_particles``. Defaults to 1024.
:returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``.
:rtype: ~pyro.infer.mcmc.api.MCMC
"""
# Save these options for .predict().
# Parse options, saving some for use in .predict().
self.num_quant_bins = options.pop("num_quant_bins", 4)
self._dct = options.pop("dct", None)
if self._dct is not None and self.is_regional:
raise NotImplementedError("regional models do not support DiscreteCosineReparam")

# Heuristically initialze to feasible latents.
haar = options.pop("haar", False)
assert isinstance(haar, bool)
haar_full_mass = options.pop("haar_full_mass", 0)
assert isinstance(haar_full_mass, int)
assert haar_full_mass >= 0
haar_full_mass = min(haar_full_mass, self.duration)
haar = haar or (haar_full_mass > 0)

# Heuristically initialize to feasible latents.
heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
for k in list(options)
if k.startswith("heuristic_")}

def heuristic():
logger.info("Heuristically initializing...")
with poutine.block():
init_values = self.heuristic(**heuristic_options)
assert isinstance(init_values, dict)
assert "auxiliary" in init_values, \
".heuristic() did not define auxiliary value"
if self._dct is not None:
# Also initialize DCT transformed coordinates.
if haar:
# Also initialize Haar transformed coordinates.
x = init_values["auxiliary"]
x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x)
x = DiscreteCosineTransform(smooth=self._dct)(x)
init_values["auxiliary_dct"] = x
x = HaarTransform(dim=-2 if self.is_regional else -1, flip=True)(x)
init_values["auxiliary_haar"] = x
if haar_full_mass:
# Also split into low- and high-frequency parts.
x0, x1 = init_values["auxiliary_haar"].split(
[haar_full_mass, self.duration - haar_full_mass],
dim=-2 if self.is_regional else -1)
init_values["auxiliary_haar_split_0"] = x0
init_values["auxiliary_haar_split_1"] = x1
logger.info("Heuristic init: {}".format(", ".join(
"{}={:0.3g}".format(k, v.item())
for k, v in init_values.items()
if v.numel() == 1)))
return init_to_value(values=init_values)

# Configure a kernel.
logger.info("Running inference...")
max_tree_depth = options.pop("max_tree_depth", 5)
full_mass = options.pop("full_mass", self.full_mass)
model = self._vectorized_model
if self._dct is not None:
rep = DiscreteCosineReparam(smooth=self._dct)
if haar:
rep = HaarReparam(dim=-2 if self.is_regional else -1, flip=True)
model = poutine.reparam(model, {"auxiliary": rep})
if haar_full_mass:
assert full_mass and isinstance(full_mass, list)
full_mass = full_mass[:]
full_mass[0] = full_mass[0] + ("auxiliary_haar_split_0",)
rep = SplitReparam([haar_full_mass, self.duration - haar_full_mass],
dim=-2 if self.is_regional else -1)
model = poutine.reparam(model, {"auxiliary_haar": rep})
kernel = NUTS(model,
full_mass=full_mass,
init_strategy=init_to_generated(generate=heuristic),
max_plate_nesting=self.max_plate_nesting,
max_tree_depth=max_tree_depth)
if options.pop("arrowhead_mass", False):
kernel.mass_matrix_adapter = ArrowheadMassMatrix()
Expand All @@ -342,13 +376,27 @@ def heuristic():
mcmc = MCMC(kernel, **options)
mcmc.run()
self.samples = mcmc.get_samples()
if haar_full_mass:
# Transform back from SplitReparam coordinates.
self.samples["auxiliary_haar"] = torch.cat([
self.samples.pop("auxiliary_haar_split_0"),
self.samples.pop("auxiliary_haar_split_1"),
], dim=-2 if self.is_regional else -1)
if haar:
# Transform back from Haar coordinates.
x = self.samples.pop("auxiliary_haar")
x = HaarTransform(dim=-2 if self.is_regional else -1, flip=True).inv(x)
x = biject_to(constraints.interval(-0.5, self.population + 0.5))(x)
self.samples["auxiliary"] = x

# Unsqueeze samples to align particle dim for use in poutine.condition.
# TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples().
self.samples = align_samples(self.samples, model,
self.samples = align_samples(self.samples, self._vectorized_model,
particle_dim=-1 - self.max_plate_nesting)
return mcmc # E.g. so user can run mcmc.summary().

@torch.no_grad()
@set_approx_sample_thresh(10000)
def predict(self, forecast=0):
"""
Predict latent variables and optionally forecast forward.
Expand Down Expand Up @@ -377,10 +425,6 @@ def predict(self, forecast=0):
model = self._sequential_model
model = poutine.condition(model, samples)
model = particle_plate(model)
if self._dct is not None:
# Apply the same reparameterizer as during inference.
rep = DiscreteCosineReparam(smooth=self._dct)
model = poutine.reparam(model, {"auxiliary": rep})
model = infer_discrete(model, first_available_dim=-2 - self.max_plate_nesting)
trace = poutine.trace(model).get_trace()
samples = OrderedDict((name, site["value"])
Expand Down
Loading

0 comments on commit 8b3e050

Please sign in to comment.