4 An affine invariant Markov chain Monte Carlo (MCMC) sampler.
6 Goodman & Weare, Ensemble Samplers With Affine Invariance
7 Comm. App. Math. Comp. Sci., Vol. 5 (2010), No. 1, 65–80
11 from __future__
import (division, print_function, absolute_import,
14 __all__ = [
"EnsembleSampler"]
16 import multiprocessing
25 from .sampler
import Sampler
30 A generalized Ensemble sampler that uses 2 ensembles for parallelization.
31 The ``__init__`` function will raise an ``AssertionError`` if
32 ``k < 2 * dim`` (and you haven't set the ``live_dangerously`` parameter)
35 **Warning**: The :attr:`chain` member of this object has the shape:
36 ``(nwalkers, nlinks, dim)`` where ``nlinks`` is the number of steps
37 taken by the chain and ``k`` is the number of walkers. Use the
38 :attr:`flatchain` property to get the chain flattened to
39 ``(nlinks, dim)``. For users of pre-1.0 versions, this shape is
40 different so be careful!
43 The number of Goodman & Weare "walkers".
46 Number of dimensions in the parameter space.
49 A function that takes a vector in the parameter space as input and
50 returns the natural logarithm of the posterior probability for that
54 The proposal scale parameter. (default: ``2.0``)
56 :param args: (optional)
57 A list of extra arguments for ``lnpostfn``. ``lnpostfn`` will be
58 called with the sequence ``lnpostfn(p, *args)``.
60 :param postargs: (optional)
61 Alias of ``args`` for backwards compatibility.
63 :param threads: (optional)
64 The number of threads to use for parallelization. If ``threads == 1``,
65 then the ``multiprocessing`` module is not used but if
66 ``threads > 1``, then a ``Pool`` object is created and calls to
67 ``lnpostfn`` are run in parallel.
69 :param pool: (optional)
70 An alternative method of using the parallelized algorithm. If
71 provided, the value of ``threads`` is ignored and the
72 object provided by ``pool`` is used for all parallelization. It
73 can be any object with a ``map`` method that follows the same
74 calling sequence as the built-in ``map`` function.
77 def __init__(self, nwalkers, dim, lnpostfn, a=2.0, args=[], postargs=None,
78 threads=1, pool=
None, live_dangerously=
False):
84 if postargs
is not None:
86 super(EnsembleSampler, self).__init__(dim, lnpostfn, args=args)
92 assert self.
k % 2 == 0,
"The number of walkers must be even."
93 if not live_dangerously:
94 assert self.
k >= 2 * self.dim, (
95 "The number of walkers needs to be more than twice the "
96 "dimension of your parameter space... unless you're "
104 Clear the ``chain`` and ``lnprobability`` array. Also reset the
105 bookkeeping parameters.
108 super(EnsembleSampler, self).
reset()
110 self.
_chain = np.empty((self.
k, 0, self.dim))
111 self.
_lnprob = np.empty((self.
k, 0))
116 def sample(self, p0, lnprob0=None, rstate0=None, blobs0=None,
117 iterations=1, thin=1, storechain=
True, mh_proposal=
None):
119 Advance the chain ``iterations`` steps as a generator.
122 A list of the initial positions of the walkers in the
123 parameter space. It should have the shape ``(nwalkers, dim)``.
125 :param lnprob0: (optional)
126 The list of log posterior probabilities for the walkers at
127 positions given by ``p0``. If ``lnprob is None``, the initial
128 values are calculated. It should have the shape ``(k, dim)``.
130 :param rstate0: (optional)
131 The state of the random number generator.
132 See the :attr:`Sampler.random_state` property for details.
134 :param iterations: (optional)
135 The number of steps to run.
137 :param thin: (optional)
138 If you only want to store and yield every ``thin`` samples in the
139 chain, set thin to an integer greater than 1.
141 :param storechain: (optional)
142 By default, the sampler stores (in memory) the positions and
143 log-probabilities of the samples in the chain. If you are
144 using another method to store the samples to a file or if you
145 don't need to analyse the samples after the fact (for burn-in
146 for example) set ``storechain`` to ``False``.
148 :param mh_proposal: (optional)
149 A function that returns a list of positions for ``nwalkers``
150 walkers given a current list of positions of the same size. See
151 :class:`utils.MH_proposal_axisaligned` for an example.
153 At each iteration, this generator yields:
155 * ``pos`` — A list of the current positions of the walkers in the
156 parameter space. The shape of this object will be
159 * ``lnprob`` — The list of log posterior probabilities for the
160 walkers at positions given by ``pos`` . The shape of this object
161 is ``(nwalkers, dim)``.
163 * ``rstate`` — The current state of the random number generator.
165 * ``blobs`` — (optional) The metadata "blobs" associated with the
166 current position. The value is only returned if ``lnpostfn``
177 halfk = int(self.
k / 2)
188 if np.any(np.isnan(lnprob)):
189 raise ValueError(
"The initial lnprob was NaN.")
192 i0 = self._chain.shape[1]
197 N = int(iterations / thin)
199 np.zeros((self.
k, N, self.dim))),
202 np.zeros((self.
k, N))), axis=1)
204 for i
in range(int(iterations)):
209 if mh_proposal
is not None:
215 acc = (newlnp > lnprob)
218 worse = np.flatnonzero(~acc)
219 acc[worse] = ((newlnp[worse] - lnprob[worse]) >
220 np.log(self._random.rand(len(worse))))
224 lnprob[acc] = newlnp[acc]
229 assert blobs
is not None, (
230 "If you start sampling with a given lnprob, you also "
231 "need to provide the current list of blobs at that "
233 ind = np.arange(self.
k)[acc]
242 first, second = slice(halfk), slice(halfk, self.
k)
243 for S0, S1
in [(first, second), (second, first)]:
249 lnprob[S0][acc] = newlnp[acc]
254 assert blobs
is not None, (
255 "If you start sampling with a given lnprob, "
256 "you also need to provide the current list of "
257 "blobs at that position.")
258 ind = np.arange(len(acc))[acc]
259 indfull = np.arange(self.
k)[S0][acc]
260 for j
in range(len(ind)):
261 blobs[indfull[j]] = blob[ind[j]]
263 if storechain
and i % thin == 0:
264 ind = i0 + int(i / thin)
265 self.
_chain[:, ind, :] = p
267 if blobs
is not None:
268 self._blobs.append(list(blobs))
272 if blobs
is not None:
278 def _propose_stretch(self, p0, p1, lnprob0):
280 Propose a new position for one sub-ensemble given the positions of
284 The positions from which to jump.
287 The positions of the other ensemble.
290 The log-probabilities at ``p0``.
294 * ``q`` — The new proposed positions for the walkers in ``ensemble``.
296 * ``newlnprob`` — The vector of log-probabilities at the positions
299 * ``accept`` — A vector of type ``bool`` indicating whether or not
300 the proposed position for each walker should be accepted.
302 * ``blob`` — The new meta data blobs or ``None`` if nothing was
303 returned by ``lnprobfn``.
306 s = np.atleast_2d(p0)
308 c = np.atleast_2d(p1)
313 zz = ((self.
a - 1.) * self._random.rand(Ns) + 1) ** 2. / self.
a
314 rint = self._random.randint(Nc, size=(Ns,))
317 q = c[rint] - zz[:, np.newaxis] * (c[rint] - s)
321 lnpdiff = (self.dim - 1.) * np.log(zz) + newlnprob - lnprob0
322 accept = (lnpdiff > np.log(self._random.rand(len(lnpdiff))))
324 return q, newlnprob, accept, blob
326 def _get_lnprob(self, pos=None):
328 Calculate the vector of log-probability for the walkers.
330 :param pos: (optional)
331 The position vector in parameter space where the probability
332 should be calculated. This defaults to the current position
333 unless a different one is provided.
337 * ``lnprob`` — A vector of log-probabilities with one entry for each
338 walker in this sub-ensemble.
340 * ``blob`` — The list of meta data returned by the ``lnpostfn`` at
341 this position or ``None`` if nothing was returned.
350 if np.any(np.isinf(p)):
351 raise ValueError(
"At least one parameter value was infinite.")
352 if np.any(np.isnan(p)):
353 raise ValueError(
"At least one parameter value was NaN.")
358 if self.
pool is not None:
364 results = list(M(self.
lnprobfn, [p[i]
for i
in range(len(p))]))
367 lnprob = np.array([float(l[0])
for l
in results])
368 blob = [l[1]
for l
in results]
369 except (IndexError, TypeError):
370 lnprob = np.array([float(l)
for l
in results])
374 if np.any(np.isnan(lnprob)):
375 raise ValueError(
"lnprob returned NaN.")
382 Get the list of "blobs" produced by sampling. The result is a list
383 (of length ``iterations``) of ``list`` s (of length ``nwalkers``) of
384 arbitrary objects. **Note**: this will actually be an empty list if
385 your ``lnpostfn`` doesn't return any metadata.
393 A pointer to the Markov chain itself. The shape of this array is
394 ``(k, iterations, dim)``.
397 return super(EnsembleSampler, self).chain
402 A shortcut for accessing chain flattened along the zeroth (walker)
407 return self.chain.reshape(s[0] * s[1], s[2])
412 A pointer to the matrix of the value of ``lnprobfn`` produced at each
413 step for each walker. The shape is ``(k, iterations)``.
416 return super(EnsembleSampler, self).lnprobability
421 A shortcut to return the equivalent of ``lnprobability`` but aligned
422 to ``flatchain`` rather than ``chain``.
425 return super(EnsembleSampler, self).lnprobability.flatten()
430 An array (length: ``k``) of the fraction of steps accepted for each
434 return super(EnsembleSampler, self).acceptance_fraction
439 The autocorrelation time of each parameter in the chain (length:
440 ``dim``) as estimated by the ``acor`` module.
444 raise ImportError(
"acor")
448 t[i] = acor.acor(self.
chain[:, :, i])[0]
454 This is a hack to make the likelihood function pickleable when ``args``
458 def __init__(self, f, args):
462 def __call__(self, x):
464 return self.
f(x, *self.
args)
467 print(
"emcee: Exception while calling your likelihood function:")
469 print(
" args:", self.
args)
471 traceback.print_exc()