Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamingMCMC class #2857

Merged
merged 5 commits into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/pyro.infer.mcmc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ MCMC
:undoc-members:
:show-inheritance:

StreamingMCMC
-------------

.. autoclass:: pyro.infer.mcmc.api.StreamingMCMC
:members:
:undoc-members:
:show-inheritance:

MCMCKernel
----------
.. autoclass:: pyro.infer.mcmc.mcmc_kernel.MCMCKernel
Expand Down
3 changes: 2 additions & 1 deletion pyro/infer/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from pyro.infer.mcmc.adaptation import ArrowheadMassMatrix, BlockMassMatrix
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc.api import MCMC, StreamingMCMC
from pyro.infer.mcmc.hmc import HMC
from pyro.infer.mcmc.nuts import NUTS

Expand All @@ -12,4 +12,5 @@
"HMC",
"MCMC",
"NUTS",
"StreamingMCMC",
]
205 changes: 156 additions & 49 deletions pyro/infer/mcmc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
code that works with different backends.
- minimal memory consumption with multiprocessing and CUDA.
"""

import copy
import json
import logging
import queue
import signal
import threading
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict

import torch
import torch.multiprocessing as mp
Expand All @@ -31,7 +33,8 @@
initialize_logger,
)
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.mcmc.util import diagnostics, print_summary
from pyro.infer.mcmc.util import diagnostics, print_summary, select_samples
from pyro.ops.streaming import CountMeanVarianceStats, StatsOfDict, StreamingStats
from pyro.util import optional

MAX_SEED = 2**32 - 1
Expand Down Expand Up @@ -257,7 +260,46 @@ def run(self, *args, **kwargs):
self.terminate(terminate_workers=exc_raised)


class MCMC:
class AbstractMCMC(ABC):
"""
Base class for MCMC methods.
"""
def __init__(self, kernel, num_chains, transforms):
self.kernel = kernel
self.num_chains = num_chains
self.transforms = transforms

@abstractmethod
def run(self, *args, **kwargs):
raise NotImplementedError

def _set_transforms(self, *args, **kwargs):
# Use `kernel.transforms` when available
if getattr(self.kernel, "transforms", None) is not None:
self.transforms = self.kernel.transforms
# Else, get transforms from model (e.g. in multiprocessing).
elif self.kernel.model:
warmup_steps = 0
self.kernel.setup(warmup_steps, *args, **kwargs)
self.transforms = self.kernel.transforms
# Assign default value
else:
self.transforms = {}

def _validate_kernel(self, initial_params):
if isinstance(self.kernel, (HMC, NUTS)) and self.kernel.potential_fn is not None:
if initial_params is None:
raise ValueError("Must provide valid initial parameters to begin sampling"
" when using `potential_fn` in HMC/NUTS kernel.")

def _validate_initial_params(self, initial_params):
for v in initial_params.values():
if v.shape[0] != self.num_chains:
raise ValueError("The leading dimension of tensors in `initial_params` "
"must match the number of chains.")


class MCMC(AbstractMCMC):
"""
Wrapper class for Markov Chain Monte Carlo algorithms. Specific MCMC algorithms
are TraceKernel instances and need to be supplied as a ``kernel`` argument
Expand Down Expand Up @@ -307,28 +349,21 @@ class MCMC:
def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None,
num_chains=1, hook_fn=None, mp_context=None, disable_progbar=False,
disable_validation=True, transforms=None, save_params=None):
super().__init__(kernel, num_chains, transforms)
self.warmup_steps = num_samples if warmup_steps is None else warmup_steps # Stan
self.num_samples = num_samples
self.kernel = kernel
self.transforms = transforms
self.disable_validation = disable_validation
self._samples = None
self._args = None
self._kwargs = None
if save_params is not None:
kernel.save_params = save_params
if isinstance(self.kernel, (HMC, NUTS)) and self.kernel.potential_fn is not None:
if initial_params is None:
raise ValueError("Must provide valid initial parameters to begin sampling"
" when using `potential_fn` in HMC/NUTS kernel.")
self._validate_kernel(initial_params)
parallel = False
if num_chains > 1:
# check that initial_params is different for each chain
if initial_params:
for v in initial_params.values():
if v.shape[0] != num_chains:
raise ValueError("The leading dimension of tensors in `initial_params` "
"must match the number of chains.")
self._validate_initial_params(initial_params)
# FIXME: probably we want to use "spawn" method by default to avoid the error
# CUDA initialization error https://github.com/pytorch/pytorch/issues/2517
# even that we run MCMC in CPU.
Expand All @@ -348,10 +383,7 @@ def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None,
else:
if initial_params:
initial_params = {k: v.unsqueeze(0) for k, v in initial_params.items()}

