Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add naive relaxed inference for CompartmentalModel #2510

Closed
wants to merge 10 commits into from
4 changes: 4 additions & 0 deletions examples/contrib/epidemiology/regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ def hook_fn(kernel, *unused):
heuristic_ess_threshold=args.ess_threshold,
warmup_steps=args.warmup_steps,
num_samples=args.num_samples,
relax=args.relax,
max_tree_depth=args.max_tree_depth,
num_quant_bins=args.num_bins,
dct=args.dct,
haar=args.haar,
haar_full_mass=args.haar_full_mass,
hook_fn=hook_fn)
Expand Down Expand Up @@ -137,6 +139,8 @@ def main(args):
parser.add_argument("-R0", "--basic-reproduction-number", default=1.5, type=float)
parser.add_argument("-tau", "--recovery-time", default=7.0, type=float)
parser.add_argument("-rho", "--response-rate", default=0.5, type=float)
parser.add_argument("--relax", action="store_true")
parser.add_argument("--dct", type=float)
parser.add_argument("--haar", action="store_true")
parser.add_argument("-hfm", "--haar-full-mass", default=0, type=int)
parser.add_argument("-n", "--num-samples", default=200, type=int)
Expand Down
4 changes: 4 additions & 0 deletions examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ def hook_fn(kernel, *unused):
heuristic_ess_threshold=args.ess_threshold,
warmup_steps=args.warmup_steps,
num_samples=args.num_samples,
relax=args.relax,
max_tree_depth=args.max_tree_depth,
arrowhead_mass=args.arrowhead_mass,
num_quant_bins=args.num_bins,
dct=args.dct,
haar=args.haar,
haar_full_mass=args.haar_full_mass,
hook_fn=hook_fn)
Expand Down Expand Up @@ -238,6 +240,8 @@ def main(args):
parser.add_argument("-k", "--concentration", default=math.inf, type=float,
help="If finite, use a superspreader model.")
parser.add_argument("-rho", "--response-rate", default=0.5, type=float)
parser.add_argument("--relax", action="store_true")
parser.add_argument("--dct", type=float)
parser.add_argument("--haar", action="store_true")
parser.add_argument("-hfm", "--haar-full-mass", default=0, type=int)
parser.add_argument("-n", "--num-samples", default=200, type=int)
Expand Down
91 changes: 77 additions & 14 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
import pyro.distributions as dist
import pyro.distributions.hmm
import pyro.poutine as poutine
from pyro.distributions.transforms import HaarTransform
from pyro.distributions.transforms import DiscreteCosineTransform, HaarTransform
from pyro.infer import MCMC, NUTS, SMCFilter, infer_discrete
from pyro.infer.autoguide import init_to_generated, init_to_value
from pyro.infer.mcmc import ArrowheadMassMatrix
from pyro.infer.reparam import HaarReparam, SplitReparam
from pyro.infer.reparam import DiscreteCosineReparam, HaarReparam, SplitReparam
from pyro.infer.smcfilter import SMCFailed
from pyro.util import warn_if_nan

