diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index 3538be5b4f..20541e34f9 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -17,7 +17,7 @@ import pyro.poutine as poutine from pyro.distributions.transforms import DiscreteCosineTransform from pyro.infer import MCMC, NUTS, SMCFilter, infer_discrete -from pyro.infer.autoguide import init_to_value +from pyro.infer.autoguide import init_to_generated, init_to_value from pyro.infer.mcmc import ArrowheadMassMatrix from pyro.infer.reparam import DiscreteCosineReparam from pyro.util import warn_if_nan @@ -301,19 +301,24 @@ def fit(self, **options): raise NotImplementedError("regional models do not support DiscreteCosineReparam") # Heuristically initialze to feasible latents. - logger.info("Heuristically initializing...") heuristic_options = {k.replace("heuristic_", ""): options.pop(k) - for k in list(options) if k.startswith("heuristic_")} - init_values = self.heuristic(**heuristic_options) - assert isinstance(init_values, dict) - assert "auxiliary" in init_values, \ - ".heuristic() did not define auxiliary value" - if self._dct is not None: - # Also initialize DCT transformed coordinates. - x = init_values["auxiliary"] - x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x) - x = DiscreteCosineTransform(smooth=self._dct)(x) - init_values["auxiliary_dct"] = x + for k in list(options) + if k.startswith("heuristic_")} + + def heuristic(): + logger.info("Heuristically initializing...") + with poutine.block(): + init_values = self.heuristic(**heuristic_options) + assert isinstance(init_values, dict) + assert "auxiliary" in init_values, \ + ".heuristic() did not define auxiliary value" + if self._dct is not None: + # Also initialize DCT transformed coordinates. + x = init_values["auxiliary"] + x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x) + x = DiscreteCosineTransform(smooth=self._dct)(x) + init_values["auxiliary_dct"] = x + return init_to_value(values=init_values) # Configure a kernel. logger.info("Running inference...") @@ -325,7 +330,7 @@ def fit(self, **options): model = poutine.reparam(model, {"auxiliary": rep}) kernel = NUTS(model, full_mass=full_mass, - init_strategy=init_to_value(values=init_values), + init_strategy=init_to_generated(generate=heuristic), max_tree_depth=max_tree_depth) if options.pop("arrowhead_mass", False): kernel.mass_matrix_adapter = ArrowheadMassMatrix() diff --git a/pyro/infer/autoguide/__init__.py b/pyro/infer/autoguide/__init__.py index 2c751f5808..f82c2f81e9 100644 --- a/pyro/infer/autoguide/__init__.py +++ b/pyro/infer/autoguide/__init__.py @@ -5,10 +5,9 @@ AutoDiscreteParallel, AutoGuide, AutoGuideList, AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal, AutoMultivariateNormal, AutoNormal, AutoNormalizingFlow) -from pyro.infer.autoguide.utils import mean_field_entropy -from pyro.infer.autoguide.initialization import (init_to_feasible, init_to_mean, init_to_median, +from pyro.infer.autoguide.initialization import (init_to_feasible, init_to_generated, init_to_mean, init_to_median, init_to_sample, init_to_uniform, init_to_value) - +from pyro.infer.autoguide.utils import mean_field_entropy __all__ = [ 'AutoCallable', @@ -25,6 +24,7 @@ 'AutoNormal', 'AutoNormalizingFlow', 'init_to_feasible', + 'init_to_generated', 'init_to_mean', 'init_to_median', 'init_to_sample', diff --git a/pyro/infer/autoguide/initialization.py b/pyro/infer/autoguide/initialization.py index cff6fcfc09..70b1a430f3 100644 --- a/pyro/infer/autoguide/initialization.py +++ b/pyro/infer/autoguide/initialization.py @@ -100,7 +100,8 @@ def init_to_mean(site=None): def init_to_uniform(site=None, radius=2): """ - Initialize to a random point in the area `(-radius, radius)` of unconstrained domain. + Initialize to a random point in the area ``(-radius, radius)`` of + unconstrained domain. :param float radius: specifies the range to draw an initial point in the unconstrained domain. """ @@ -114,8 +115,8 @@ def init_to_uniform(site=None, radius=2): def init_to_value(site=None, values={}): """ - Initialize to the value specified in `values`. We defer to - :func:`init_to_uniform` strategy for sites which do not appear in `values`. + Initialize to the value specified in ``values``. We defer to + :func:`init_to_uniform` strategy for sites which do not appear in ``values``. :param dict values: dictionary of initial values keyed by site name. """ @@ -128,6 +129,42 @@ def init_to_value(site=None, values={}): return init_to_uniform(site) +class _InitToGenerated: + def __init__(self, generate): + self.generate = generate + self._init = None + self._seen = set() + + def __call__(self, site): + if self._init is None or site["name"] in self._seen: + self._init = self.generate() + self._seen = {site["name"]} + return self._init(site) + + +def init_to_generated(site=None, generate=lambda: init_to_uniform): + """ + Initialize to another initialization strategy returned by the callback + ``generate`` which is called once per model execution. + + This is like :func:`init_to_value` but can produce different (e.g. random) + values once per model execution. For example to generate values and return + ``init_to_value`` you could define:: + + def generate(): + values = {"x": torch.randn(100), "y": torch.rand(5)} + return init_to_value(values=values) + + my_init_fn = init_to_generated(generate=generate) + + :param callable generate: A callable returning another initialization + function, e.g. returning an ``init_to_value(values={...})`` populated + with a dictionary of random samples. + """ + init = _InitToGenerated(generate) + return init if site is None else init(site) + + class InitMessenger(Messenger): """ Initializes a site by replacing ``.sample()`` calls with values diff --git a/tests/infer/mcmc/test_mcmc_util.py b/tests/infer/mcmc/test_mcmc_util.py index 413ec36cde..6ead1a3322 100644 --- a/tests/infer/mcmc/test_mcmc_util.py +++ b/tests/infer/mcmc/test_mcmc_util.py @@ -9,8 +9,8 @@ import pyro import pyro.distributions as dist from pyro.infer import Predictive -from pyro.infer.autoguide import (init_to_feasible, init_to_mean, init_to_median, - init_to_sample, init_to_uniform, init_to_value) +from pyro.infer.autoguide import (init_to_feasible, init_to_generated, init_to_mean, init_to_median, init_to_sample, + init_to_uniform, init_to_value) from pyro.infer.mcmc import NUTS from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.util import initialize_model @@ -114,7 +114,9 @@ def model(): init_to_sample(), init_to_uniform(radius=0.1), init_to_value(values={"x": torch.tensor(3.)}), -]) + init_to_generated( + generate=lambda: init_to_value(values={"x": torch.rand(())})), +], ids=str) def test_init_strategy_smoke(init_strategy): def model(): pyro.sample("x", dist.LogNormal(0, 1)) diff --git a/tests/infer/test_initialization.py b/tests/infer/test_initialization.py new file mode 100644 index 0000000000..7e1527e4b1 --- /dev/null +++ b/tests/infer/test_initialization.py @@ -0,0 +1,34 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import pyro +import pyro.distributions as dist +from pyro.infer.autoguide.initialization import InitMessenger, init_to_generated, init_to_value + + +def test_init_to_generated(): + def model(): + x = pyro.sample("x", dist.Normal(0, 1)) + y = pyro.sample("y", dist.Normal(0, 1)) + z = pyro.sample("z", dist.Normal(0, 1)) + return x, y, z + + class MockGenerate: + def __init__(self): + self.counter = 0 + + def __call__(self): + values = {"x": torch.tensor(self.counter + 0.0), + "y": torch.tensor(self.counter + 0.5)} + self.counter += 1 + return init_to_value(values=values) + + mock_generate = MockGenerate() + with InitMessenger(init_to_generated(generate=mock_generate)): + for i in range(5): + x, y, z = model() + assert x == i + assert y == i + 0.5 + assert mock_generate.counter == 5