From dd68da2f8b583b22afdce87f22896cfffbe74745 Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Mon, 16 Jul 2018 16:31:06 +0100 Subject: [PATCH] #391 add method to MCMCSampling to access underlying samplers --- pints/_mcmc/__init__.py | 23 ++++++++++++++++++----- test/test_mcmc_sampling.py | 3 +++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/pints/_mcmc/__init__.py b/pints/_mcmc/__init__.py index 6d7415d66..1c2c9d24f 100644 --- a/pints/_mcmc/__init__.py +++ b/pints/_mcmc/__init__.py @@ -18,6 +18,7 @@ class MCMCSampler(pints.Loggable): All MCMC samplers implement the :class:`pints.Loggable` interface. """ + def name(self): """ Returns this method's full name. @@ -41,6 +42,7 @@ class SingleChainMCMC(MCMCSampler): covariance of the distribution to estimate, around ``x0``. """ + def __init__(self, x0, sigma0=None): # Check initial position @@ -104,6 +106,7 @@ class SingleChainAdaptiveMCMC(SingleChainMCMC): of the distribution to estimate, around ``x0``. """ + def __init__(self, x0, sigma0=None): super(SingleChainAdaptiveMCMC, self).__init__(x0, sigma0) @@ -155,6 +158,7 @@ class MultiChainMCMC(MCMCSampler): ``diag(sigma0)`` will be used. """ + def __init__(self, chains, x0, sigma0=None): # Check number of chains @@ -238,6 +242,7 @@ class MCMCSampling(object): :class:`AdaptiveCovarianceMCMC` is used. """ + def __init__(self, log_pdf, chains, x0, sigma0=None, method=None): # Store function @@ -308,7 +313,15 @@ def __init__(self, log_pdf, chains, x0, sigma0=None, method=None): self._max_iterations = None self.set_max_iterations() - #TODO: Add more stopping criteria + # TODO: Add more stopping criteria + + def samplers(self): + """ + Returns the underlying array of samplers. The length of the array will + either be the number of chains, or one for samplers that sample + multiple chains + """ + return self._samplers def adaptation_free_iterations(self): """ @@ -388,9 +401,9 @@ def run(self): logger.add_time('Time m:s') # Create chains - #TODO Pre-allocate? - #TODO Thinning - #TODO Advanced logging + # TODO Pre-allocate? + # TODO Thinning + # TODO Advanced logging chains = [] # Start sampling @@ -446,7 +459,7 @@ def run(self): halt_message = ('Halting: Maximum number of iterations (' + str(iteration) + ') reached.') - #TODO Add more stopping criteria + # TODO Add more stopping criteria # # Adaptive methods diff --git a/test/test_mcmc_sampling.py b/test/test_mcmc_sampling.py index 850b3c239..16b42720c 100755 --- a/test/test_mcmc_sampling.py +++ b/test/test_mcmc_sampling.py @@ -44,6 +44,7 @@ class TestMCMCSampling(unittest.TestCase): """ Tests the MCMCSampling class. """ + def __init__(self, name): super(TestMCMCSampling, self).__init__(name) @@ -90,6 +91,7 @@ def test_single(self): mcmc = pints.MCMCSampling(self.log_posterior, nchains, xs) mcmc.set_max_iterations(niterations) mcmc.set_log_to_screen(False) + self.assertEqual(len(mcmc.samplers()), nchains) chains = mcmc.run() self.assertEqual(chains.shape[0], nchains) self.assertEqual(chains.shape[1], niterations) @@ -177,6 +179,7 @@ def test_multi(self): mcmc = pints.MCMCSampling( self.log_posterior, nchains, xs, method=pints.DifferentialEvolutionMCMC) + self.assertEqual(len(mcmc.samplers()), 1) mcmc.set_max_iterations(niterations) mcmc.set_log_to_screen(False) chains = mcmc.run()