from .distributions import set_approx_log_prob_tol, set_approx_sample_thresh
from .util import align_samples, cat2, clamp, quantize, quantize_enumerate
from .util import align_samples, cat2, clamp, differentiably_round, quantize, quantize_enumerate

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -295,8 +295,12 @@ def fit(self, **options):
pulled out and have special meaning.
:param int max_tree_depth: (Default 5). Max tree depth of the
:class:`~pyro.infer.mcmc.nuts.NUTS` kernel.
:param full_mass: (Default ``False``). Specification of mass matrix
of the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel.
:param bool relax: Whether to use a relaxed model rather than the
default discrete model. The relaxed model is a biased approximation
of the discrete model, but is cheaper, allowing more MCMC samples
and hence lower variance. Defaults to False.
:param full_mass: Specification of mass matrix of the
:class:`~pyro.infer.mcmc.nuts.NUTS` kernel. Defaults to False.
:param bool arrowhead_mass: Whether to treat ``full_mass`` as the head
of an arrowhead matrix versus simply as a block. Defaults to False.
:param int num_quant_bins: The number of quantization bins to use. Note
Expand All @@ -306,43 +310,53 @@ def fit(self, **options):
:param int haar_full_mass: Number of low frequency Haar components to
include in the full mass matrix. If nonzero this implies
``haar=True``.
:param float dct: If provided, use a discrete cosine reparameterizer
with this value as smoothness.
:param int heuristic_num_particles: Passed to :meth:`heuristic` as
``num_particles``. Defaults to 1024.
:returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``.
:rtype: ~pyro.infer.mcmc.api.MCMC
"""
# Parse options, saving some for use in .predict().
self.num_quant_bins = options.pop("num_quant_bins", 4)
dct = options.pop("dct", None)
haar = options.pop("haar", False)
assert isinstance(haar, bool)
haar_full_mass = options.pop("haar_full_mass", 0)
assert isinstance(haar_full_mass, int)
assert haar_full_mass >= 0
haar_full_mass = min(haar_full_mass, self.duration)
haar = haar or (haar_full_mass > 0)
assert haar is False or dct is None, "Cannot combine dct with haar"

# Heuristically initialize to feasible latents.
heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
for k in list(options)
if k.startswith("heuristic_")}
time_dim = -2 if self.is_regional else -1

def heuristic():
with poutine.block():
init_values = self.heuristic(**heuristic_options)
assert isinstance(init_values, dict)
assert "auxiliary" in init_values, \
".heuristic() did not define auxiliary value"
if dct is not None:
# Also initialize DCT transformed coordinates.
x = init_values["auxiliary"]
x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x)
x = DiscreteCosineTransform(dim=time_dim, smooth=dct)(x)
init_values["auxiliary_dct"] = x
if haar:
# Also initialize Haar transformed coordinates.
x = init_values["auxiliary"]
x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x)
x = HaarTransform(dim=-2 if self.is_regional else -1, flip=True)(x)
x = HaarTransform(dim=time_dim, flip=True)(x)
init_values["auxiliary_haar"] = x
if haar_full_mass:
# Also split into low- and high-frequency parts.
x0, x1 = init_values["auxiliary_haar"].split(
[haar_full_mass, self.duration - haar_full_mass],
dim=-2 if self.is_regional else -1)
[haar_full_mass, self.duration - haar_full_mass], dim=time_dim)
init_values["auxiliary_haar_split_0"] = x0
init_values["auxiliary_haar_split_1"] = x1
logger.info("Heuristic init: {}".format(", ".join(
Expand All @@ -353,18 +367,22 @@ def heuristic():

# Configure a kernel.
logger.info("Running inference...")
relax = options.pop("relax", False)
max_tree_depth = options.pop("max_tree_depth", 5)
full_mass = options.pop("full_mass", self.full_mass)
model = self._vectorized_model
model = self._relaxed_model if relax else self._vectorized_model
if dct is not None:
rep = DiscreteCosineReparam(dim=time_dim, smooth=dct)
model = poutine.reparam(model, {"auxiliary": rep})
if haar:
rep = HaarReparam(dim=-2 if self.is_regional else -1, flip=True)
rep = HaarReparam(dim=time_dim, flip=True)
model = poutine.reparam(model, {"auxiliary": rep})
if haar_full_mass:
assert full_mass and isinstance(full_mass, list)
full_mass = full_mass[:]
full_mass[0] = full_mass[0] + ("auxiliary_haar_split_0",)
rep = SplitReparam([haar_full_mass, self.duration - haar_full_mass],
dim=-2 if self.is_regional else -1)
dim=time_dim)
model = poutine.reparam(model, {"auxiliary_haar": rep})
kernel = NUTS(model,
full_mass=full_mass,
Expand All @@ -384,11 +402,17 @@ def heuristic():
self.samples["auxiliary_haar"] = torch.cat([
self.samples.pop("auxiliary_haar_split_0"),
self.samples.pop("auxiliary_haar_split_1"),
], dim=-2 if self.is_regional else -1)
], dim=time_dim)
if haar:
# Transform back from Haar coordinates.
x = self.samples.pop("auxiliary_haar")
x = HaarTransform(dim=-2 if self.is_regional else -1, flip=True).inv(x)
x = HaarTransform(dim=time_dim, flip=True).inv(x)
x = biject_to(constraints.interval(-0.5, self.population + 0.5))(x)
self.samples["auxiliary"] = x
if dct is not None:
# Transform back from discrete cosine coordinates.
x = self.samples.pop("auxiliary_dct")
x = DiscreteCosineTransform(dim=time_dim, smooth=dct).inv(x)
x = biject_to(constraints.interval(-0.5, self.population + 0.5))(x)
self.samples["auxiliary"] = x

Expand Down Expand Up @@ -580,7 +604,7 @@ def enum_reshape(tensor, position):
# Enable approximate inference by using aux as a non-enumerated proxy
# for enumerated compartment values.
for name in self.approximate:
aux = auxiliary[self.compartments.index(name)]
aux = differentiably_round(auxiliary[self.compartments.index(name)])
curr[name + "_approx"] = aux
prev[name + "_approx"] = cat2(init[name], aux[:-1],
dim=-2 if self.is_regional else -1)
Expand All @@ -606,6 +630,45 @@ def enum_reshape(tensor, position):

self._clear_plates()

def _relaxed_model(self):
"""
Relaxed vectorized model used for approximate inference.
"""
C = len(self.compartments)
T = self.duration
R_shape = getattr(self.population, "shape", ()) # Region shape.

# Sample global parameters.
params = self.global_model()

# Sample the continuous reparameterizing variable.
shape = (C, T) + R_shape
auxiliary = pyro.sample("auxiliary",
dist.Uniform(-0.5, self.population + 0.5)
.mask(False).expand(shape).to_event())
assert auxiliary.shape == shape, "particle plates are not supported"

# Constrain.
curr = dict(zip(self.compartments, differentiably_round(auxiliary)))

# Truncate final value from the right then pad initial value onto the left.
init = self.initialize(params)
prev = {name: cat2(init[name], value[:-1], dim=-2 if self.is_regional else -1)
for name, value in curr.items()}

# Enable approximate inference by using aux as a non-enumerated proxy
# for enumerated compartment values.
for name in self.approximate:
curr[name + "_approx"] = curr[name]
prev[name + "_approx"] = prev[name]

# Transition.
with pyro.plate("time", T, dim=-1 - self.max_plate_nesting):
t = slice(None) # Used to slice data tensors.
self.transition_bwd(params, prev, curr, t)

self._clear_plates()


class _SMCModel:
"""
Expand Down
17 changes: 17 additions & 0 deletions pyro/contrib/epidemiology/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ def cat2(lhs, rhs, *, dim=-1):
return torch.cat([lhs.expand(shape), rhs.expand(shape)], dim=dim)


