gstlal-inspiral  0.4.2
 All Classes Namespaces Files Functions Variables Pages
sampler.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 """
4 The base sampler class implementing various helpful functions.
5 
6 """
7 
8 from __future__ import (division, print_function, absolute_import,
9  unicode_literals)
10 
11 __all__ = ["Sampler"]
12 
13 import numpy as np
14 
15 try:
16  import acor
17  acor = acor
18 except ImportError:
19  acor = None
20 
21 
22 class Sampler(object):
23  """
24  An abstract sampler object that implements various helper functions
25 
26  :param dim:
27  The number of dimensions in the parameter space.
28 
29  :param lnpostfn:
30  A function that takes a vector in the parameter space as input and
31  returns the natural logarithm of the posterior probability for that
32  position.
33 
34  :param args: (optional)
35  A list of extra arguments for ``lnpostfn``. ``lnpostfn`` will be
36  called with the sequence ``lnpostfn(p, *args)``.
37 
38  """
39  def __init__(self, dim, lnprobfn, args=[]):
40  self.dim = dim
41  self.lnprobfn = lnprobfn
42  self.args = args
43 
44  # This is a random number generator that we can easily set the state
45  # of without affecting the numpy-wide generator
46  self._random = np.random.mtrand.RandomState()
47 
48  self.reset()
49 
50  @property
51  def random_state(self):
52  """
53  The state of the internal random number generator. In practice, it's
54  the result of calling ``get_state()`` on a
55  ``numpy.random.mtrand.RandomState`` object. You can try to set this
56  property but be warned that if you do this and it fails, it will do
57  so silently.
58 
59  """
60  return self._random.get_state()
61 
62  @random_state.setter # NOQA
63  def random_state(self, state):
64  """
65  Try to set the state of the random number generator but fail silently
66  if it doesn't work. Don't say I didn't warn you...
67 
68  """
69  try:
70  self._random.set_state(state)
71  except:
72  pass
73 
74  @property
76  """
77  The fraction of proposed steps that were accepted.
78 
79  """
80  return self.naccepted / self.iterations
81 
82  @property
83  def chain(self):
84  """
85  A pointer to the Markov chain.
86 
87  """
88  return self._chain
89 
90  @property
91  def flatchain(self):
92  """
93  Alias of ``chain`` provided for compatibility.
94 
95  """
96  return self._chain
97 
98  @property
99  def lnprobability(self):
100  """
101  A list of the log-probability values associated with each step in
102  the chain.
103 
104  """
105  return self._lnprob
106 
107  @property
108  def acor(self):
109  """
110  The autocorrelation time of each parameter in the chain (length:
111  ``dim``) as estimated by the ``acor`` module.
112 
113  """
114  if acor is None:
115  raise ImportError("acor")
116  return acor.acor(self._chain.T)[0]
117 
118  def get_lnprob(self, p):
119  """Return the log-probability at the given position."""
120  return self.lnprobfn(p, *self.args)
121 
122  def reset(self):
123  """
124  Clear ``chain``, ``lnprobability`` and the bookkeeping parameters.
125 
126  """
127  self.iterations = 0
128  self.naccepted = 0
129 
130  def clear_chain(self):
131  """An alias for :func:`reset` kept for backwards compatibility."""
132  return self.reset()
133 
134  def sample(self, *args, **kwargs):
135  raise NotImplementedError("The sampling routine must be implemented "\
136  "by subclasses")
137 
138  def run_mcmc(self, pos0, N, rstate0=None, lnprob0=None, **kwargs):
139  """
140  Iterate :func:`sample` for ``N`` iterations and return the result.
141 
142  :param p0:
143  The initial position vector.
144 
145  :param N:
146  The number of steps to run.
147 
148  :param lnprob0: (optional)
149  The log posterior probability at position ``p0``. If ``lnprob``
150  is not provided, the initial value is calculated.
151 
152  :param rstate0: (optional)
153  The state of the random number generator. See the
154  :func:`random_state` property for details.
155 
156  :param kwargs: (optional)
157  Other parameters that are directly passed to :func:`sample`.
158 
159  """
160  for results in self.sample(pos0, lnprob0, rstate0, iterations=N,
161  **kwargs):
162  pass
163  return results