Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initialization strategy that randomly samples values #2482

Merged
merged 4 commits into from
May 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pyro.poutine as poutine
from pyro.distributions.transforms import DiscreteCosineTransform
from pyro.infer import MCMC, NUTS, SMCFilter, infer_discrete
from pyro.infer.autoguide import init_to_value
from pyro.infer.autoguide import init_to_generated, init_to_value
from pyro.infer.mcmc import ArrowheadMassMatrix
from pyro.infer.reparam import DiscreteCosineReparam
from pyro.util import warn_if_nan
Expand Down Expand Up @@ -301,19 +301,24 @@ def fit(self, **options):
raise NotImplementedError("regional models do not support DiscreteCosineReparam")

# Heuristically initialze to feasible latents.
logger.info("Heuristically initializing...")
heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
for k in list(options) if k.startswith("heuristic_")}
init_values = self.heuristic(**heuristic_options)
assert isinstance(init_values, dict)
assert "auxiliary" in init_values, \
".heuristic() did not define auxiliary value"
if self._dct is not None:
# Also initialize DCT transformed coordinates.
x = init_values["auxiliary"]
x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x)
x = DiscreteCosineTransform(smooth=self._dct)(x)
init_values["auxiliary_dct"] = x
for k in list(options)
if k.startswith("heuristic_")}

def heuristic():
logger.info("Heuristically initializing...")
with poutine.block():
init_values = self.heuristic(**heuristic_options)
assert isinstance(init_values, dict)
assert "auxiliary" in init_values, \
".heuristic() did not define auxiliary value"
if self._dct is not None:
# Also initialize DCT transformed coordinates.
x = init_values["auxiliary"]
x = biject_to(constraints.interval(-0.5, self.population + 0.5)).inv(x)
x = DiscreteCosineTransform(smooth=self._dct)(x)
init_values["auxiliary_dct"] = x
return init_to_value(values=init_values)

# Configure a kernel.
logger.info("Running inference...")
Expand All @@ -325,7 +330,7 @@ def fit(self, **options):
model = poutine.reparam(model, {"auxiliary": rep})
kernel = NUTS(model,
full_mass=full_mass,
init_strategy=init_to_value(values=init_values),
init_strategy=init_to_generated(generate=heuristic),
max_tree_depth=max_tree_depth)
if options.pop("arrowhead_mass", False):
kernel.mass_matrix_adapter = ArrowheadMassMatrix()
Expand Down
6 changes: 3 additions & 3 deletions pyro/infer/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
AutoDiscreteParallel, AutoGuide, AutoGuideList, AutoIAFNormal,
AutoLaplaceApproximation, AutoLowRankMultivariateNormal,
AutoMultivariateNormal, AutoNormal, AutoNormalizingFlow)
from pyro.infer.autoguide.utils import mean_field_entropy
from pyro.infer.autoguide.initialization import (init_to_feasible, init_to_mean, init_to_median,
from pyro.infer.autoguide.initialization import (init_to_feasible, init_to_generated, init_to_mean, init_to_median,
init_to_sample, init_to_uniform, init_to_value)

from pyro.infer.autoguide.utils import mean_field_entropy

__all__ = [
'AutoCallable',
Expand All @@ -25,6 +24,7 @@
'AutoNormal',
'AutoNormalizingFlow',
'init_to_feasible',
'init_to_generated',
'init_to_mean',
'init_to_median',
'init_to_sample',
Expand Down
43 changes: 40 additions & 3 deletions pyro/infer/autoguide/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def init_to_mean(site=None):

def init_to_uniform(site=None, radius=2):
"""
Initialize to a random point in the area `(-radius, radius)` of unconstrained domain.
Initialize to a random point in the area ``(-radius, radius)`` of
unconstrained domain.

:param float radius: specifies the range to draw an initial point in the unconstrained domain.
"""
Expand All @@ -114,8 +115,8 @@ def init_to_uniform(site=None, radius=2):

def init_to_value(site=None, values={}):
"""
Initialize to the value specified in `values`. We defer to
:func:`init_to_uniform` strategy for sites which do not appear in `values`.
Initialize to the value specified in ``values``. We defer to
:func:`init_to_uniform` strategy for sites which do not appear in ``values``.

:param dict values: dictionary of initial values keyed by site name.
"""
Expand All @@ -128,6 +129,42 @@ def init_to_value(site=None, values={}):
return init_to_uniform(site)


class _InitToGenerated:
def __init__(self, generate):
self.generate = generate
self._init = None
self._seen = set()

def __call__(self, site):
if self._init is None or site["name"] in self._seen:
self._init = self.generate()
self._seen = {site["name"]}
return self._init(site)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic LGTM.



def init_to_generated(site=None, generate=lambda: init_to_uniform):
"""
Initialize to another initialization strategy returned by the callback
``generate`` which is called once per model execution.

This is like :func:`init_to_value` but can produce different (e.g. random)
values once per model execution. For example to generate values and return
``init_to_value`` you could define::

def generate():
values = {"x": torch.randn(100), "y": torch.rand(5)}
return init_to_value(values=values)

my_init_fn = init_to_generated(generate=generate)

:param callable generate: A callable returning another initialization
function, e.g. returning an ``init_to_value(values={...})`` populated
with a dictionary of random samples.
"""
init = _InitToGenerated(generate)
return init if site is None else init(site)


class InitMessenger(Messenger):
"""
Initializes a site by replacing ``.sample()`` calls with values
Expand Down
8 changes: 5 additions & 3 deletions tests/infer/mcmc/test_mcmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import pyro
import pyro.distributions as dist
from pyro.infer import Predictive
from pyro.infer.autoguide import (init_to_feasible, init_to_mean, init_to_median,
init_to_sample, init_to_uniform, init_to_value)
from pyro.infer.autoguide import (init_to_feasible, init_to_generated, init_to_mean, init_to_median, init_to_sample,
init_to_uniform, init_to_value)
from pyro.infer.mcmc import NUTS
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc.util import initialize_model
Expand Down Expand Up @@ -114,7 +114,9 @@ def model():
init_to_sample(),
init_to_uniform(radius=0.1),
init_to_value(values={"x": torch.tensor(3.)}),
])
init_to_generated(
generate=lambda: init_to_value(values={"x": torch.rand(())})),
], ids=str)
def test_init_strategy_smoke(init_strategy):
def model():
pyro.sample("x", dist.LogNormal(0, 1))
Expand Down
34 changes: 34 additions & 0 deletions tests/infer/test_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

import pyro
import pyro.distributions as dist
from pyro.infer.autoguide.initialization import InitMessenger, init_to_generated, init_to_value


def test_init_to_generated():
def model():
x = pyro.sample("x", dist.Normal(0, 1))
y = pyro.sample("y", dist.Normal(0, 1))
z = pyro.sample("z", dist.Normal(0, 1))
return x, y, z

class MockGenerate:
def __init__(self):
self.counter = 0

def __call__(self):
values = {"x": torch.tensor(self.counter + 0.0),
"y": torch.tensor(self.counter + 0.5)}
self.counter += 1
return init_to_value(values=values)

mock_generate = MockGenerate()
with InitMessenger(init_to_generated(generate=mock_generate)):
for i in range(5):
x, y, z = model()
assert x == i
assert y == i + 0.5
assert mock_generate.counter == 5