diff --git a/pyro/infer/autoguide/__init__.py b/pyro/infer/autoguide/__init__.py index f31e4352bc..2c751f5808 100644 --- a/pyro/infer/autoguide/__init__.py +++ b/pyro/infer/autoguide/__init__.py @@ -6,7 +6,9 @@ AutoLaplaceApproximation, AutoLowRankMultivariateNormal, AutoMultivariateNormal, AutoNormal, AutoNormalizingFlow) from pyro.infer.autoguide.utils import mean_field_entropy -from pyro.infer.autoguide.initialization import init_to_feasible, init_to_mean, init_to_median, init_to_sample +from pyro.infer.autoguide.initialization import (init_to_feasible, init_to_mean, init_to_median, + init_to_sample, init_to_uniform, init_to_value) + __all__ = [ 'AutoCallable', @@ -26,5 +28,7 @@ 'init_to_mean', 'init_to_median', 'init_to_sample', + 'init_to_uniform', + 'init_to_value', 'mean_field_entropy', ] diff --git a/pyro/infer/autoguide/initialization.py b/pyro/infer/autoguide/initialization.py index d6cea16699..cff6fcfc09 100644 --- a/pyro/infer/autoguide/initialization.py +++ b/pyro/infer/autoguide/initialization.py @@ -9,6 +9,8 @@ trace ``site`` dict and returns an appropriately sized ``value`` to serve as an initial constrained value for a guide estimate. """ +import functools + import torch from torch.distributions import transform_to @@ -19,34 +21,45 @@ from pyro.util import torch_isnan +# TODO: move this file out of `autoguide` in a minor release + def _is_multivariate(d): while isinstance(d, (Independent, MaskedDistribution)): d = d.base_dist return any(size > 1 for size in d.event_shape) -def init_to_feasible(site): +def init_to_feasible(site=None): """ Initialize to an arbitrary feasible point, ignoring distribution parameters. """ + if site is None: + return init_to_feasible + value = site["fn"].sample().detach() t = transform_to(site["fn"].support) return t(torch.zeros_like(t.inv(value))) -def init_to_sample(site): +def init_to_sample(site=None): """ Initialize to a random sample from the prior. """ + if site is None: + return init_to_sample + return site["fn"].sample().detach() -def init_to_median(site, num_samples=15): +def init_to_median(site=None, num_samples=15): """ Initialize to the prior median; fallback to a feasible point if median is undefined. """ + if site is None: + return functools.partial(init_to_median, num_samples=num_samples) + # The median undefined for multivariate distributions. if _is_multivariate(site["fn"]): return init_to_feasible(site) @@ -64,10 +77,13 @@ def init_to_median(site, num_samples=15): return init_to_feasible(site) -def init_to_mean(site): +def init_to_mean(site=None): """ Initialize to the prior mean; fallback to median if mean is undefined. """ + if site is None: + return init_to_mean + try: # Try .mean() method. value = site["fn"].mean.detach() @@ -78,10 +94,40 @@ def init_to_mean(site): return value except (NotImplementedError, ValueError): # Fall back to a median. - # This is requred for distributions with infinite variance, e.g. Cauchy. + # This is required for distributions with infinite variance, e.g. Cauchy. return init_to_median(site) +def init_to_uniform(site=None, radius=2): + """ + Initialize to a random point in the area `(-radius, radius)` of unconstrained domain. + + :param float radius: specifies the range to draw an initial point in the unconstrained domain. + """ + if site is None: + return functools.partial(init_to_uniform, radius=radius) + + value = site["fn"].sample().detach() + t = transform_to(site["fn"].support) + return t(torch.rand_like(t.inv(value)) * (2 * radius) - radius) + + +def init_to_value(site=None, values={}): + """ + Initialize to the value specified in `values`. We defer to + :func:`init_to_uniform` strategy for sites which do not appear in `values`. + + :param dict values: dictionary of initial values keyed by site name. + """ + if site is None: + return functools.partial(init_to_value, values=values) + + if site["name"] in values: + return values[site["name"]] + else: + return init_to_uniform(site) + + class InitMessenger(Messenger): """ Initializes a site by replacing ``.sample()`` calls with values diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 44b29b983b..2d3a92da61 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -10,6 +10,7 @@ import pyro.distributions as dist from pyro.distributions.util import eye_like, scalar_like +from pyro.infer.autoguide import init_to_uniform from pyro.infer.mcmc.adaptation import WarmupAdapter from pyro.infer.mcmc.mcmc_kernel import MCMCKernel from pyro.infer.mcmc.util import initialize_model @@ -63,6 +64,8 @@ class HMC(MCMCKernel): tracer when ``jit_compile=True``. Default is False. :param float target_accept_prob: Increasing this value will lead to a smaller step size, hence the sampling will be slower and more robust. Default to 0.8. + :param callable init_strategy: A per-site initialization function. + See :ref:`autoguide-initialization` section for available functions. .. note:: Internally, the mass matrix will be ordered according to the order of the names of latent variables, not the order of their appearance in @@ -102,7 +105,8 @@ def __init__(self, jit_compile=False, jit_options=None, ignore_jit_warnings=False, - target_accept_prob=0.8): + target_accept_prob=0.8, + init_strategy=init_to_uniform): if not ((model is None) ^ (potential_fn is None)): raise ValueError("Only one of `model` or `potential_fn` must be specified.") # NB: deprecating args - model, transforms @@ -112,6 +116,7 @@ def __init__(self, self._jit_compile = jit_compile self._jit_options = jit_options self._ignore_jit_warnings = ignore_jit_warnings + self._init_strategy = init_strategy self.potential_fn = potential_fn if trajectory_length is not None: @@ -237,6 +242,7 @@ def _initialize_model_properties(self, model_args, model_kwargs): jit_compile=self._jit_compile, jit_options=self._jit_options, skip_jit_warnings=self._ignore_jit_warnings, + init_strategy=self._init_strategy, ) self.potential_fn = potential_fn self.transforms = transforms diff --git a/pyro/infer/mcmc/nuts.py b/pyro/infer/mcmc/nuts.py index 653139f169..bdada30f63 100644 --- a/pyro/infer/mcmc/nuts.py +++ b/pyro/infer/mcmc/nuts.py @@ -8,6 +8,7 @@ import pyro import pyro.distributions as dist from pyro.distributions.util import scalar_like +from pyro.infer.autoguide import init_to_uniform from pyro.infer.mcmc.hmc import HMC from pyro.ops.integrator import velocity_verlet from pyro.util import optional, torch_isnan @@ -95,6 +96,8 @@ class NUTS(HMC): so the sampling will be slower but more robust. Default to 0.8. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Default to 10. + :param callable init_strategy: A per-site initialization function. + See :ref:`autoguide-initialization` section for available functions. Example: @@ -130,7 +133,8 @@ def __init__(self, jit_options=None, ignore_jit_warnings=False, target_accept_prob=0.8, - max_tree_depth=10): + max_tree_depth=10, + init_strategy=init_to_uniform): super().__init__(model, potential_fn, step_size, @@ -142,7 +146,8 @@ def __init__(self, jit_compile=jit_compile, jit_options=jit_options, ignore_jit_warnings=ignore_jit_warnings, - target_accept_prob=target_accept_prob) + target_accept_prob=target_accept_prob, + init_strategy=init_strategy) self.use_multinomial_sampling = use_multinomial_sampling self._max_tree_depth = max_tree_depth # There are three conditions to stop doubling process: diff --git a/pyro/infer/mcmc/util.py b/pyro/infer/mcmc/util.py index 82f38e3e70..45a37f8c32 100644 --- a/pyro/infer/mcmc/util.py +++ b/pyro/infer/mcmc/util.py @@ -14,9 +14,9 @@ import pyro import pyro.poutine as poutine -import pyro.distributions as dist from pyro.distributions.util import broadcast_shape, logsumexp from pyro.infer import config_enumerate +from pyro.infer.autoguide.initialization import InitMessenger, init_to_uniform from pyro.infer.util import is_validation_enabled from pyro.ops import stats from pyro.ops.contract import contract_to_tensor @@ -299,9 +299,9 @@ def get_potential_fn(self, jit_compile=False, skip_jit_warnings=True, jit_option return self._potential_fn -# TODO: expose init_strategy using separate functions. -def _get_init_params(model, model_args, model_kwargs, transforms, potential_fn, prototype_params, - max_tries_initial_params=100, num_chains=1, strategy="uniform"): +def _find_valid_initial_params(model, model_args, model_kwargs, transforms, potential_fn, + prototype_params, max_tries_initial_params=100, num_chains=1, + init_strategy=init_to_uniform): params = prototype_params # For empty models, exit early @@ -310,14 +310,11 @@ def _get_init_params(model, model_args, model_kwargs, transforms, potential_fn, params_per_chain = defaultdict(list) num_found = 0 + model = InitMessenger(init_strategy)(model) for attempt in range(num_chains * max_tries_initial_params): - if strategy == "uniform": - params = {k: dist.Uniform(v.new_full(v.shape, -2), v.new_full(v.shape, 2)).sample() - for k, v in params.items()} - elif strategy == "prior": - trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) - samples = {name: trace.nodes[name]["value"].detach() for name in params} - params = {k: transforms[k](v) for k, v in samples.items()} + trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) + samples = {name: trace.nodes[name]["value"].detach() for name in params} + params = {k: transforms[k](v) for k, v in samples.items()} pe_grad, pe = potential_grad(potential_fn, params) if torch.isfinite(pe) and all(map(torch.all, map(torch.isfinite, pe_grad.values()))): @@ -333,7 +330,8 @@ def _get_init_params(model, model_args, model_kwargs, transforms, potential_fn, def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max_plate_nesting=None, - jit_compile=False, jit_options=None, skip_jit_warnings=False, num_chains=1): + jit_compile=False, jit_options=None, skip_jit_warnings=False, num_chains=1, + init_strategy=init_to_uniform): """ Given a Python callable with Pyro primitives, generates the following model-specific properties needed for inference using HMC/NUTS kernels: @@ -364,6 +362,8 @@ def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max tracer when ``jit_compile=True``. Default is False. :param int num_chains: Number of parallel chains. If `num_chains > 1`, the returned `initial_params` will be a list with `num_chains` elements. + :param callable init_strategy: A per-site initialization function. + See :ref:`autoguide-initialization` section for available functions. :returns: a tuple of (`initial_params`, `potential_fn`, `transforms`, `prototype_trace`) """ # XXX `transforms` domains are sites' supports @@ -407,8 +407,9 @@ def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max # Note that we deliberately do not exercise jit compilation here so as to # enable potential_fn to be picklable (a torch._C.Function cannot be pickled). - init_params = _get_init_params(model, model_args, model_kwargs, transforms, - pe_maker.get_potential_fn(), prototype_params, num_chains=num_chains) + init_params = _find_valid_initial_params(model, model_args, model_kwargs, transforms, + pe_maker.get_potential_fn(), prototype_params, + num_chains=num_chains, init_strategy=init_strategy) potential_fn = pe_maker.get_potential_fn(jit_compile, skip_jit_warnings, jit_options) return init_params, potential_fn, transforms, model_trace diff --git a/tests/infer/mcmc/test_mcmc_api.py b/tests/infer/mcmc/test_mcmc_api.py index d75fcdfb7b..30ea24e92f 100644 --- a/tests/infer/mcmc/test_mcmc_api.py +++ b/tests/infer/mcmc/test_mcmc_api.py @@ -103,8 +103,8 @@ def test_mcmc_interface(num_draws, group_by_chain, num_chains): samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()} sample_mean = samples['y'].mean() sample_std = samples['y'].std() - assert_close(sample_mean, torch.tensor(0.0), atol=0.05) - assert_close(sample_std, torch.tensor(1.0), atol=0.05) + assert_close(sample_mean, torch.tensor(0.0), atol=0.1) + assert_close(sample_std, torch.tensor(1.0), atol=0.1) @pytest.mark.parametrize("num_chains, cpu_count", [ diff --git a/tests/infer/mcmc/test_mcmc_util.py b/tests/infer/mcmc/test_mcmc_util.py index f3b0644b00..45276e3444 100644 --- a/tests/infer/mcmc/test_mcmc_util.py +++ b/tests/infer/mcmc/test_mcmc_util.py @@ -1,12 +1,16 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from functools import partial + import pytest import torch import pyro import pyro.distributions as dist from pyro.infer import Predictive +from pyro.infer.autoguide import (init_to_feasible, init_to_mean, init_to_median, + init_to_sample, init_to_uniform, init_to_value) from pyro.infer.mcmc import NUTS from pyro.infer.mcmc.api import MCMC from pyro.infer.mcmc.util import initialize_model @@ -85,3 +89,36 @@ def model(): mcmc.run() else: mcmc.run() + + +def test_init_to_value(): + def model(): + pyro.sample("x", dist.LogNormal(0, 1)) + + value = torch.randn(()).exp() * 10 + kernel = NUTS(model, init_strategy=partial(init_to_value, values={"x": value})) + kernel.setup(warmup_steps=10) + assert_close(value, kernel.initial_params['x'].exp()) + + +@pytest.mark.parametrize("init_strategy", [ + init_to_feasible, + init_to_mean, + init_to_median, + init_to_sample, + init_to_uniform, + init_to_value, + init_to_feasible(), + init_to_mean(), + init_to_median(num_samples=4), + init_to_sample(), + init_to_uniform(radius=0.1), + init_to_value(values={"x": torch.tensor(3.)}), +]) +def test_init_strategy_smoke(init_strategy): + def model(): + pyro.sample("x", dist.LogNormal(0, 1)) + + value = torch.randn(()).exp() * 10 + kernel = NUTS(model, init_strategy=partial(init_to_value, values={"x": value})) + kernel.setup(warmup_steps=10)