From 8fd0bf51d5f1813dac995f111edf041f461c1029 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= <8431159+mtsokol@users.noreply.github.com> Date: Wed, 16 Jun 2021 22:26:26 +0200 Subject: [PATCH] StreamingMCMC class (#2857) --- docs/source/pyro.infer.mcmc.txt | 10 ++ pyro/infer/mcmc/__init__.py | 3 +- pyro/infer/mcmc/api.py | 204 +++++++++++++++++++++++------- pyro/infer/mcmc/util.py | 30 +++++ tests/infer/mcmc/test_mcmc_api.py | 92 ++++++++++---- 5 files changed, 261 insertions(+), 78 deletions(-) diff --git a/docs/source/pyro.infer.mcmc.txt b/docs/source/pyro.infer.mcmc.txt index 9d6b483eea..0f51191474 100644 --- a/docs/source/pyro.infer.mcmc.txt +++ b/docs/source/pyro.infer.mcmc.txt @@ -6,6 +6,14 @@ MCMC :undoc-members: :show-inheritance: +StreamingMCMC +------------- + +.. autoclass:: pyro.infer.mcmc.api.StreamingMCMC + :members: + :undoc-members: + :show-inheritance: + MCMCKernel ---------- .. autoclass:: pyro.infer.mcmc.mcmc_kernel.MCMCKernel @@ -43,3 +51,5 @@ Utilities .. autofunction:: pyro.infer.mcmc.util.initialize_model .. autofunction:: pyro.infer.mcmc.util.diagnostics + +.. autofunction:: pyro.infer.mcmc.util.select_samples diff --git a/pyro/infer/mcmc/__init__.py b/pyro/infer/mcmc/__init__.py index 99d241b162..7be4bae99c 100644 --- a/pyro/infer/mcmc/__init__.py +++ b/pyro/infer/mcmc/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from pyro.infer.mcmc.adaptation import ArrowheadMassMatrix, BlockMassMatrix -from pyro.infer.mcmc.api import MCMC +from pyro.infer.mcmc.api import MCMC, StreamingMCMC from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.nuts import NUTS @@ -12,4 +12,5 @@ "HMC", "MCMC", "NUTS", + "StreamingMCMC", ] diff --git a/pyro/infer/mcmc/api.py b/pyro/infer/mcmc/api.py index 1aeee9e5af..f121a51f4b 100644 --- a/pyro/infer/mcmc/api.py +++ b/pyro/infer/mcmc/api.py @@ -9,14 +9,16 @@ code that works with different backends. - minimal memory consumption with multiprocessing and CUDA. """ - +import copy import json import logging import queue import signal import threading import warnings +from abc import ABC, abstractmethod from collections import OrderedDict +from typing import Dict import torch import torch.multiprocessing as mp @@ -31,7 +33,8 @@ initialize_logger, ) from pyro.infer.mcmc.nuts import NUTS -from pyro.infer.mcmc.util import diagnostics, print_summary +from pyro.infer.mcmc.util import diagnostics, print_summary, select_samples +from pyro.ops.streaming import CountMeanVarianceStats, StatsOfDict, StreamingStats from pyro.util import optional MAX_SEED = 2**32 - 1 @@ -257,7 +260,46 @@ def run(self, *args, **kwargs): self.terminate(terminate_workers=exc_raised) -class MCMC: +class AbstractMCMC(ABC): + """ + Base class for MCMC methods. + """ + def __init__(self, kernel, num_chains, transforms): + self.kernel = kernel + self.num_chains = num_chains + self.transforms = transforms + + @abstractmethod + def run(self, *args, **kwargs): + raise NotImplementedError + + def _set_transforms(self, *args, **kwargs): + # Use `kernel.transforms` when available + if getattr(self.kernel, "transforms", None) is not None: + self.transforms = self.kernel.transforms + # Else, get transforms from model (e.g. in multiprocessing). + elif self.kernel.model: + warmup_steps = 0 + self.kernel.setup(warmup_steps, *args, **kwargs) + self.transforms = self.kernel.transforms + # Assign default value + else: + self.transforms = {} + + def _validate_kernel(self, initial_params): + if isinstance(self.kernel, (HMC, NUTS)) and self.kernel.potential_fn is not None: + if initial_params is None: + raise ValueError("Must provide valid initial parameters to begin sampling" + " when using `potential_fn` in HMC/NUTS kernel.") + + def _validate_initial_params(self, initial_params): + for v in initial_params.values(): + if v.shape[0] != self.num_chains: + raise ValueError("The leading dimension of tensors in `initial_params` " + "must match the number of chains.") + + +class MCMC(AbstractMCMC): """ Wrapper class for Markov Chain Monte Carlo algorithms. Specific MCMC algorithms are TraceKernel instances and need to be supplied as a ``kernel`` argument @@ -307,28 +349,21 @@ class MCMC: def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None, num_chains=1, hook_fn=None, mp_context=None, disable_progbar=False, disable_validation=True, transforms=None, save_params=None): + super().__init__(kernel, num_chains, transforms) self.warmup_steps = num_samples if warmup_steps is None else warmup_steps # Stan self.num_samples = num_samples - self.kernel = kernel - self.transforms = transforms self.disable_validation = disable_validation self._samples = None self._args = None self._kwargs = None if save_params is not None: kernel.save_params = save_params - if isinstance(self.kernel, (HMC, NUTS)) and self.kernel.potential_fn is not None: - if initial_params is None: - raise ValueError("Must provide valid initial parameters to begin sampling" - " when using `potential_fn` in HMC/NUTS kernel.") + self._validate_kernel(initial_params) parallel = False if num_chains > 1: # check that initial_params is different for each chain if initial_params: - for v in initial_params.values(): - if v.shape[0] != num_chains: - raise ValueError("The leading dimension of tensors in `initial_params` " - "must match the number of chains.") + self._validate_initial_params(initial_params) # FIXME: probably we want to use "spawn" method by default to avoid the error # CUDA initialization error https://github.com/pytorch/pytorch/issues/2517 # even that we run MCMC in CPU. @@ -348,10 +383,7 @@ def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None, else: if initial_params: initial_params = {k: v.unsqueeze(0) for k, v in initial_params.items()} - - self.num_chains = num_chains self._diagnostics = [None] * num_chains - if parallel: self.sampler = _MultiSampler(kernel, num_samples, self.warmup_steps, num_chains, mp_context, disable_progbar, initial_params=initial_params, hook=hook_fn) @@ -422,17 +454,7 @@ def model(data): # If transforms is not explicitly provided, infer automatically using # model args, kwargs. if self.transforms is None: - # Use `kernel.transforms` when available - if getattr(self.kernel, "transforms", None) is not None: - self.transforms = self.kernel.transforms - # Else, get transforms from model (e.g. in multiprocessing). - elif self.kernel.model: - warmup_steps = 0 - self.kernel.setup(warmup_steps, *args, **kwargs) - self.transforms = self.kernel.transforms - # Assign default value - else: - self.transforms = {} + self._set_transforms(*args, **kwargs) # transform samples back to constrained space for name, z in z_acc.items(): @@ -447,30 +469,10 @@ def get_samples(self, num_samples=None, group_by_chain=False): """ Get samples from the MCMC run, potentially resampling with replacement. - :param int num_samples: Number of samples to return. If `None`, all the samples - from an MCMC chain are returned in their original ordering. - :param bool group_by_chain: Whether to preserve the chain dimension. If True, - all samples will have num_chains as the size of their leading dimension. - :return: dictionary of samples keyed by site name. + For parameter details see: :meth:`select_samples `. """ samples = self._samples - if num_samples is None: - # reshape to collapse chain dim when group_by_chain=False - if not group_by_chain: - samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()} - else: - if not samples: - raise ValueError("No samples found from MCMC run.") - if group_by_chain: - batch_dim = 1 - else: - samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()} - batch_dim = 0 - sample_tensor = list(samples.values())[0] - batch_size, device = sample_tensor.shape[batch_dim], sample_tensor.device - idxs = torch.randint(0, batch_size, size=(num_samples,), device=device) - samples = {k: v.index_select(batch_dim, idxs) for k, v in samples.items()} - return samples + return select_samples(samples, num_samples, group_by_chain) def diagnostics(self): """ @@ -496,3 +498,107 @@ def summary(self, prob=0.9): if 'divergences' in self._diagnostics[0]: print("Number of divergences: {}".format( sum([len(self._diagnostics[i]['divergences']) for i in range(self.num_chains)]))) + + +class StreamingMCMC(AbstractMCMC): + """ + MCMC that computes required statistics in a streaming fashion. For this class no samples are retained + but only aggregated statistics. This is useful for running memory expensive models where we care only + about specific statistics (especially useful in a memory constrained environments like GPU). + + For available streaming ops please see :mod:`~pyro.ops.streaming`. + """ + def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None, + statistics=None, num_chains=1, hook_fn=None, disable_progbar=False, + disable_validation=True, transforms=None, save_params=None): + super().__init__(kernel, num_chains, transforms) + self.warmup_steps = num_samples if warmup_steps is None else warmup_steps # Stan + self.num_samples = num_samples + self.disable_validation = disable_validation + self._samples = None + self._args = None + self._kwargs = None + if statistics is None: + statistics = StatsOfDict(default=CountMeanVarianceStats) + self._statistics = statistics + self._default_statistics = copy.deepcopy(statistics) + if save_params is not None: + kernel.save_params = save_params + self._validate_kernel(initial_params) + if num_chains > 1: + if initial_params: + self._validate_initial_params(initial_params) + else: + if initial_params: + initial_params = {k: v.unsqueeze(0) for k, v in initial_params.items()} + self._diagnostics = [None] * num_chains + self.sampler = _UnarySampler(kernel, num_samples, self.warmup_steps, num_chains, disable_progbar, + initial_params=initial_params, hook=hook_fn) + + @poutine.block + def run(self, *args, **kwargs): + """ + Run StreamingMCMC to compute required `self._statistics`. + """ + self._args, self._kwargs = args, kwargs + num_samples = [0] * self.num_chains + + with optional(pyro.validation_enabled(not self.disable_validation), + self.disable_validation is not None): + args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args] + for x, chain_id in self.sampler.run(*args, **kwargs): + if num_samples[chain_id] == 0: + # If transforms is not explicitly provided, infer automatically using + # model args, kwargs. + if self.transforms is None: + self._set_transforms(*args, **kwargs) + num_samples[chain_id] += 1 + z_structure = x + elif num_samples[chain_id] == self.num_samples + 1: + self._diagnostics[chain_id] = x + else: + num_samples[chain_id] += 1 + if self.num_chains > 1: + x_cloned = x.clone() + del x + else: + x_cloned = x + + # unpack latent + pos = 0 + z_acc = z_structure.copy() + for k in sorted(z_structure): + shape = z_structure[k] + next_pos = pos + shape.numel() + z_acc[k] = x_cloned[pos:next_pos].reshape(shape) + pos = next_pos + + for name, z in z_acc.items(): + if name in self.transforms: + z_acc[name] = self.transforms[name].inv(z) + + self._statistics.update({ + (chain_id, name): transformed_sample for name, transformed_sample in z_acc.items() + }) + + # terminate the sampler (shut down worker processes) + self.sampler.terminate(True) + + def get_statistics(self, group_by_chain=True): + """ + Returns a dict of statistics defined by those passed to the class constructor. + + :param bool group_by_chain: Whether statistics should be chain-wise or merged together. + """ + if group_by_chain: + return self._statistics.get() + else: + # merge all chains with respect to names + merged_dict: Dict[str, StreamingStats] = {} + for (_, name), stat in self._statistics.stats.items(): + if name in merged_dict: + merged_dict[name] = merged_dict[name].merge(stat) + else: + merged_dict[name] = stat + + return {k: v.get() for k, v in merged_dict.items()} diff --git a/pyro/infer/mcmc/util.py b/pyro/infer/mcmc/util.py index 89831a93f8..f862fe8077 100644 --- a/pyro/infer/mcmc/util.py +++ b/pyro/infer/mcmc/util.py @@ -662,3 +662,33 @@ def wrapped_fn(*args, **kwargs): predictions[site] = value.reshape(shape) return predictions + + +def select_samples(samples, num_samples=None, group_by_chain=False): + """ + Performs selection from given MCMC samples. + + :param dictionary samples: Samples object to sample from. + :param int num_samples: Number of samples to return. If `None`, all the samples + from an MCMC chain are returned in their original ordering. + :param bool group_by_chain: Whether to preserve the chain dimension. If True, + all samples will have num_chains as the size of their leading dimension. + :return: dictionary of samples keyed by site name. + """ + if num_samples is None: + # reshape to collapse chain dim when group_by_chain=False + if not group_by_chain: + samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()} + else: + if not samples: + raise ValueError("No samples found from MCMC run.") + if group_by_chain: + batch_dim = 1 + else: + samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()} + batch_dim = 0 + sample_tensor = list(samples.values())[0] + batch_size, device = sample_tensor.shape[batch_dim], sample_tensor.device + idxs = torch.randint(0, batch_size, size=(num_samples,), device=device) + samples = {k: v.index_select(batch_dim, idxs) for k, v in samples.items()} + return samples diff --git a/tests/infer/mcmc/test_mcmc_api.py b/tests/infer/mcmc/test_mcmc_api.py index 0dc12cad50..878e134837 100644 --- a/tests/infer/mcmc/test_mcmc_api.py +++ b/tests/infer/mcmc/test_mcmc_api.py @@ -11,9 +11,10 @@ import pyro.distributions as dist from pyro import poutine from pyro.infer.mcmc import HMC, NUTS -from pyro.infer.mcmc.api import MCMC, _MultiSampler, _UnarySampler +from pyro.infer.mcmc.api import MCMC, StreamingMCMC, _MultiSampler, _UnarySampler from pyro.infer.mcmc.mcmc_kernel import MCMCKernel -from pyro.infer.mcmc.util import initialize_model +from pyro.infer.mcmc.util import initialize_model, select_samples +from pyro.ops.streaming import StackStats, StatsOfDict from pyro.util import optional from tests.common import assert_close @@ -73,29 +74,68 @@ def normal_normal_model(data): return y +def run_default_mcmc(data, kernel, num_samples, warmup_steps=None, initial_params=None, + num_chains=1, hook_fn=None, mp_context=None, transforms=None, num_draws=None, + group_by_chain=False): + mcmc = MCMC(kernel=kernel, num_samples=num_samples, warmup_steps=warmup_steps, initial_params=initial_params, + num_chains=num_chains, hook_fn=hook_fn, mp_context=mp_context, transforms=transforms) + mcmc.run(data) + return mcmc.get_samples(num_draws, group_by_chain=group_by_chain), mcmc.num_chains + + +def run_streaming_mcmc(data, kernel, num_samples, warmup_steps=None, initial_params=None, + num_chains=1, hook_fn=None, mp_context=None, transforms=None, num_draws=None, + group_by_chain=False): + mcmc = StreamingMCMC(kernel=kernel, num_samples=num_samples, warmup_steps=warmup_steps, + initial_params=initial_params, statistics=StatsOfDict(default=StackStats), + num_chains=num_chains, hook_fn=hook_fn, transforms=transforms) + mcmc.run(data) + statistics = mcmc.get_statistics(group_by_chain=group_by_chain) + + if group_by_chain: + samples = {} + agg = {} + for (_, name), stat in statistics.items(): + if name in agg: + agg[name].append(stat['samples']) + else: + agg[name] = [stat['samples']] + for name, l in agg.items(): + samples[name] = torch.stack(l) + else: + samples = {name: stat['samples'] for name, stat in statistics.items()} + + samples = select_samples(samples, num_draws, group_by_chain) + + if not group_by_chain: + samples = {name: stat.unsqueeze(-1) for name, stat in samples.items()} + + return samples, mcmc.num_chains + + +@pytest.mark.parametrize("run_mcmc_cls", [run_default_mcmc, run_streaming_mcmc]) @pytest.mark.parametrize('num_draws', [None, 1800, 2200]) @pytest.mark.parametrize('group_by_chain', [False, True]) @pytest.mark.parametrize('num_chains', [1, 2]) @pytest.mark.filterwarnings("ignore:num_chains") -def test_mcmc_interface(num_draws, group_by_chain, num_chains): +def test_mcmc_interface(run_mcmc_cls, num_draws, group_by_chain, num_chains): num_samples = 2000 data = torch.tensor([1.0]) initial_params, _, transforms, _ = initialize_model(normal_normal_model, model_args=(data,), num_chains=num_chains) kernel = PriorKernel(normal_normal_model) - mcmc = MCMC(kernel=kernel, num_samples=num_samples, warmup_steps=100, num_chains=num_chains, - mp_context="spawn", initial_params=initial_params, transforms=transforms) - mcmc.run(data) - samples = mcmc.get_samples(num_draws, group_by_chain=group_by_chain) + samples, mcmc_num_chains = run_mcmc_cls(data, kernel, num_samples=num_samples, warmup_steps=100, + initial_params=initial_params, num_chains=num_chains, mp_context='spawn', + transforms=transforms, num_draws=num_draws, group_by_chain=group_by_chain) # test sample shape expected_samples = num_draws if num_draws is not None else num_samples if group_by_chain: - expected_shape = (mcmc.num_chains, expected_samples, 1) + expected_shape = (mcmc_num_chains, expected_samples, 1) elif num_draws is not None: # FIXME: what is the expected behavior of num_draw is not None and group_by_chain=False? expected_shape = (expected_samples, 1) else: - expected_shape = (mcmc.num_chains * expected_samples, 1) + expected_shape = (mcmc_num_chains * expected_samples, 1) assert samples['y'].shape == expected_shape # test sample stats @@ -146,6 +186,10 @@ def _hook(iters, kernel, samples, stage, i): iters.append((stage, i)) +@pytest.mark.parametrize("run_mcmc_cls", [ + run_default_mcmc, + run_streaming_mcmc +]) @pytest.mark.parametrize("kernel, model", [ (HMC, _empty_model), (NUTS, _empty_model), @@ -156,7 +200,7 @@ def _hook(iters, kernel, samples, stage, i): 2 ]) @pytest.mark.filterwarnings("ignore:num_chains") -def test_null_model_with_hook(kernel, model, jit, num_chains): +def test_null_model_with_hook(run_mcmc_cls, kernel, model, jit, num_chains): num_warmup, num_samples = 10, 10 initial_params, potential_fn, transforms, _ = initialize_model(model, num_chains=num_chains) @@ -167,10 +211,8 @@ def test_null_model_with_hook(kernel, model, jit, num_chains): mp_context = "spawn" if "CUDA_TEST" in os.environ else None kern = kernel(potential_fn=potential_fn, transforms=transforms, jit_compile=jit) - mcmc = MCMC(kern, num_samples=num_samples, warmup_steps=num_warmup, - num_chains=num_chains, initial_params=initial_params, hook_fn=hook, mp_context=mp_context) - mcmc.run() - samples = mcmc.get_samples() + samples, _ = run_mcmc_cls(data=None, kernel=kern, num_samples=num_samples, warmup_steps=num_warmup, + initial_params=initial_params, hook_fn=hook, num_chains=num_chains, mp_context=mp_context) assert samples == {} if num_chains == 1: expected = [("Warmup", i) for i in range(num_warmup)] + [("Sample", i) for i in range(num_samples)] @@ -200,8 +242,9 @@ def test_mcmc_diagnostics(num_chains): for i in range(num_chains)} +@pytest.mark.parametrize("run_mcmc_cls", [run_default_mcmc, run_streaming_mcmc]) @pytest.mark.filterwarnings("ignore:num_chains") -def test_sequential_consistent(monkeypatch): +def test_sequential_consistent(run_mcmc_cls, monkeypatch): # test if there is no stuff left from the previous chain monkeypatch.setattr(torch.multiprocessing, 'cpu_count', lambda: 1) @@ -219,31 +262,24 @@ def setup(self, warmup_steps, *args, **kwargs): data = torch.tensor([1.0]) kernel = FirstKernel(normal_normal_model) - mcmc = MCMC(kernel, num_samples=100, warmup_steps=100, num_chains=2) - mcmc.run(data) - samples1 = mcmc.get_samples(group_by_chain=True) + samples1, _ = run_mcmc_cls(data, kernel, num_samples=100, warmup_steps=100, num_chains=2, group_by_chain=True) kernel = SecondKernel(normal_normal_model) - mcmc = MCMC(kernel, num_samples=100, warmup_steps=100, num_chains=2) - mcmc.run(data) - samples2 = mcmc.get_samples(group_by_chain=True) + samples2, _ = run_mcmc_cls(data, kernel, num_samples=100, warmup_steps=100, num_chains=2, group_by_chain=True) assert_close(samples1["y"][0], samples2["y"][1]) assert_close(samples1["y"][1], samples2["y"][0]) -def test_model_with_potential_fn(): +@pytest.mark.parametrize("run_mcmc_cls", [run_default_mcmc, run_streaming_mcmc]) +def test_model_with_potential_fn(run_mcmc_cls): init_params = {"z": torch.tensor(0.)} def potential_fn(params): return params["z"] - mcmc = MCMC( - kernel=HMC(potential_fn=potential_fn), - num_samples=10, - warmup_steps=10, - initial_params=init_params) - mcmc.run() + run_mcmc_cls(data=None, kernel=HMC(potential_fn=potential_fn), num_samples=10, + warmup_steps=10, initial_params=init_params) @pytest.mark.parametrize("save_params", ["xy", "x", "y", "xy"])