-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Resamper for interactive prior tuning #3118
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
542dbbf
Implement ResamplingCache
fritzo ea3b824
Vectorize for speed
fritzo 5eabf61
lint
fritzo f23dc65
optimization nit
fritzo 15093bf
More optimization
fritzo 234b7e3
Completely vectorize
fritzo 5c38313
Generalize to multiple distributions
fritzo 9276156
Add a tutorial
fritzo 1cc0749
Refactor to use models
fritzo 7b93348
Simplified, but introduced a bug :confused:
fritzo ae4296a
fix bug
fritzo 6d45a77
Update tutorial
fritzo 768d1e1
Implement stable sampling via Gumbel-max trick
fritzo aec8f96
Change nomenclature
fritzo e5448bf
Update prior_predictive.ipynb
fritzo 9798d06
Install pyro-ppl in colab
fritzo 090afef
Address review comments
fritzo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Callable, Dict, Optional | ||
|
||
import torch | ||
|
||
import pyro | ||
import pyro.poutine as poutine | ||
from pyro.poutine.trace_struct import Trace | ||
from pyro.poutine.util import site_is_subsample | ||
|
||
|
||
class Resampler: | ||
"""Resampler for interactive tuning of generative models, typically | ||
when preforming prior predictive checks as an early step of Bayesian | ||
workflow. | ||
|
||
This is intended as a computational cache to speed up the interactive | ||
tuning of the parameters of prior distributions based on samples from a | ||
downstream simulation. The idea is that the simulation can be expensive, | ||
but that when one slightly tweaks parameters of the parameter distribution | ||
then one can reuse most of the previous samples via importance resampling. | ||
|
||
:param callable guide: A pyro model that takes no arguments. The guide | ||
fritzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
should be diffuse, covering more space than the subsequent ``model`` | ||
passed to :meth:`sample`. Must be vectorizable via ``pyro.plate``. | ||
:param callable simulator: An optional larger pyro model with a superset of | ||
the guide's latent variables. Must be vectorizable via ``pyro.plate``. | ||
:param int num_guide_samples: Number of inital samples to draw from the | ||
guide. This should be much larger than the ``num_samples`` requested in | ||
subsequent calls to :meth:`sample`. | ||
:param int max_plate_nesting: The maximum plate nesting in the model. | ||
If absent this will be guessed by running the guide. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
guide: Callable, | ||
simulator: Optional[Callable] = None, | ||
*, | ||
num_guide_samples: int, | ||
max_plate_nesting: Optional[int] = None, | ||
): | ||
super().__init__() | ||
if max_plate_nesting is None: | ||
max_plate_nesting = _guess_max_plate_nesting( | ||
guide if simulator is None else simulator | ||
) | ||
self._particle_dim = -1 - max_plate_nesting | ||
self._gumbels: Optional[torch.Tensor] = None | ||
|
||
# Draw samples from the initial guide. | ||
with pyro.plate("particles", num_guide_samples, dim=self._particle_dim): | ||
trace = poutine.trace(guide).get_trace() | ||
self._old_logp = _log_prob_sum(trace, num_guide_samples) | ||
|
||
if simulator: | ||
# Draw extended samples from the simulator. | ||
trace = poutine.trace(poutine.replay(simulator, trace)).get_trace() | ||
self._samples = { | ||
name: site["value"] | ||
for name, site in trace.nodes.items() | ||
if site["type"] == "sample" and not site_is_subsample(site) | ||
} | ||
|
||
@torch.no_grad() | ||
def sample( | ||
self, model: Callable, num_samples: int, stable: bool = True | ||
) -> Dict[str, torch.Tensor]: | ||
"""Draws a set of at most ``num_samples`` many model samples, | ||
optionally extended by the ``simulator``. | ||
|
||
Internally this importance resamples the samples generated by the | ||
``guide`` in ``.__init__()``, and does not rerun the ``guide`` or | ||
``simulator``. If the original guide samples poorly cover the model | ||
distribution, samples will show low diversity. | ||
|
||
:param callable model: A model with the same latent variables as the | ||
original ``guide``. Must be vectorizable via ``pyro.plate``. | ||
:param int num_samples: The number of samples to draw. | ||
:param bool stable: Whether to use piecewise-constant multinomial | ||
sampling. Set to True for visualization, False for Monte Carlo | ||
integration. Defaults to True. | ||
:returns: A dictionary of stacked samples. | ||
:rtype: Dict[str, torch.Tensor] | ||
""" | ||
num_guide_samples = len(self._old_logp) | ||
with pyro.plate("particles", num_guide_samples, dim=self._particle_dim): | ||
trace = poutine.trace(poutine.condition(model, self._samples)).get_trace() | ||
new_logp = _log_prob_sum(trace, num_guide_samples) | ||
logits = new_logp - self._old_logp | ||
i = self._categorical_sample(logits, num_samples, stable) | ||
samples = {k: v[i] for k, v in self._samples.items()} | ||
return samples | ||
|
||
def _categorical_sample( | ||
self, logits: torch.Tensor, num_samples: int, stable: bool | ||
) -> torch.Tensor: | ||
if not stable: | ||
return torch.multinomial(logits.exp(), num_samples, replacement=True) | ||
|
||
# Implement stable categorical sampling via the Gumbel-max trick. | ||
if self._gumbels is None or len(self._gumbels) < num_samples: | ||
# gumbel ~ -log(-log(uniform(0,1))) | ||
tiny = torch.finfo(logits.dtype).tiny | ||
self._gumbels = logits.new_empty(num_samples, len(logits)).uniform_() | ||
self._gumbels.clamp_(min=tiny).log_().neg_().clamp_(min=tiny).log_().neg_() | ||
return self._gumbels[:num_samples].add(logits).max(-1).indices | ||
|
||
|
||
def _log_prob_sum(trace: Trace, batch_size: int) -> torch.Tensor: | ||
"""Computes vectorized log_prob_sum batched over the leftmost dimension.""" | ||
trace.compute_log_prob() | ||
result = 0.0 | ||
for site in trace.nodes.values(): | ||
if site["type"] == "sample": | ||
logp = site["log_prob"] | ||
assert logp.shape[:1] == (batch_size,) | ||
result += logp.reshape(batch_size, -1).sum(-1) | ||
return result | ||
|
||
|
||
def _guess_max_plate_nesting(model: callable) -> int: | ||
with torch.no_grad(), poutine.block(), poutine.mask(mask=False): | ||
trace = poutine.trace(model).get_trace() | ||
plate_nesting = {0}.union( | ||
-f.dim | ||
for site in trace.nodes.values() | ||
for f in site.get("cond_indep_stack", []) | ||
if f.vectorized | ||
) | ||
return max(plate_nesting) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import functools | ||
|
||
import pytest | ||
import torch | ||
|
||
import pyro | ||
import pyro.distributions as dist | ||
from pyro.infer.resampler import Resampler | ||
from tests.common import assert_close | ||
|
||
|
||
@pytest.mark.parametrize("stable", [False, True]) | ||
def test_resampling_cache(stable): | ||
def model_(a): | ||
pyro.sample("alpha", dist.Dirichlet(a)) | ||
|
||
def simulator(): | ||
a = torch.tensor([2.0, 1.0, 1.0, 2.0]) | ||
alpha = pyro.sample("alpha", dist.Dirichlet(a)) | ||
pyro.sample("x", dist.Normal(alpha, 0.01).to_event(1)) | ||
|
||
# initialize | ||
a = torch.tensor([1.0, 2.0, 1.0, 1.0]) | ||
guide = functools.partial(model_, a) | ||
resampler = Resampler(guide, simulator, num_guide_samples=10000) | ||
|
||
# resample | ||
b = torch.tensor([1.0, 2.0, 3.0, 4.0]) | ||
model = functools.partial(model_, b) | ||
samples = resampler.sample(model, 1000) | ||
assert all(v.shape[:1] == (1000,) for v in samples.values()) | ||
num_unique = len(set(map(tuple, samples["alpha"].tolist()))) | ||
assert num_unique >= 500 | ||
|
||
# check moments | ||
expected_mean = b / b.sum() | ||
actual_mean = samples["alpha"].mean(0) | ||
assert_close(actual_mean, expected_mean, atol=0.01) | ||
actual_mean = samples["x"].mean(0) | ||
assert_close(actual_mean, expected_mean, atol=0.01) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,6 +93,7 @@ List of Tutorials | |
tensor_shapes | ||
modules | ||
workflow | ||
prior_predictive | ||
jit | ||
svi_horovod | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this silences a new sphinx warning