gstlal-inspiral  0.4.2
 All Classes Namespaces Files Functions Variables Pages
mh.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 """
4 A vanilla Metropolis-Hastings sampler
5 
6 """
7 
8 from __future__ import (division, print_function, absolute_import,
9  unicode_literals)
10 
11 __all__ = ["MHSampler"]
12 
13 import numpy as np
14 
15 from .sampler import Sampler
16 
17 
18 # === MHSampler ===
19 class MHSampler(Sampler):
20  """
21  The most basic possible Metropolis-Hastings style MCMC sampler
22 
23  :param cov:
24  The covariance matrix to use for the proposal distribution.
25 
26  :param dim:
27  Number of dimensions in the parameter space.
28 
29  :param lnpostfn:
30  A function that takes a vector in the parameter space as input and
31  returns the natural logarithm of the posterior probability for that
32  position.
33 
34  :param args: (optional)
35  A list of extra arguments for ``lnpostfn``. ``lnpostfn`` will be
36  called with the sequence ``lnpostfn(p, *args)``.
37 
38  """
39  def __init__(self, cov, *args, **kwargs):
40  super(MHSampler, self).__init__(*args, **kwargs)
41  self.cov = cov
42 
43  def reset(self):
44  super(MHSampler, self).reset()
45  self._chain = np.empty((0, self.dim))
46  self._lnprob = np.empty(0)
47 
48  def sample(self, p0, lnprob=None, randomstate=None, thin=1,
49  storechain=True, iterations=1):
50  """
51  Advances the chain ``iterations`` steps as an iterator
52 
53  :param p0:
54  The initial position vector.
55 
56  :param lnprob0: (optional)
57  The log posterior probability at position ``p0``. If ``lnprob``
58  is not provided, the initial value is calculated.
59 
60  :param rstate0: (optional)
61  The state of the random number generator. See the
62  :func:`random_state` property for details.
63 
64  :param iterations: (optional)
65  The number of steps to run.
66 
67  :param thin: (optional)
68  If you only want to store and yield every ``thin`` samples in the
69  chain, set thin to an integer greater than 1.
70 
71  :param storechain: (optional)
72  By default, the sampler stores (in memory) the positions and
73  log-probabilities of the samples in the chain. If you are
74  using another method to store the samples to a file or if you
75  don't need to analyse the samples after the fact (for burn-in
76  for example) set ``storechain`` to ``False``.
77 
78  At each iteration, this generator yields:
79 
80  * ``pos`` — The current positions of the chain in the parameter
81  space.
82 
83  * ``lnprob`` — The value of the log posterior at ``pos`` .
84 
85  * ``rstate`` — The current state of the random number generator.
86 
87  """
88 
89  self.random_state = randomstate
90 
91  p = np.array(p0)
92  if lnprob is None:
93  lnprob = self.get_lnprob(p)
94 
95  # Resize the chain in advance.
96  if storechain:
97  N = int(iterations / thin)
98  self._chain = np.concatenate((self._chain,
99  np.zeros((N, self.dim))), axis=0)
100  self._lnprob = np.append(self._lnprob, np.zeros(N))
101 
102  i0 = self.iterations
103  # Use range instead of xrange for python 3 compatability
104  for i in range(int(iterations)):
105  self.iterations += 1
106 
107  # Calculate the proposal distribution.
108  q = self._random.multivariate_normal(p, self.cov)
109  newlnprob = self.get_lnprob(q)
110  diff = newlnprob - lnprob
111 
112  # M-H acceptance ratio
113  if diff < 0:
114  diff = np.exp(diff) - self._random.rand()
115 
116  if diff > 0:
117  p = q
118  lnprob = newlnprob
119  self.naccepted += 1
120 
121  if storechain and i % thin == 0:
122  ind = i0 + int(i / thin)
123  self._chain[ind, :] = p
124  self._lnprob[ind] = lnprob
125 
126  # Heavy duty iterator action going on right here...
127  yield p, lnprob, self.random_state