Skip to content

Commit

Permalink
Implemented get_statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Jun 3, 2021
1 parent b8927be commit bd9c5f9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
29 changes: 23 additions & 6 deletions pyro/infer/mcmc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
code that works with different backends.
- minimal memory consumption with multiprocessing and CUDA.
"""

from abc import ABC
import copy
import json
import logging
import queue
import signal
import threading
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict

import torch
import torch.multiprocessing as mp
Expand All @@ -33,8 +34,8 @@
)
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.mcmc.util import diagnostics, print_summary
from pyro.ops.streaming import CountMeanVarianceStats, StatsOfDict, StreamingStats
from pyro.util import optional
from pyro.ops.streaming import CountMeanVarianceStats, StatsOfDict

MAX_SEED = 2**32 - 1

Expand Down Expand Up @@ -268,6 +269,10 @@ def __init__(self, kernel, num_chains, transforms):
self.num_chains = num_chains
self.transforms = transforms

@abstractmethod
def run(self, *args, **kwargs):
raise NotImplementedError

def _set_transforms(self, *args, **kwargs):
# Use `kernel.transforms` when available
if getattr(self.kernel, "transforms", None) is not None:
Expand Down Expand Up @@ -529,6 +534,7 @@ def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None,
if statistics is None:
statistics = StatsOfDict(default=CountMeanVarianceStats)
self._statistics = statistics
self._default_statistics = copy.deepcopy(statistics)
if save_params is not None:
kernel.save_params = save_params
self._validate_kernel(initial_params)
Expand Down Expand Up @@ -583,11 +589,22 @@ def run(self, *args, **kwargs):
z_acc[name] = self.transforms[name].inv(z)

self._statistics.update({
name: transformed_sample for name, transformed_sample in z_acc.items()
(chain_id, name): transformed_sample for name, transformed_sample in z_acc.items()
})

# terminate the sampler (shut down worker processes)
self.sampler.terminate(True)

def summary(self, prob=0.9):
return self._statistics.get()
def get_statistics(self, group_by_chain=True):
if group_by_chain:
return self._statistics.get()
else:
# merge all chains with respect to names
merged_dict: Dict[str, StreamingStats] = {}
for (_, name), stat in self._statistics.stats.items():
if name in merged_dict:
merged_dict[name] = merged_dict[name].merge(stat)
else:
merged_dict[name] = stat

return {k: v.get() for k, v in merged_dict.items()}
11 changes: 8 additions & 3 deletions tests/infer/mcmc/test_mcmc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from pyro.infer.mcmc.api import MCMC, StreamingMCMC, _MultiSampler, _UnarySampler
from pyro.infer.mcmc.mcmc_kernel import MCMCKernel
from pyro.infer.mcmc.util import initialize_model
from pyro.ops.streaming import CountMeanVarianceStats, StatsOfDict
from pyro.util import optional
from pyro.ops.streaming import CountMeanVarianceStats, StatsOfDict, CountStats
from tests.common import assert_close


Expand Down Expand Up @@ -77,7 +77,8 @@ def normal_normal_model(data):
@pytest.mark.parametrize("mcmc_cls", [StreamingMCMC])
@pytest.mark.parametrize('num_chains', [1, 2])
@pytest.mark.filterwarnings("ignore:num_chains")
def test_mcmc_summary(mcmc_cls, num_chains):
@pytest.mark.parametrize('group_by_chain', [True, False])
def test_mcmc_summary(mcmc_cls, num_chains, group_by_chain):
num_samples = 2000
data = torch.tensor([1.0])
initial_params, _, transforms, _ = initialize_model(normal_normal_model, model_args=(data,),
Expand All @@ -87,7 +88,11 @@ def test_mcmc_summary(mcmc_cls, num_chains):
statistics=StatsOfDict(default=CountMeanVarianceStats),
num_chains=num_chains, initial_params=initial_params, transforms=transforms)
mcmc.run(data)
print(mcmc.summary()) # TODO Draft test
statistics = mcmc.get_statistics(group_by_chain=group_by_chain)
count = 0
for stat in statistics.values():
count += stat['count']
assert count == num_samples * num_chains


@pytest.mark.parametrize('num_draws', [None, 1800, 2200])
Expand Down

0 comments on commit bd9c5f9

Please sign in to comment.