self.num_chains = num_chains
self._diagnostics = [None] * num_chains

if parallel:
self.sampler = _MultiSampler(kernel, num_samples, self.warmup_steps, num_chains, mp_context,
disable_progbar, initial_params=initial_params, hook=hook_fn)
Expand Down Expand Up @@ -422,17 +454,7 @@ def model(data):
# If transforms is not explicitly provided, infer automatically using
# model args, kwargs.
if self.transforms is None:
# Use `kernel.transforms` when available
if getattr(self.kernel, "transforms", None) is not None:
self.transforms = self.kernel.transforms
# Else, get transforms from model (e.g. in multiprocessing).
elif self.kernel.model:
warmup_steps = 0
self.kernel.setup(warmup_steps, *args, **kwargs)
self.transforms = self.kernel.transforms
# Assign default value
else:
self.transforms = {}
self._set_transforms(*args, **kwargs)

# transform samples back to constrained space
for name, z in z_acc.items():
Expand All @@ -447,30 +469,11 @@ def get_samples(self, num_samples=None, group_by_chain=False):
"""
Get samples from the MCMC run, potentially resampling with replacement.

:param int num_samples: Number of samples to return. If `None`, all the samples
from an MCMC chain are returned in their original ordering.
:param bool group_by_chain: Whether to preserve the chain dimension. If True,
all samples will have num_chains as the size of their leading dimension.
:return: dictionary of samples keyed by site name.
For parameter details see:
:meth:`select_samples <pyro.infer.mcmc.util.select_samples>`.
"""
samples = self._samples
if num_samples is None:
# reshape to collapse chain dim when group_by_chain=False
if not group_by_chain:
samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()}
else:
if not samples:
raise ValueError("No samples found from MCMC run.")
if group_by_chain:
batch_dim = 1
else:
samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()}
batch_dim = 0
sample_tensor = list(samples.values())[0]
batch_size, device = sample_tensor.shape[batch_dim], sample_tensor.device
idxs = torch.randint(0, batch_size, size=(num_samples,), device=device)
samples = {k: v.index_select(batch_dim, idxs) for k, v in samples.items()}
return samples
return select_samples(samples, num_samples, group_by_chain)

def diagnostics(self):
"""
Expand All @@ -496,3 +499,107 @@ def summary(self, prob=0.9):
if 'divergences' in self._diagnostics[0]:
print("Number of divergences: {}".format(
sum([len(self._diagnostics[i]['divergences']) for i in range(self.num_chains)])))


class StreamingMCMC(AbstractMCMC):
"""
MCMC that computes required statistics in a streaming fashion. For this class no samples are retained
but only aggregated statistics. This is useful for running memory expensive models where we care only
about specific statistics (especially useful in a memory constrained environments like GPU).

