Skip to content

Commit

Permalink
Expose initialization strategy in HMC (#2417)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Apr 16, 2020
1 parent 05586c5 commit de0ab28
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 25 deletions.
6 changes: 5 additions & 1 deletion pyro/infer/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
AutoLaplaceApproximation, AutoLowRankMultivariateNormal,
AutoMultivariateNormal, AutoNormal, AutoNormalizingFlow)
from pyro.infer.autoguide.utils import mean_field_entropy
from pyro.infer.autoguide.initialization import init_to_feasible, init_to_mean, init_to_median, init_to_sample
from pyro.infer.autoguide.initialization import (init_to_feasible, init_to_mean, init_to_median,
init_to_sample, init_to_uniform, init_to_value)


__all__ = [
'AutoCallable',
Expand All @@ -26,5 +28,7 @@
'init_to_mean',
'init_to_median',
'init_to_sample',
'init_to_uniform',
'init_to_value',
'mean_field_entropy',
]
56 changes: 51 additions & 5 deletions pyro/infer/autoguide/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
trace ``site`` dict and returns an appropriately sized ``value`` to serve
as an initial constrained value for a guide estimate.
"""
import functools

import torch
from torch.distributions import transform_to

Expand All @@ -19,34 +21,45 @@
from pyro.util import torch_isnan


# TODO: move this file out of `autoguide` in a minor release

def _is_multivariate(d):
while isinstance(d, (Independent, MaskedDistribution)):
d = d.base_dist
return any(size > 1 for size in d.event_shape)


def init_to_feasible(site):
def init_to_feasible(site=None):
"""
Initialize to an arbitrary feasible point, ignoring distribution
parameters.
"""
if site is None:
return init_to_feasible

value = site["fn"].sample().detach()
t = transform_to(site["fn"].support)
return t(torch.zeros_like(t.inv(value)))


def init_to_sample(site):
def init_to_sample(site=None):
"""
Initialize to a random sample from the prior.
"""
if site is None:
return init_to_sample

return site["fn"].sample().detach()


def init_to_median(site, num_samples=15):
def init_to_median(site=None, num_samples=15):
"""
Initialize to the prior median; fallback to a feasible point if median is
undefined.
"""
if site is None:
return functools.partial(init_to_median, num_samples=num_samples)

# The median undefined for multivariate distributions.
if _is_multivariate(site["fn"]):
return init_to_feasible(site)
Expand All @@ -64,10 +77,13 @@ def init_to_median(site, num_samples=15):
return init_to_feasible(site)


def init_to_mean(site):
def init_to_mean(site=None):
"""
Initialize to the prior mean; fallback to median if mean is undefined.
"""
if site is None:
return init_to_mean

try:
# Try .mean() method.
value = site["fn"].mean.detach()
Expand All @@ -78,10 +94,40 @@ def init_to_mean(site):
return value
except (NotImplementedError, ValueError):
# Fall back to a median.
# This is requred for distributions with infinite variance, e.g. Cauchy.
# This is required for distributions with infinite variance, e.g. Cauchy.
return init_to_median(site)


def init_to_uniform(site=None, radius=2):
"""
Initialize to a random point in the area `(-radius, radius)` of unconstrained domain.
:param float radius: specifies the range to draw an initial point in the unconstrained domain.
"""
if site is None:
return functools.partial(init_to_uniform, radius=radius)

value = site["fn"].sample().detach()
t = transform_to(site["fn"].support)
return t(torch.rand_like(t.inv(value)) * (2 * radius) - radius)


def init_to_value(site=None, values={}):
"""
Initialize to the value specified in `values`. We defer to
:func:`init_to_uniform` strategy for sites which do not appear in `values`.
:param dict values: dictionary of initial values keyed by site name.
"""
if site is None:
return functools.partial(init_to_value, values=values)

if site["name"] in values:
return values[site["name"]]
else:
return init_to_uniform(site)


class InitMessenger(Messenger):
"""
Initializes a site by replacing ``.sample()`` calls with values
Expand Down
8 changes: 7 additions & 1 deletion pyro/infer/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pyro.distributions as dist
from pyro.distributions.util import eye_like, scalar_like

