Skip to content

Commit

Permalink
Merge pull request #392 from pints-team/issue-391-access-sampler
Browse files Browse the repository at this point in the history
#391 add method to MCMCSampling to access underlying samplers
  • Loading branch information
MichaelClerx authored Jul 16, 2018
2 parents 2563f7d + dd68da2 commit 8d924d4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
23 changes: 18 additions & 5 deletions pints/_mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class MCMCSampler(pints.Loggable):
All MCMC samplers implement the :class:`pints.Loggable` interface.
"""

def name(self):
"""
Returns this method's full name.
Expand All @@ -41,6 +42,7 @@ class SingleChainMCMC(MCMCSampler):
covariance of the distribution to estimate, around ``x0``.
"""

def __init__(self, x0, sigma0=None):

# Check initial position
Expand Down Expand Up @@ -104,6 +106,7 @@ class SingleChainAdaptiveMCMC(SingleChainMCMC):
of the distribution to estimate, around ``x0``.
"""

def __init__(self, x0, sigma0=None):
super(SingleChainAdaptiveMCMC, self).__init__(x0, sigma0)

Expand Down Expand Up @@ -155,6 +158,7 @@ class MultiChainMCMC(MCMCSampler):
``diag(sigma0)`` will be used.
"""

def __init__(self, chains, x0, sigma0=None):

# Check number of chains
Expand Down Expand Up @@ -238,6 +242,7 @@ class MCMCSampling(object):
:class:`AdaptiveCovarianceMCMC` is used.
"""

def __init__(self, log_pdf, chains, x0, sigma0=None, method=None):

# Store function
Expand Down Expand Up @@ -308,7 +313,15 @@ def __init__(self, log_pdf, chains, x0, sigma0=None, method=None):
self._max_iterations = None
self.set_max_iterations()

#TODO: Add more stopping criteria
# TODO: Add more stopping criteria

def samplers(self):
"""
Returns the underlying array of samplers. The length of the array will
either be the number of chains, or one for samplers that sample
multiple chains
"""
return self._samplers

def adaptation_free_iterations(self):
"""
Expand Down Expand Up @@ -388,9 +401,9 @@ def run(self):
logger.add_time('Time m:s')

# Create chains
#TODO Pre-allocate?
#TODO Thinning
#TODO Advanced logging
# TODO Pre-allocate?
# TODO Thinning
# TODO Advanced logging
chains = []

# Start sampling
Expand Down Expand Up @@ -446,7 +459,7 @@ def run(self):
halt_message = ('Halting: Maximum number of iterations ('
+ str(iteration) + ') reached.')

#TODO Add more stopping criteria
# TODO Add more stopping criteria

#
# Adaptive methods
Expand Down
3 changes: 3 additions & 0 deletions test/test_mcmc_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class TestMCMCSampling(unittest.TestCase):
"""
Tests the MCMCSampling class.
"""

def __init__(self, name):
super(TestMCMCSampling, self).__init__(name)

Expand Down Expand Up @@ -90,6 +91,7 @@ def test_single(self):
mcmc = pints.MCMCSampling(self.log_posterior, nchains, xs)
mcmc.set_max_iterations(niterations)
mcmc.set_log_to_screen(False)
self.assertEqual(len(mcmc.samplers()), nchains)
chains = mcmc.run()
self.assertEqual(chains.shape[0], nchains)
self.assertEqual(chains.shape[1], niterations)
Expand Down Expand Up @@ -177,6 +179,7 @@ def test_multi(self):
mcmc = pints.MCMCSampling(
self.log_posterior, nchains, xs,
method=pints.DifferentialEvolutionMCMC)
self.assertEqual(len(mcmc.samplers()), 1)
mcmc.set_max_iterations(niterations)
mcmc.set_log_to_screen(False)
chains = mcmc.run()
Expand Down

0 comments on commit 8d924d4

Please sign in to comment.