For available streaming ops please see :mod:`~pyro.ops.streaming`.
"""
def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None,
statistics=None, num_chains=1, hook_fn=None, disable_progbar=False,
disable_validation=True, transforms=None, save_params=None):
super().__init__(kernel, num_chains, transforms)
self.warmup_steps = num_samples if warmup_steps is None else warmup_steps # Stan
self.num_samples = num_samples
self.disable_validation = disable_validation
self._samples = None
self._args = None
self._kwargs = None
if statistics is None:
statistics = StatsOfDict(default=CountMeanVarianceStats)
self._statistics = statistics
self._default_statistics = copy.deepcopy(statistics)
if save_params is not None:
kernel.save_params = save_params
self._validate_kernel(initial_params)
if num_chains > 1:
if initial_params:
self._validate_initial_params(initial_params)
else:
if initial_params:
initial_params = {k: v.unsqueeze(0) for k, v in initial_params.items()}
self._diagnostics = [None] * num_chains
self.sampler = _UnarySampler(kernel, num_samples, self.warmup_steps, num_chains, disable_progbar,
initial_params=initial_params, hook=hook_fn)

@poutine.block
def run(self, *args, **kwargs):
"""
Run StreamingMCMC to compute required `self._statistics`.
"""
self._args, self._kwargs = args, kwargs
num_samples = [0] * self.num_chains

with optional(pyro.validation_enabled(not self.disable_validation),
self.disable_validation is not None):
args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
for x, chain_id in self.sampler.run(*args, **kwargs):
if num_samples[chain_id] == 0:
# If transforms is not explicitly provided, infer automatically using
# model args, kwargs.
if self.transforms is None:
self._set_transforms(*args, **kwargs)
num_samples[chain_id] += 1
z_structure = x
elif num_samples[chain_id] == self.num_samples + 1:
self._diagnostics[chain_id] = x
else:
num_samples[chain_id] += 1
if self.num_chains > 1:
x_cloned = x.clone()
del x
else:
x_cloned = x

# unpack latent
pos = 0
z_acc = z_structure.copy()
for k in sorted(z_structure):
shape = z_structure[k]
next_pos = pos + shape.numel()
z_acc[k] = x_cloned[pos:next_pos].reshape(shape)
pos = next_pos

for name, z in z_acc.items():
if name in self.transforms:
z_acc[name] = self.transforms[name].inv(z)

self._statistics.update({
(chain_id, name): transformed_sample for name, transformed_sample in z_acc.items()
})

# terminate the sampler (shut down worker processes)
self.sampler.terminate(True)

def get_statistics(self, group_by_chain=True):
"""
Returns a dict of statistics defined by those passed to the class constructor.

:param bool group_by_chain: Whether statistics should be chain-wise or merged together.
"""
if group_by_chain:
return self._statistics.get()
else:
# merge all chains with respect to names
merged_dict: Dict[str, StreamingStats] = {}
for (_, name), stat in self._statistics.stats.items():
if name in merged_dict:
merged_dict[name] = merged_dict[name].merge(stat)
else:
merged_dict[name] = stat

return {k: v.get() for k, v in merged_dict.items()}
30 changes: 30 additions & 0 deletions pyro/infer/mcmc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,33 @@ def wrapped_fn(*args, **kwargs):
predictions[site] = value.reshape(shape)

return predictions


def select_samples(samples, num_samples=None, group_by_chain=False):
"""
Performs selection from given MCMC samples.

:param dictionary samples: Samples object to sample from.
:param int num_samples: Number of samples to return. If `None`, all the samples
from an MCMC chain are returned in their original ordering.
:param bool group_by_chain: Whether to preserve the chain dimension. If True,
all samples will have num_chains as the size of their leading dimension.
:return: dictionary of samples keyed by site name.
"""
if num_samples is None:
# reshape to collapse chain dim when group_by_chain=False
if not group_by_chain:
samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()}
else:
if not samples:
raise ValueError("No samples found from MCMC run.")
if group_by_chain:
batch_dim = 1
else:
samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()}
batch_dim = 0
sample_tensor = list(samples.values())[0]
batch_size, device = sample_tensor.shape[batch_dim], sample_tensor.device
idxs = torch.randint(0, batch_size, size=(num_samples,), device=device)
samples = {k: v.index_select(batch_dim, idxs) for k, v in samples.items()}
return samples
Loading