class _Round(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.round()

@staticmethod
def backward(ctx, grad):
return grad


def differentiably_round(x):
"""
Like :func:`torch.round` but passes gradients as if no rounding occurred.
"""
return _Round.apply(x)


@torch.no_grad()
def align_samples(samples, model, particle_dim):
"""
Expand Down
5 changes: 5 additions & 0 deletions pyro/infer/mcmc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ def _find_valid_initial_params(model, model_args, model_kwargs, transforms, pote
params = {k: transforms[k](v) for k, v in samples.items()}
pe_grad, pe = potential_grad(potential_fn, params)

if not torch.isfinite(pe):
print("DEBUG energy is not finite")
elif not all(torch.isfinite(g).all() for g in pe_grad.values()):
print("DEBUG grad is not finite")

if torch.isfinite(pe) and all(map(torch.all, map(torch.isfinite, pe_grad.values()))):
for k, v in params.items():
params_per_chain[k].append(v)
Expand Down
9 changes: 8 additions & 1 deletion tests/contrib/epidemiology/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
@pytest.mark.parametrize("forecast", [0, 7])
@pytest.mark.parametrize("options", [
{},
{"dct": 1.0},
{"haar": True},
{"haar_full_mass": 2},
{"num_quant_bins": 2},
{"num_quant_bins": 8},
{"num_quant_bins": 12},
{"num_quant_bins": 16},
{"arrowhead_mass": True},
{"relax": True},
], ids=str)
def test_simple_sir_smoke(duration, forecast, options):
population = 100
Expand Down Expand Up @@ -56,6 +58,7 @@ def test_simple_sir_smoke(duration, forecast, options):
{"haar": True},
{"haar_full_mass": 2},
{"num_quant_bins": 8},
{"relax": True},
], ids=str)
def test_simple_seir_smoke(duration, forecast, options):
population = 100
Expand Down Expand Up @@ -91,6 +94,7 @@ def test_simple_seir_smoke(duration, forecast, options):
{"haar": True},
{"haar_full_mass": 2},
{"num_quant_bins": 8},
{"relax": True},
], ids=str)
def test_superspreading_sir_smoke(duration, forecast, options):
population = 100
Expand Down Expand Up @@ -122,6 +126,7 @@ def test_superspreading_sir_smoke(duration, forecast, options):
{"haar": True},
{"haar_full_mass": 2},
{"num_quant_bins": 8},
{"relax": True},
], ids=str)
def test_superspreading_seir_smoke(duration, forecast, options):
population = 100
Expand Down Expand Up @@ -272,12 +277,14 @@ def test_unknown_start_smoke(duration, pre_obs_window, forecast, options):


@pytest.mark.parametrize("duration", [3, 7])
@pytest.mark.parametrize("forecast", [0, 7])
@pytest.mark.parametrize("forecast", [9])
@pytest.mark.parametrize("options", [
{},
{"dct": 1.0},
{"haar": True},
{"haar_full_mass": 2},
{"num_quant_bins": 8},
{"relax": True},
], ids=str)
def test_regional_smoke(duration, forecast, options):
num_regions = 6
Expand Down
2 changes: 2 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=8',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -hfm=3',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --relax',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -a',
'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2',
'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar',
'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -hfm=3',
'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --relax',
'contrib/forecast/bart.py --num-steps=2 --stride=99999',
'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000',
'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000',
Expand Down