4 from __future__
import (division, print_function, absolute_import,
7 __all__ = [
"sample_ball",
"MH_proposal_axisaligned"]
14 from mpi4py
import MPI
20 def sample_ball(p0, std, size=1):
22 Produce a ball of walkers around an initial parameter value.
24 :param p0: The initial parameter value.
25 :param std: The axis-aligned standard deviation.
26 :param size: The number of samples to produce.
29 assert(len(p0) == len(std))
30 return np.vstack([p0 + std * np.random.normal(size=len(p0))
31 for i
in range(size)])
36 A Metropolis-Hastings proposal, with axis-aligned Gaussian steps,
37 for convenient use as the ``mh_proposal`` option to
38 :func:`EnsembleSampler.sample` .
41 def __init__(self, stdev):
44 def __call__(self, X):
46 assert(len(self.
stdev) == npar)
47 return X + self.
stdev * np.random.normal(size=X.shape)
53 return "<Close pool message>"
56 def __init__(self, function):
59 def _error_function(task):
60 raise RuntimeError(
"Pool was sent tasks before being told what "
65 A pool that distributes tasks over a set of MPI processes. MPI is an
66 API for distributed memory parallelism. This pool will let you run
67 emcee without shared memory, letting you use much larger machines
70 The pool only support the :func:`map` method at the moment because
71 this is the only functionality that emcee needs. That being said,
72 this pool is fairly general and it could be used for other purposes.
74 Contributed by `Joe Zuntz <https://github.com/joezuntz>`_.
76 :param comm: (optional)
77 The ``mpi4py`` communicator.
79 :param debug: (optional)
80 If ``True``, print out a lot of status updates at each step.
83 def __init__(self, comm=MPI.COMM_WORLD, debug=False):
85 self.
rank = comm.Get_rank()
86 self.
size = comm.Get_size() - 1
90 raise ValueError(
"Tried to create an MPI pool, but there "
91 "was only one MPI process available. "
96 Is the current process the master?
103 If this isn't the master process, wait for instructions.
107 raise RuntimeError(
"Master node told to await jobs.")
109 status = MPI.Status()
115 print(
"Worker {0} waiting for task.".format(self.
rank))
118 task = self.comm.recv(source=0, tag=MPI.ANY_TAG, status=status)
120 print(
"Worker {0} got task {1} with tag {2}."
121 .format(self.
rank, task, status.tag))
125 if isinstance(task, _close_pool_message):
127 print(
"Worker {0} told to quit.".format(self.
rank))
132 if isinstance(task, _function_wrapper):
135 print(
"Worker {0} replaced its task function: {1}."
143 print(
"Worker {0} sending answer {1} with tag {2}."
144 .format(self.
rank, result, status.tag))
145 self.comm.isend(result, dest=0, tag=status.tag)
147 def map(self, function, tasks):
149 Like the built-in :func:`map` function, apply a function to all
150 of the values in a list and return the list of results.
153 The function to apply to the list.
156 The list of elements.
168 print(
"Master replacing pool function with {0}."
176 for i
in range(self.
size):
177 r = self.comm.isend(F, dest=i + 1)
182 MPI.Request.waitall(requests)
187 for i, task
in enumerate(tasks):
188 worker = i % self.
size + 1
190 print(
"Sent task {0} to worker {1} with tag {2}."
191 .format(task, worker, i))
192 r = self.comm.isend(task, dest=worker, tag=i)
194 MPI.Request.waitall(requests)
198 for i
in range(ntask):
199 worker = i % self.
size + 1
201 print(
"Master waiting for worker {0} with tag {1}"
203 result = self.comm.recv(source=worker, tag=i)
204 results.append(result)
209 Just send a message off to all the pool members which contains
210 the special :class:`_close_pool_message` sentinel.
214 for i
in range(self.
size):
220 def __exit__(self, *args):