gstlal-inspiral  0.4.2
 All Classes Namespaces Files Functions Variables Pages
ensemble.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 """
4 An affine invariant Markov chain Monte Carlo (MCMC) sampler.
5 
6 Goodman & Weare, Ensemble Samplers With Affine Invariance
7  Comm. App. Math. Comp. Sci., Vol. 5 (2010), No. 1, 65–80
8 
9 """
10 
11 from __future__ import (division, print_function, absolute_import,
12  unicode_literals)
13 
14 __all__ = ["EnsembleSampler"]
15 
16 import multiprocessing
17 import numpy as np
18 
19 try:
20  import acor
21  acor = acor
22 except ImportError:
23  acor = None
24 
25 from .sampler import Sampler
26 
27 
28 class EnsembleSampler(Sampler):
29  """
30  A generalized Ensemble sampler that uses 2 ensembles for parallelization.
31  The ``__init__`` function will raise an ``AssertionError`` if
32  ``k < 2 * dim`` (and you haven't set the ``live_dangerously`` parameter)
33  or if ``k`` is odd.
34 
35  **Warning**: The :attr:`chain` member of this object has the shape:
36  ``(nwalkers, nlinks, dim)`` where ``nlinks`` is the number of steps
37  taken by the chain and ``k`` is the number of walkers. Use the
38  :attr:`flatchain` property to get the chain flattened to
39  ``(nlinks, dim)``. For users of pre-1.0 versions, this shape is
40  different so be careful!
41 
42  :param nwalkers:
43  The number of Goodman & Weare "walkers".
44 
45  :param dim:
46  Number of dimensions in the parameter space.
47 
48  :param lnpostfn:
49  A function that takes a vector in the parameter space as input and
50  returns the natural logarithm of the posterior probability for that
51  position.
52 
53  :param a: (optional)
54  The proposal scale parameter. (default: ``2.0``)
55 
56  :param args: (optional)
57  A list of extra arguments for ``lnpostfn``. ``lnpostfn`` will be
58  called with the sequence ``lnpostfn(p, *args)``.
59 
60  :param postargs: (optional)
61  Alias of ``args`` for backwards compatibility.
62 
63  :param threads: (optional)
64  The number of threads to use for parallelization. If ``threads == 1``,
65  then the ``multiprocessing`` module is not used but if
66  ``threads > 1``, then a ``Pool`` object is created and calls to
67  ``lnpostfn`` are run in parallel.
68 
69  :param pool: (optional)
70  An alternative method of using the parallelized algorithm. If
71  provided, the value of ``threads`` is ignored and the
72  object provided by ``pool`` is used for all parallelization. It
73  can be any object with a ``map`` method that follows the same
74  calling sequence as the built-in ``map`` function.
75 
76  """
77  def __init__(self, nwalkers, dim, lnpostfn, a=2.0, args=[], postargs=None,
78  threads=1, pool=None, live_dangerously=False):
79  self.k = nwalkers
80  self.a = a
81  self.threads = threads
82  self.pool = pool
83 
84  if postargs is not None:
85  args = postargs
86  super(EnsembleSampler, self).__init__(dim, lnpostfn, args=args)
87 
88  # Do a little bit of _magic_ to make the likelihood call with
89  # ``args`` pickleable.
90  self.lnprobfn = _function_wrapper(self.lnprobfn, self.args)
91 
92  assert self.k % 2 == 0, "The number of walkers must be even."
93  if not live_dangerously:
94  assert self.k >= 2 * self.dim, (
95  "The number of walkers needs to be more than twice the "
96  "dimension of your parameter space... unless you're "
97  "crazy!")
98 
99  if self.threads > 1 and self.pool is None:
100  self.pool = multiprocessing.Pool(self.threads)
101 
102  def reset(self):
103  """
104  Clear the ``chain`` and ``lnprobability`` array. Also reset the
105  bookkeeping parameters.
106 
107  """
108  super(EnsembleSampler, self).reset()
109  self.naccepted = np.zeros(self.k)
110  self._chain = np.empty((self.k, 0, self.dim))
111  self._lnprob = np.empty((self.k, 0))
112 
113  # Initialize list for storing optional metadata blobs.
114  self._blobs = []
115 
116  def sample(self, p0, lnprob0=None, rstate0=None, blobs0=None,
117  iterations=1, thin=1, storechain=True, mh_proposal=None):
118  """
119  Advance the chain ``iterations`` steps as a generator.
120 
121  :param p0:
122  A list of the initial positions of the walkers in the
123  parameter space. It should have the shape ``(nwalkers, dim)``.
124 
125  :param lnprob0: (optional)
126  The list of log posterior probabilities for the walkers at
127  positions given by ``p0``. If ``lnprob is None``, the initial
128  values are calculated. It should have the shape ``(k, dim)``.
129 
130  :param rstate0: (optional)
131  The state of the random number generator.
132  See the :attr:`Sampler.random_state` property for details.
133 
134  :param iterations: (optional)
135  The number of steps to run.
136 
137  :param thin: (optional)
138  If you only want to store and yield every ``thin`` samples in the
139  chain, set thin to an integer greater than 1.
140 
141  :param storechain: (optional)
142  By default, the sampler stores (in memory) the positions and
143  log-probabilities of the samples in the chain. If you are
144  using another method to store the samples to a file or if you
145  don't need to analyse the samples after the fact (for burn-in
146  for example) set ``storechain`` to ``False``.
147 
148  :param mh_proposal: (optional)
149  A function that returns a list of positions for ``nwalkers``
150  walkers given a current list of positions of the same size. See
151  :class:`utils.MH_proposal_axisaligned` for an example.
152 
153  At each iteration, this generator yields:
154 
155  * ``pos`` — A list of the current positions of the walkers in the
156  parameter space. The shape of this object will be
157  ``(nwalkers, dim)``.
158 
159  * ``lnprob`` — The list of log posterior probabilities for the
160  walkers at positions given by ``pos`` . The shape of this object
161  is ``(nwalkers, dim)``.
162 
163  * ``rstate`` — The current state of the random number generator.
164 
165  * ``blobs`` — (optional) The metadata "blobs" associated with the
166  current position. The value is only returned if ``lnpostfn``
167  returns blobs too.
168 
169  """
170  # Try to set the initial value of the random number generator. This
171  # fails silently if it doesn't work but that's what we want because
172  # we'll just interpret any garbage as letting the generator stay in
173  # it's current state.
174  self.random_state = rstate0
175 
176  p = np.array(p0)
177  halfk = int(self.k / 2)
178 
179  # If the initial log-probabilities were not provided, calculate them
180  # now.
181  lnprob = lnprob0
182  blobs = blobs0
183  if lnprob is None:
184  lnprob, blobs = self._get_lnprob(p)
185 
186  # Check to make sure that the probability function didn't return
187  # ``np.nan``.
188  if np.any(np.isnan(lnprob)):
189  raise ValueError("The initial lnprob was NaN.")
190 
191  # Store the initial size of the stored chain.
192  i0 = self._chain.shape[1]
193 
194  # Here, we resize chain in advance for performance. This actually
195  # makes a pretty big difference.
196  if storechain:
197  N = int(iterations / thin)
198  self._chain = np.concatenate((self._chain,
199  np.zeros((self.k, N, self.dim))),
200  axis=1)
201  self._lnprob = np.concatenate((self._lnprob,
202  np.zeros((self.k, N))), axis=1)
203 
204  for i in range(int(iterations)):
205  self.iterations += 1
206 
207  # If we were passed a Metropolis-Hastings proposal
208  # function, use it.
209  if mh_proposal is not None:
210  # Draw proposed positions & evaluate lnprob there
211  q = mh_proposal(p)
212  newlnp, blob = self._get_lnprob(q)
213 
214  # Accept if newlnp is better; and ...
215  acc = (newlnp > lnprob)
216 
217  # ... sometimes accept for steps that got worse
218  worse = np.flatnonzero(~acc)
219  acc[worse] = ((newlnp[worse] - lnprob[worse]) >
220  np.log(self._random.rand(len(worse))))
221  del worse
222 
223  # Update the accepted walkers.
224  lnprob[acc] = newlnp[acc]
225  p[acc] = q[acc]
226  self.naccepted[acc] += 1
227 
228  if blob is not None:
229  assert blobs is not None, (
230  "If you start sampling with a given lnprob, you also "
231  "need to provide the current list of blobs at that "
232  "position.")
233  ind = np.arange(self.k)[acc]
234  for j in ind:
235  blobs[j] = blob[j]
236 
237  else:
238  # Loop over the two ensembles, calculating the proposed
239  # positions.
240 
241  # Slices for the first and second halves
242  first, second = slice(halfk), slice(halfk, self.k)
243  for S0, S1 in [(first, second), (second, first)]:
244  q, newlnp, acc, blob = self._propose_stretch(p[S0], p[S1],
245  lnprob[S0])
246  if np.any(acc):
247  # Update the positions, log probabilities and
248  # acceptance counts.
249  lnprob[S0][acc] = newlnp[acc]
250  p[S0][acc] = q[acc]
251  self.naccepted[S0][acc] += 1
252 
253  if blob is not None:
254  assert blobs is not None, (
255  "If you start sampling with a given lnprob, "
256  "you also need to provide the current list of "
257  "blobs at that position.")
258  ind = np.arange(len(acc))[acc]
259  indfull = np.arange(self.k)[S0][acc]
260  for j in range(len(ind)):
261  blobs[indfull[j]] = blob[ind[j]]
262 
263  if storechain and i % thin == 0:
264  ind = i0 + int(i / thin)
265  self._chain[:, ind, :] = p
266  self._lnprob[:, ind] = lnprob
267  if blobs is not None:
268  self._blobs.append(list(blobs))
269 
270  # Yield the result as an iterator so that the user can do all
271  # sorts of fun stuff with the results so far.
272  if blobs is not None:
273  # This is a bit of a hack to keep things backwards compatible.
274  yield p, lnprob, self.random_state, blobs
275  else:
276  yield p, lnprob, self.random_state
277 
278  def _propose_stretch(self, p0, p1, lnprob0):
279  """
280  Propose a new position for one sub-ensemble given the positions of
281  another.
282 
283  :param p0:
284  The positions from which to jump.
285 
286  :param p1:
287  The positions of the other ensemble.
288 
289  :param lnprob0:
290  The log-probabilities at ``p0``.
291 
292  This method returns:
293 
294  * ``q`` — The new proposed positions for the walkers in ``ensemble``.
295 
296  * ``newlnprob`` — The vector of log-probabilities at the positions
297  given by ``q``.
298 
299  * ``accept`` — A vector of type ``bool`` indicating whether or not
300  the proposed position for each walker should be accepted.
301 
302  * ``blob`` — The new meta data blobs or ``None`` if nothing was
303  returned by ``lnprobfn``.
304 
305  """
306  s = np.atleast_2d(p0)
307  Ns = len(s)
308  c = np.atleast_2d(p1)
309  Nc = len(c)
310 
311  # Generate the vectors of random numbers that will produce the
312  # proposal.
313  zz = ((self.a - 1.) * self._random.rand(Ns) + 1) ** 2. / self.a
314  rint = self._random.randint(Nc, size=(Ns,))
315 
316  # Calculate the proposed positions and the log-probability there.
317  q = c[rint] - zz[:, np.newaxis] * (c[rint] - s)
318  newlnprob, blob = self._get_lnprob(q)
319 
320  # Decide whether or not the proposals should be accepted.
321  lnpdiff = (self.dim - 1.) * np.log(zz) + newlnprob - lnprob0
322  accept = (lnpdiff > np.log(self._random.rand(len(lnpdiff))))
323 
324  return q, newlnprob, accept, blob
325 
326  def _get_lnprob(self, pos=None):
327  """
328  Calculate the vector of log-probability for the walkers.
329 
330  :param pos: (optional)
331  The position vector in parameter space where the probability
332  should be calculated. This defaults to the current position
333  unless a different one is provided.
334 
335  This method returns:
336 
337  * ``lnprob`` — A vector of log-probabilities with one entry for each
338  walker in this sub-ensemble.
339 
340  * ``blob`` — The list of meta data returned by the ``lnpostfn`` at
341  this position or ``None`` if nothing was returned.
342 
343  """
344  if pos is None:
345  p = self.pos
346  else:
347  p = pos
348 
349  # Check that the parameters are in physical ranges.
350  if np.any(np.isinf(p)):
351  raise ValueError("At least one parameter value was infinite.")
352  if np.any(np.isnan(p)):
353  raise ValueError("At least one parameter value was NaN.")
354 
355  # If the `pool` property of the sampler has been set (i.e. we want
356  # to use `multiprocessing`), use the `pool`'s map method. Otherwise,
357  # just use the built-in `map` function.
358  if self.pool is not None:
359  M = self.pool.map
360  else:
361  M = map
362 
363  # Run the log-probability calculations (optionally in parallel).
364  results = list(M(self.lnprobfn, [p[i] for i in range(len(p))]))
365 
366  try:
367  lnprob = np.array([float(l[0]) for l in results])
368  blob = [l[1] for l in results]
369  except (IndexError, TypeError):
370  lnprob = np.array([float(l) for l in results])
371  blob = None
372 
373  # Check for lnprob returning NaN.
374  if np.any(np.isnan(lnprob)):
375  raise ValueError("lnprob returned NaN.")
376 
377  return lnprob, blob
378 
379  @property
380  def blobs(self):
381  """
382  Get the list of "blobs" produced by sampling. The result is a list
383  (of length ``iterations``) of ``list`` s (of length ``nwalkers``) of
384  arbitrary objects. **Note**: this will actually be an empty list if
385  your ``lnpostfn`` doesn't return any metadata.
386 
387  """
388  return self._blobs
389 
390  @property
391  def chain(self):
392  """
393  A pointer to the Markov chain itself. The shape of this array is
394  ``(k, iterations, dim)``.
395 
396  """
397  return super(EnsembleSampler, self).chain
398 
399  @property
400  def flatchain(self):
401  """
402  A shortcut for accessing chain flattened along the zeroth (walker)
403  axis.
404 
405  """
406  s = self.chain.shape
407  return self.chain.reshape(s[0] * s[1], s[2])
408 
409  @property
410  def lnprobability(self):
411  """
412  A pointer to the matrix of the value of ``lnprobfn`` produced at each
413  step for each walker. The shape is ``(k, iterations)``.
414 
415  """
416  return super(EnsembleSampler, self).lnprobability
417 
418  @property
419  def flatlnprobability(self):
420  """
421  A shortcut to return the equivalent of ``lnprobability`` but aligned
422  to ``flatchain`` rather than ``chain``.
423 
424  """
425  return super(EnsembleSampler, self).lnprobability.flatten()
426 
427  @property
429  """
430  An array (length: ``k``) of the fraction of steps accepted for each
431  walker.
432 
433  """
434  return super(EnsembleSampler, self).acceptance_fraction
435 
436  @property
437  def acor(self):
438  """
439  The autocorrelation time of each parameter in the chain (length:
440  ``dim``) as estimated by the ``acor`` module.
441 
442  """
443  if acor is None:
444  raise ImportError("acor")
445  s = self.dim
446  t = np.zeros(s)
447  for i in range(s):
448  t[i] = acor.acor(self.chain[:, :, i])[0]
449  return t
450 
451 
452 class _function_wrapper(object):
453  """
454  This is a hack to make the likelihood function pickleable when ``args``
455  are also included.
456 
457  """
458  def __init__(self, f, args):
459  self.f = f
460  self.args = args
461 
462  def __call__(self, x):
463  try:
464  return self.f(x, *self.args)
465  except:
466  import traceback
467  print("emcee: Exception while calling your likelihood function:")
468  print(" params:", x)
469  print(" args:", self.args)
470  print(" exception:")
471  traceback.print_exc()
472  raise