from pyro.infer.autoguide import init_to_uniform
from pyro.infer.mcmc.adaptation import WarmupAdapter
from pyro.infer.mcmc.mcmc_kernel import MCMCKernel
from pyro.infer.mcmc.util import initialize_model
Expand Down Expand Up @@ -63,6 +64,8 @@ class HMC(MCMCKernel):
tracer when ``jit_compile=True``. Default is False.
:param float target_accept_prob: Increasing this value will lead to a smaller
step size, hence the sampling will be slower and more robust. Default to 0.8.
:param callable init_strategy: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
.. note:: Internally, the mass matrix will be ordered according to the order
of the names of latent variables, not the order of their appearance in
Expand Down Expand Up @@ -102,7 +105,8 @@ def __init__(self,
jit_compile=False,
jit_options=None,
ignore_jit_warnings=False,
target_accept_prob=0.8):
target_accept_prob=0.8,
init_strategy=init_to_uniform):
if not ((model is None) ^ (potential_fn is None)):
raise ValueError("Only one of `model` or `potential_fn` must be specified.")
# NB: deprecating args - model, transforms
Expand All @@ -112,6 +116,7 @@ def __init__(self,
self._jit_compile = jit_compile
self._jit_options = jit_options
self._ignore_jit_warnings = ignore_jit_warnings
self._init_strategy = init_strategy

self.potential_fn = potential_fn
if trajectory_length is not None:
Expand Down Expand Up @@ -237,6 +242,7 @@ def _initialize_model_properties(self, model_args, model_kwargs):
jit_compile=self._jit_compile,
jit_options=self._jit_options,
skip_jit_warnings=self._ignore_jit_warnings,
init_strategy=self._init_strategy,
)
self.potential_fn = potential_fn
self.transforms = transforms
Expand Down
9 changes: 7 additions & 2 deletions pyro/infer/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pyro
import pyro.distributions as dist
from pyro.distributions.util import scalar_like
from pyro.infer.autoguide import init_to_uniform
from pyro.infer.mcmc.hmc import HMC
from pyro.ops.integrator import velocity_verlet
from pyro.util import optional, torch_isnan
Expand Down Expand Up @@ -95,6 +96,8 @@ class NUTS(HMC):
so the sampling will be slower but more robust. Default to 0.8.
:param int max_tree_depth: Max depth of the binary tree created during the doubling
scheme of NUTS sampler. Default to 10.
:param callable init_strategy: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
Example:
Expand Down Expand Up @@ -130,7 +133,8 @@ def __init__(self,
jit_options=None,
ignore_jit_warnings=False,
target_accept_prob=0.8,
max_tree_depth=10):
max_tree_depth=10,
init_strategy=init_to_uniform):
super().__init__(model,
potential_fn,
step_size,
Expand All @@ -142,7 +146,8 @@ def __init__(self,
jit_compile=jit_compile,
jit_options=jit_options,
ignore_jit_warnings=ignore_jit_warnings,
target_accept_prob=target_accept_prob)
target_accept_prob=target_accept_prob,
init_strategy=init_strategy)
self.use_multinomial_sampling = use_multinomial_sampling
self._max_tree_depth = max_tree_depth
# There are three conditions to stop doubling process:
Expand Down
29 changes: 15 additions & 14 deletions pyro/infer/mcmc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

