Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Resamper for interactive prior tuning #3118

Merged
merged 17 commits into from
Jul 25, 2022
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
language = "en"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this silences a new sphinx warning


# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
Expand Down
7 changes: 7 additions & 0 deletions docs/source/infer.util.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ Model inspection
.. automodule:: pyro.infer.inspect
:members:
:member-order: bysource

Interactive prior tuning
------------------------

.. automodule:: pyro.infer.resampler
:members:
:member-order: bysource
134 changes: 134 additions & 0 deletions pyro/infer/resampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, Dict, Optional

import torch

import pyro
import pyro.poutine as poutine
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import site_is_subsample


class Resampler:
"""Resampler for interactive tuning of generative models, typically
when preforming prior predictive checks as an early step of Bayesian
workflow.

This is intended as a computational cache to speed up the interactive
tuning of the parameters of prior distributions based on samples from a
downstream simulation. The idea is that the simulation can be expensive,
but that when one slightly tweaks parameters of the parameter distribution
then one can reuse most of the previous samples via importance resampling.

:param callable guide: A pyro model that takes no arguments. The guide
fritzo marked this conversation as resolved.
Show resolved Hide resolved
should be diffuse, covering more space than the subsequent ``model``
passed to :meth:`sample`.
:param callable simulator: An optional larger pyro model with a superset of
the guide's latent variables.
:param int num_guide_samples: Number of inital samples to draw from the
guide. This should be much larger than the ``num_samples`` requested in
subsequent calls to :meth:`sample`.
:param int max_plate_nesting: The maximum plate nesting in the model.
If absent this will be guessed by running the guide.
"""

def __init__(
self,
guide: Callable,
simulator: Optional[Callable] = None,
*,
num_guide_samples: int,
max_plate_nesting: Optional[int] = None,
):
super().__init__()
if max_plate_nesting is None:
max_plate_nesting = _guess_max_plate_nesting(
guide if simulator is None else simulator
)
self._particle_dim = -1 - max_plate_nesting
self._gumbels: Optional[torch.Tensor] = None

# Draw samples from the initial guide.
with pyro.plate("particles", num_guide_samples, dim=self._particle_dim):
trace = poutine.trace(guide).get_trace()
self._old_logp = _log_prob_sum(trace, num_guide_samples)

if simulator:
# Draw extended samples from the simulator.
trace = poutine.trace(poutine.replay(simulator, trace)).get_trace()
self._samples = {
name: site["value"]
for name, site in trace.nodes.items()
if site["type"] == "sample" and not site_is_subsample(site)
}

@torch.no_grad()
def sample(
self, model: Callable, num_samples: int, stable: bool = True
) -> Dict[str, torch.Tensor]:
"""Draws a set of at most ``num_samples`` many model samples,
optionally extended by the ``simulator``.

Internally this importance resamples the samples generated by the
``guide`` in ``.__init__()``, and does not rerun the ``guide`` or
``simulator``. If the original guide samples poorly cover the model
distribution, samples will show low diversity.

:param callable model: A model with the same latent variables as the
original ``guide``.
:param int num_samples: The number of samples to draw.
:param bool stable: Whether to use piecewise-constant multinomial
sampling. Set to True for visualization, False for Monte Carlo
integration. Defaults to True.
:returns: A dictionary of stacked samples.
:rtype: Dict[str, torch.Tensor]
"""
# Importance sample: keep all weights >= 1; subsample weights < 1.
fritzo marked this conversation as resolved.
Show resolved Hide resolved
num_guide_samples = len(self._old_logp)
with pyro.plate("particles", num_guide_samples, dim=self._particle_dim):
trace = poutine.trace(poutine.condition(model, self._samples)).get_trace()
new_logp = _log_prob_sum(trace, num_guide_samples)
logits = new_logp - self._old_logp
i = self._categorical_sample(logits, num_samples, stable)
samples = {k: v[i] for k, v in self._samples.items()}
return samples

def _categorical_sample(
self, logits: torch.Tensor, num_samples: int, stable: bool
) -> torch.Tensor:
if not stable:
return torch.multinomial(logits.exp(), num_samples, replacement=True)

# Implement stable categorical sampling via the Gumbel-max trick.
if self._gumbels is None or len(self._gumbels) < num_samples:
# gumbel ~ -log(-log(uniform(0,1)))
tiny = torch.finfo(logits.dtype).tiny
self._gumbels = logits.new_empty(num_samples, len(logits)).uniform_()
self._gumbels.clamp_(min=tiny).log_().neg_().clamp_(min=tiny).log_().neg_()
return self._gumbels[:num_samples].add(logits).max(-1).indices


def _log_prob_sum(trace: Trace, batch_size: int) -> torch.Tensor:
"""Computes vectorized log_prob_sum batched over the leftmost dimension."""
trace.compute_log_prob()
result = 0.0
for site in trace.nodes.values():
if site["type"] == "sample":
logp = site["log_prob"]
assert logp.shape[:1] == (batch_size,)
result += logp.reshape(batch_size, -1).sum(-1)
return result


def _guess_max_plate_nesting(model: callable) -> int:
with torch.no_grad(), poutine.block(), poutine.mask(mask=False):
trace = poutine.trace(model).get_trace()
plate_nesting = {0}.union(
-f.dim
for site in trace.nodes.values()
for f in site.get("cond_indep_stack", [])
if f.vectorized
)
return max(plate_nesting)
43 changes: 43 additions & 0 deletions tests/infer/test_resampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import functools

import pytest
import torch

import pyro
import pyro.distributions as dist
from pyro.infer.resampler import Resampler
from tests.common import assert_close


@pytest.mark.parametrize("stable", [False, True])
def test_resampling_cache(stable):
def model_(a):
pyro.sample("alpha", dist.Dirichlet(a))

def simulator():
a = torch.tensor([2.0, 1.0, 1.0, 2.0])
alpha = pyro.sample("alpha", dist.Dirichlet(a))
pyro.sample("x", dist.Normal(alpha, 0.01).to_event(1))

# initialize
a = torch.tensor([1.0, 2.0, 1.0, 1.0])
guide = functools.partial(model_, a)
resampler = Resampler(guide, simulator, num_guide_samples=10000)

# resample
b = torch.tensor([1.0, 2.0, 3.0, 4.0])
model = functools.partial(model_, b)
samples = resampler.sample(model, 1000)
assert all(v.shape[:1] == (1000,) for v in samples.values())
num_unique = len(set(map(tuple, samples["alpha"].tolist())))
assert num_unique >= 500

# check moments
expected_mean = b / b.sum()
actual_mean = samples["alpha"].mean(0)
assert_close(actual_mean, expected_mean, atol=0.01)
actual_mean = samples["x"].mean(0)
assert_close(actual_mean, expected_mean, atol=0.01)
2 changes: 1 addition & 1 deletion tutorial/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
language = "en"

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
Expand Down
1 change: 1 addition & 0 deletions tutorial/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ List of Tutorials
tensor_shapes
modules
workflow
prior_predictive
jit
svi_horovod

Expand Down
Loading