gstlal-inspiral  0.4.2
 All Classes Namespaces Files Functions Variables Pages
utils.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 
4 from __future__ import (division, print_function, absolute_import,
5  unicode_literals)
6 
7 __all__ = ["sample_ball", "MH_proposal_axisaligned"]
8 
9 
10 import numpy as np
11 
12 # If mpi4py is installed, import it.
13 try:
14  from mpi4py import MPI
15  MPI = MPI
16 except ImportError:
17  MPI = None
18 
19 
20 def sample_ball(p0, std, size=1):
21  """
22  Produce a ball of walkers around an initial parameter value.
23 
24  :param p0: The initial parameter value.
25  :param std: The axis-aligned standard deviation.
26  :param size: The number of samples to produce.
27 
28  """
29  assert(len(p0) == len(std))
30  return np.vstack([p0 + std * np.random.normal(size=len(p0))
31  for i in range(size)])
32 
33 
35  """
36  A Metropolis-Hastings proposal, with axis-aligned Gaussian steps,
37  for convenient use as the ``mh_proposal`` option to
38  :func:`EnsembleSampler.sample` .
39 
40  """
41  def __init__(self, stdev):
42  self.stdev = stdev
43 
44  def __call__(self, X):
45  (nw, npar) = X.shape
46  assert(len(self.stdev) == npar)
47  return X + self.stdev * np.random.normal(size=X.shape)
48 
49 
50 if MPI is not None:
51  class _close_pool_message(object):
52  def __repr__(self):
53  return "<Close pool message>"
54 
55  class _function_wrapper(object):
56  def __init__(self, function):
57  self.function = function
58 
59  def _error_function(task):
60  raise RuntimeError("Pool was sent tasks before being told what "
61  "function to apply.")
62 
63  class MPIPool(object):
64  """
65  A pool that distributes tasks over a set of MPI processes. MPI is an
66  API for distributed memory parallelism. This pool will let you run
67  emcee without shared memory, letting you use much larger machines
68  with emcee.
69 
70  The pool only support the :func:`map` method at the moment because
71  this is the only functionality that emcee needs. That being said,
72  this pool is fairly general and it could be used for other purposes.
73 
74  Contributed by `Joe Zuntz <https://github.com/joezuntz>`_.
75 
76  :param comm: (optional)
77  The ``mpi4py`` communicator.
78 
79  :param debug: (optional)
80  If ``True``, print out a lot of status updates at each step.
81 
82  """
83  def __init__(self, comm=MPI.COMM_WORLD, debug=False):
84  self.comm = comm
85  self.rank = comm.Get_rank()
86  self.size = comm.Get_size() - 1
87  self.debug = debug
88  self.function = _error_function
89  if self.size == 0:
90  raise ValueError("Tried to create an MPI pool, but there "
91  "was only one MPI process available. "
92  "Need at least two.")
93 
94  def is_master(self):
95  """
96  Is the current process the master?
97 
98  """
99  return self.rank == 0
100 
101  def wait(self):
102  """
103  If this isn't the master process, wait for instructions.
104 
105  """
106  if self.is_master():
107  raise RuntimeError("Master node told to await jobs.")
108 
109  status = MPI.Status()
110 
111  while True:
112  # Event loop.
113  # Sit here and await instructions.
114  if self.debug:
115  print("Worker {0} waiting for task.".format(self.rank))
116 
117  # Blocking receive to wait for instructions.
118  task = self.comm.recv(source=0, tag=MPI.ANY_TAG, status=status)
119  if self.debug:
120  print("Worker {0} got task {1} with tag {2}."
121  .format(self.rank, task, status.tag))
122 
123  # Check if message is special sentinel signaling end.
124  # If so, stop.
125  if isinstance(task, _close_pool_message):
126  if self.debug:
127  print("Worker {0} told to quit.".format(self.rank))
128  break
129 
130  # Check if message is special type containing new function
131  # to be applied
132  if isinstance(task, _function_wrapper):
133  self.function = task.function
134  if self.debug:
135  print("Worker {0} replaced its task function: {1}."
136  .format(self.rank, self.function))
137  continue
138 
139  # If not a special message, just run the known function on
140  # the input and return it asynchronously.
141  result = self.function(task)
142  if self.debug:
143  print("Worker {0} sending answer {1} with tag {2}."
144  .format(self.rank, result, status.tag))
145  self.comm.isend(result, dest=0, tag=status.tag)
146 
147  def map(self, function, tasks):
148  """
149  Like the built-in :func:`map` function, apply a function to all
150  of the values in a list and return the list of results.
151 
152  :param function:
153  The function to apply to the list.
154 
155  :param tasks:
156  The list of elements.
157 
158  """
159  ntask = len(tasks)
160 
161  # If not the master just wait for instructions.
162  if not self.is_master():
163  self.wait()
164  return
165 
166  if function is not self.function:
167  if self.debug:
168  print("Master replacing pool function with {0}."
169  .format(function))
170 
171  self.function = function
172  F = _function_wrapper(function)
173 
174  # Tell all the workers what function to use.
175  requests = []
176  for i in range(self.size):
177  r = self.comm.isend(F, dest=i + 1)
178  requests.append(r)
179 
180  # Wait until all of the workers have responded. See:
181  # https://gist.github.com/4176241
182  MPI.Request.waitall(requests)
183 
184  # Send all the tasks off and wait for them to be received.
185  # Again, see the bug in the above gist.
186  requests = []
187  for i, task in enumerate(tasks):
188  worker = i % self.size + 1
189  if self.debug:
190  print("Sent task {0} to worker {1} with tag {2}."
191  .format(task, worker, i))
192  r = self.comm.isend(task, dest=worker, tag=i)
193  requests.append(r)
194  MPI.Request.waitall(requests)
195 
196  # Now wait for the answers.
197  results = []
198  for i in range(ntask):
199  worker = i % self.size + 1
200  if self.debug:
201  print("Master waiting for worker {0} with tag {1}"
202  .format(worker, i))
203  result = self.comm.recv(source=worker, tag=i)
204  results.append(result)
205  return results
206 
207  def close(self):
208  """
209  Just send a message off to all the pool members which contains
210  the special :class:`_close_pool_message` sentinel.
211 
212  """
213  if self.is_master():
214  for i in range(self.size):
215  self.comm.isend(_close_pool_message(), dest=i + 1)
216 
217  def __enter__(self):
218  return self
219 
220  def __exit__(self, *args):
221  self.close()