import pyro
import pyro.poutine as poutine
import pyro.distributions as dist
from pyro.distributions.util import broadcast_shape, logsumexp
from pyro.infer import config_enumerate
from pyro.infer.autoguide.initialization import InitMessenger, init_to_uniform
from pyro.infer.util import is_validation_enabled
from pyro.ops import stats
from pyro.ops.contract import contract_to_tensor
Expand Down Expand Up @@ -299,9 +299,9 @@ def get_potential_fn(self, jit_compile=False, skip_jit_warnings=True, jit_option
return self._potential_fn


# TODO: expose init_strategy using separate functions.
def _get_init_params(model, model_args, model_kwargs, transforms, potential_fn, prototype_params,
max_tries_initial_params=100, num_chains=1, strategy="uniform"):
def _find_valid_initial_params(model, model_args, model_kwargs, transforms, potential_fn,
prototype_params, max_tries_initial_params=100, num_chains=1,
init_strategy=init_to_uniform):
params = prototype_params

# For empty models, exit early
Expand All @@ -310,14 +310,11 @@ def _get_init_params(model, model_args, model_kwargs, transforms, potential_fn,

params_per_chain = defaultdict(list)
num_found = 0
model = InitMessenger(init_strategy)(model)
for attempt in range(num_chains * max_tries_initial_params):
if strategy == "uniform":
params = {k: dist.Uniform(v.new_full(v.shape, -2), v.new_full(v.shape, 2)).sample()
for k, v in params.items()}
elif strategy == "prior":
trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
samples = {name: trace.nodes[name]["value"].detach() for name in params}
params = {k: transforms[k](v) for k, v in samples.items()}
trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
samples = {name: trace.nodes[name]["value"].detach() for name in params}
params = {k: transforms[k](v) for k, v in samples.items()}
pe_grad, pe = potential_grad(potential_fn, params)

if torch.isfinite(pe) and all(map(torch.all, map(torch.isfinite, pe_grad.values()))):
Expand All @@ -333,7 +330,8 @@ def _get_init_params(model, model_args, model_kwargs, transforms, potential_fn,


def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max_plate_nesting=None,
jit_compile=False, jit_options=None, skip_jit_warnings=False, num_chains=1):
jit_compile=False, jit_options=None, skip_jit_warnings=False, num_chains=1,
init_strategy=init_to_uniform):
"""
Given a Python callable with Pyro primitives, generates the following model-specific
properties needed for inference using HMC/NUTS kernels:
Expand Down Expand Up @@ -364,6 +362,8 @@ def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max
tracer when ``jit_compile=True``. Default is False.
:param int num_chains: Number of parallel chains. If `num_chains > 1`,
the returned `initial_params` will be a list with `num_chains` elements.
:param callable init_strategy: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
:returns: a tuple of (`initial_params`, `potential_fn`, `transforms`, `prototype_trace`)
"""
# XXX `transforms` domains are sites' supports
Expand Down Expand Up @@ -407,8 +407,9 @@ def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max

# Note that we deliberately do not exercise jit compilation here so as to
# enable potential_fn to be picklable (a torch._C.Function cannot be pickled).
init_params = _get_init_params(model, model_args, model_kwargs, transforms,
pe_maker.get_potential_fn(), prototype_params, num_chains=num_chains)
init_params = _find_valid_initial_params(model, model_args, model_kwargs, transforms,
pe_maker.get_potential_fn(), prototype_params,
num_chains=num_chains, init_strategy=init_strategy)
potential_fn = pe_maker.get_potential_fn(jit_compile, skip_jit_warnings, jit_options)
return init_params, potential_fn, transforms, model_trace

Expand Down
4 changes: 2 additions & 2 deletions tests/infer/mcmc/test_mcmc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def test_mcmc_interface(num_draws, group_by_chain, num_chains):
samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()}
sample_mean = samples['y'].mean()
sample_std = samples['y'].std()
assert_close(sample_mean, torch.tensor(0.0), atol=0.05)
assert_close(sample_std, torch.tensor(1.0), atol=0.05)
assert_close(sample_mean, torch.tensor(0.0), atol=0.1)
assert_close(sample_std, torch.tensor(1.0), atol=0.1)


@pytest.mark.parametrize("num_chains, cpu_count", [
Expand Down
37 changes: 37 additions & 0 deletions tests/infer/mcmc/test_mcmc_util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from functools import partial

import pytest
import torch

import pyro
import pyro.distributions as dist
from pyro.infer import Predictive
from pyro.infer.autoguide import (init_to_feasible, init_to_mean, init_to_median,
init_to_sample, init_to_uniform, init_to_value)
from pyro.infer.mcmc import NUTS
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc.util import initialize_model
Expand Down Expand Up @@ -85,3 +89,36 @@ def model():
mcmc.run()
else:
mcmc.run()


def test_init_to_value():
def model():
pyro.sample("x", dist.LogNormal(0, 1))

value = torch.randn(()).exp() * 10
kernel = NUTS(model, init_strategy=partial(init_to_value, values={"x": value}))
kernel.setup(warmup_steps=10)
assert_close(value, kernel.initial_params['x'].exp())


@pytest.mark.parametrize("init_strategy", [
init_to_feasible,
init_to_mean,
init_to_median,
init_to_sample,
init_to_uniform,
init_to_value,
init_to_feasible(),
init_to_mean(),
init_to_median(num_samples=4),
init_to_sample(),
init_to_uniform(radius=0.1),
init_to_value(values={"x": torch.tensor(3.)}),
])
def test_init_strategy_smoke(init_strategy):
def model():
pyro.sample("x", dist.LogNormal(0, 1))

value = torch.randn(()).exp() * 10
kernel = NUTS(model, init_strategy=partial(init_to_value, values={"x": value}))
kernel.setup(warmup_steps=10)

0 comments on commit de0ab28

Please sign in to comment.