gstlal-inspiral  0.4.2
 All Classes Namespaces Files Functions Variables Pages
svd_bank.py
Go to the documentation of this file.
1 # Copyright (C) 2010 Kipp Cannon, Chad Hanna, Leo Singer
2 # Copyright (C) 2009 Kipp Cannon, Chad Hanna
3 #
4 # This program is free software; you can redistribute it and/or modify it
5 # under the terms of the GNU General Public License as published by the
6 # Free Software Foundation; either version 2 of the License, or (at your
7 # option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful, but
10 # WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
12 # Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License along
15 # with this program; if not, write to the Free Software Foundation, Inc.,
16 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
17 
18 ## @file
19 # The module to implement SVD decomposition of CBC waveforms
20 #
21 # ### Review Status
22 #
23 # | Names | Hash | Date |
24 # | ------------------------------------------- | ------------------------------------------- | ---------- |
25 # | Florent, Sathya, Duncan Me, Jolien, Kipp, Chad | 7536db9d496be9a014559f4e273e1e856047bf71 | 2014-04-30 |
26 #
27 # #### Actions
28 # - Consider a study of how to supply the svd / time slice boundaries
29 #
30 
31 ## @package svd_bank
32 
33 
34 #
35 # =============================================================================
36 #
37 # Preamble
38 #
39 # =============================================================================
40 #
41 
42 
43 import numpy
44 import sys
45 
46 from glue.ligolw import ligolw
47 from glue.ligolw import lsctables
48 from glue.ligolw import array as ligolw_array
49 from glue.ligolw import param as ligolw_param
50 from glue.ligolw import utils as ligolw_utils
51 from glue.ligolw import types as ligolw_types
52 from glue.ligolw.utils import process as ligolw_process
53 from pylal import series
54 
55 Attributes = ligolw.sax.xmlreader.AttributesImpl
56 
57 from gstlal import cbc_template_fir
58 from gstlal import misc as gstlalmisc
59 from gstlal import templates
60 
61 
62 # FIXME: require calling code to provide the content handler
63 class DefaultContentHandler(ligolw.LIGOLWContentHandler):
64  pass
65 ligolw_array.use_in(DefaultContentHandler)
66 ligolw_param.use_in(DefaultContentHandler)
67 lsctables.use_in(DefaultContentHandler)
68 
69 
70 #
71 # =============================================================================
72 #
73 # Utilities
74 #
75 # =============================================================================
76 #
77 
78 #
79 # Read approximant
80 #
81 
82 def read_approximant(xmldoc, programs = ("gstlal_bank_splitter",)):
83  process_ids = set()
84  for program in programs:
85  process_ids |= lsctables.ProcessTable.get_table(xmldoc).get_ids_by_program(program)
86  if not process_ids:
87  raise ValueError("document must contain process entries from %s" % ", ".join(programs))
88  approximant = set(row.pyvalue for row in lsctables.ProcessParamsTable.get_table(xmldoc) if (row.process_id in process_ids) and (row.param == "--approximant"))
89  if not approximant:
90  raise ValueError("document must contain an 'approximant' process_params entry from %s" % ", ".join("'%s'" for program in programs))
91  if len(approximant) > 1:
92  raise ValueError("document must contain only one approximant")
93  approximant = approximant.pop()
95  return approximant
96 
97 #
98 # check final frequency is populated and return the max final frequency
99 #
100 
101 def check_ffinal_and_find_max_ffinal(xmldoc):
102  f_final = lsctables.SnglInspiralTable.get_table(xmldoc).getColumnByName("f_final")
103  if not all(f_final):
104  raise ValueError("f_final column not populated")
105  return max(f_final)
106 
107 #
108 # sum-of-squares false alarm probability
109 #
110 
111 
112 def sum_of_squares_threshold_from_fap(fap, coefficients):
113  return gstlalmisc.max_stat_thresh(coefficients, fap)
114  #return gstlalmisc.cdf_weighted_chisq_Pinv(coefficients, numpy.zeros(coefficients.shape, dtype = "double"), numpy.ones(coefficients.shape, dtype = "int"), 0.0, 1.0 - fap, -1, fap / 16.0)
115 
116 
117 #
118 # =============================================================================
119 #
120 # Pipeline Metadata
121 #
122 # =============================================================================
123 #
124 
125 
126 class BankFragment(object):
127  def __init__(self, rate, start, end):
128  self.rate = rate
129  self.start = start
130  self.end = end
131 
132  def set_template_bank(self, template_bank, tolerance, snr_thresh, identity_transform = False, verbose = False):
133  if verbose:
134  print >>sys.stderr, "\t%d templates of %d samples" % template_bank.shape
135 
136  self.orthogonal_template_bank, self.singular_values, self.mix_matrix, self.chifacs = cbc_template_fir.decompose_templates(template_bank, tolerance, identity = identity_transform)
137 
138  if self.singular_values is not None:
139  self.sum_of_squares_weights = numpy.sqrt(self.chifacs.mean() * gstlalmisc.ss_coeffs(self.singular_values,snr_thresh))
140  else:
141  self.sum_of_squares_weights = None
142  if verbose:
143  print >>sys.stderr, "\tidentified %d components" % self.orthogonal_template_bank.shape[0]
144  print >>sys.stderr, "\tsum-of-squares expectation value is %g" % self.chifacs.mean()
145 
146 
147 class Bank(object):
148  def __init__(self, bank_xmldoc, psd, time_slices, gate_fap, snr_threshold, tolerance, flow = 40.0, autocorrelation_length = None, logname = None, identity_transform = False, verbose = False, bank_id = None):
149  # FIXME: remove template_bank_filename when no longer needed
150  # by trigger generator element
151  self.template_bank_filename = None
152  self.filter_length = max(time_slices['end'])
153  self.snr_threshold = snr_threshold
154  self.logname = logname
155  self.bank_id = bank_id
156 
157  # Generate downsampled templates
158  template_bank, self.autocorrelation_bank, self.autocorrelation_mask, self.sigmasq, processed_psd = cbc_template_fir.generate_templates(
159  lsctables.SnglInspiralTable.get_table(bank_xmldoc),
160  read_approximant(bank_xmldoc),
161  psd,
162  flow,
163  time_slices,
164  autocorrelation_length = autocorrelation_length,
165  verbose = verbose)
166 
167  # Include signal inspiral table
168  self.sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(bank_xmldoc)
169  # Include the processed psd
170  self.processed_psd = processed_psd
171 
172  # Assign template banks to fragments
173  self.bank_fragments = [BankFragment(rate,begin,end) for rate,begin,end in time_slices]
174  for i, bank_fragment in enumerate(self.bank_fragments):
175  if verbose:
176  print >>sys.stderr, "constructing template decomposition %d of %d: %g s ... %g s" % (i + 1, len(self.bank_fragments), -bank_fragment.end, -bank_fragment.start)
177  bank_fragment.set_template_bank(template_bank[i], tolerance, self.snr_threshold, identity_transform = identity_transform, verbose = verbose)
178 
179  if bank_fragment.sum_of_squares_weights is not None:
180  self.gate_threshold = sum_of_squares_threshold_from_fap(gate_fap, numpy.array([weight**2 for bank_fragment in self.bank_fragments for weight in bank_fragment.sum_of_squares_weights], dtype = "double"))
181  else:
182  self.gate_threshold = 0
183  if verbose:
184  print >>sys.stderr, "sum-of-squares threshold for false-alarm probability of %.16g: %.16g" % (gate_fap, self.gate_threshold)
185 
186  def get_rates(self):
187  return set(bank_fragment.rate for bank_fragment in self.bank_fragments)
188 
189  # FIXME: remove set_template_bank_filename when no longer needed
190  # by trigger generator element
191  def set_template_bank_filename(self,name):
192  self.template_bank_filename = name
193 
194 
195 
196 def build_bank(template_bank_filename, psd, flow, ortho_gate_fap, snr_threshold, svd_tolerance, padding = 1.5, identity_transform = False, verbose = False, autocorrelation_length = 201, samples_min = 1024, samples_max_256 = 1024, samples_max_64 = 2048, samples_max = 4096, bank_id = None, contenthandler = DefaultContentHandler):
197  """!
198  Return an instance of a Bank class.
199 
200  @param template_bank_filename The template bank filename containing a subbank of templates to decompose in a single inpsiral table.
201  @param psd A class instance of a psd.
202  @param flow The lower frequency cutoff.
203  @param ortho_gate_fap The FAP threshold for the sum of squares threshold, see http://arxiv.org/abs/1101.0584
204  @param snr_threshold The SNR threshold for the search
205  @param svd_tolerance The target SNR loss of the SVD, see http://arxiv.org/abs/1005.0012
206  @param padding The padding from Nyquist for any template time slice, e.g., if a time slice has a Nyquist of 256 Hz and the padding is set to 2, only allow the template frequency to extend to 128 Hz.
207  @param identity_transform Don't do the SVD, just do time slices and keep the raw waveforms
208  @param verbose Be verbose
209  @param autocorrelation_length The number of autocorrelation samples to use in the chisquared test. Must be odd
210  @param samples_min The minimum number of samples to use in any time slice
211  @param samples_max_256 The maximum number of samples to have in any time slice greater than or equal to 256 Hz
212  @param samples_max_64 The maximum number of samples to have in any time slice greater than or equal to 64 Hz
213  @param samples_max The maximum number of samples in any time slice below 64 Hz
214  @param bank_id The id of the bank in question
215  @param contenthandler The ligolw content handler for file I/O
216  """
217 
218  # Open template bank file
219  bank_xmldoc = ligolw_utils.load_filename(template_bank_filename, contenthandler = contenthandler, verbose = verbose)
220 
221  # Get sngl inspiral table
222  bank_sngl_table = lsctables.SnglInspiralTable.get_table(bank_xmldoc)
223 
224  # Choose how to break up templates in time
225  time_freq_bounds = templates.time_slices(
226  bank_sngl_table,
227  fhigh=check_ffinal_and_find_max_ffinal(bank_xmldoc),
228  flow = flow,
229  padding = padding,
230  samples_min = samples_min,
231  samples_max_256 = samples_max_256,
232  samples_max_64 = samples_max_64,
233  samples_max = samples_max,
234  verbose=verbose)
235 
236  # Generate templates, perform SVD, get orthogonal basis
237  # and store as Bank object
238  bank = Bank(
239  bank_xmldoc,
240  psd[bank_sngl_table[0].ifo],
241  time_freq_bounds,
242  gate_fap = ortho_gate_fap,
243  snr_threshold = snr_threshold,
244  tolerance = svd_tolerance,
245  flow = flow,
246  autocorrelation_length = autocorrelation_length, # samples
247  identity_transform = identity_transform,
248  verbose = verbose,
249  bank_id = bank_id
250  )
251 
252  # FIXME: remove this when no longer needed
253  # by trigger generator element.
254  bank.set_template_bank_filename(template_bank_filename)
255  return bank
256 
257 
258 def write_bank(filename, banks, cliplefts = None, cliprights = None, contenthandler = DefaultContentHandler, write_psd = False, verbose = False):
259  """Write SVD banks to a LIGO_LW xml file."""
260 
261  # Create new document
262  xmldoc = ligolw.Document()
263  lw = ligolw.LIGO_LW()
264 
265  for bank, clipleft, clipright in zip(banks, cliplefts, cliprights):
266  # set up root for this sub bank
267  root = ligolw.LIGO_LW(Attributes({u"Name": u"gstlal_svd_bank_Bank"}))
268  lw.appendChild(root)
269 
270  # FIXME FIXME FIXME move this clipping stuff to the Bank class
271  # set the right clipping index
272  clipright = len(bank.sngl_inspiral_table) - clipright
273 
274  # Apply clipping option to sngl inspiral table
275  # put the bank table into the output document
276  new_sngl_table = lsctables.New(lsctables.SnglInspiralTable)
277  for row in bank.sngl_inspiral_table[clipleft:clipright]:
278  new_sngl_table.append(row)
279 
280  # put the possibly clipped table into the file
281  root.appendChild(new_sngl_table)
282 
283  # Add root-level scalar params
284  root.appendChild(ligolw_param.new_param('filter_length', ligolw_types.FromPyType[float], bank.filter_length))
285  root.appendChild(ligolw_param.new_param('gate_threshold', ligolw_types.FromPyType[float], bank.gate_threshold))
286  root.appendChild(ligolw_param.new_param('logname', ligolw_types.FromPyType[str], bank.logname))
287  root.appendChild(ligolw_param.new_param('snr_threshold', ligolw_types.FromPyType[float], bank.snr_threshold))
288  root.appendChild(ligolw_param.new_param('template_bank_filename', ligolw_types.FromPyType[str], bank.template_bank_filename))
289  root.appendChild(ligolw_param.new_param('bank_id', ligolw_types.FromPyType[str], bank.bank_id))
290 
291  # apply clipping to autocorrelations and sigmasq
292  bank.autocorrelation_bank = bank.autocorrelation_bank[clipleft:clipright,:]
293  bank.sigmasq = bank.sigmasq[clipleft:clipright]
294 
295  # Add root-level arrays
296  # FIXME: ligolw format now supports complex-valued data
297  root.appendChild(ligolw_array.from_array('autocorrelation_bank_real', bank.autocorrelation_bank.real))
298  root.appendChild(ligolw_array.from_array('autocorrelation_bank_imag', bank.autocorrelation_bank.imag))
299  root.appendChild(ligolw_array.from_array('autocorrelation_mask', bank.autocorrelation_mask))
300  root.appendChild(ligolw_array.from_array('sigmasq', numpy.array(bank.sigmasq)))
301 
302  # Write bank fragments
303  for i, frag in enumerate(bank.bank_fragments):
304  # Start new container
305  el = ligolw.LIGO_LW()
306 
307  # Apply clipping option
308  if frag.mix_matrix is not None:
309  frag.mix_matrix = frag.mix_matrix[:,clipleft*2:clipright*2]
310  frag.chifacs = frag.chifacs[clipleft*2:clipright*2]
311 
312  # Add scalar params
313  el.appendChild(ligolw_param.new_param('start', ligolw_types.FromPyType[float], frag.start))
314  el.appendChild(ligolw_param.new_param('end', ligolw_types.FromPyType[float], frag.end))
315  el.appendChild(ligolw_param.new_param('rate', ligolw_types.FromPyType[int], frag.rate))
316 
317  # Add arrays
318  el.appendChild(ligolw_array.from_array('chifacs', frag.chifacs))
319  if frag.mix_matrix is not None:
320  el.appendChild(ligolw_array.from_array('mix_matrix', frag.mix_matrix))
321  el.appendChild(ligolw_array.from_array('orthogonal_template_bank', frag.orthogonal_template_bank))
322  if frag.singular_values is not None:
323  el.appendChild(ligolw_array.from_array('singular_values', frag.singular_values))
324  if frag.sum_of_squares_weights is not None:
325  el.appendChild(ligolw_array.from_array('sum_of_squares_weights', frag.sum_of_squares_weights))
326 
327  # Add bank fragment container to root container
328  root.appendChild(el)
329 
330  # put a copy of the processed PSD file in
331  # FIXME in principle this could be different for each bank included in
332  # this file, but we only put one here
333  if write_psd:
334  series.make_psd_xmldoc({bank.sngl_inspiral_table[0].ifo: bank.processed_psd}, lw)
335 
336  # add top level LIGO_LW to document
337  xmldoc.appendChild(lw)
338 
339  # Write to file
340  ligolw_utils.write_filename(xmldoc, filename, gz = filename.endswith('.gz'), verbose = verbose)
341 
342 
343 def read_banks(filename, contenthandler = DefaultContentHandler, verbose = False):
344  """Read SVD banks from a LIGO_LW xml file."""
345 
346  # Load document
347  xmldoc = ligolw_utils.load_filename(filename, contenthandler = contenthandler, verbose = verbose)
348 
349  banks = []
350 
351  for root in (elem for elem in xmldoc.getElementsByTagName(ligolw.LIGO_LW.tagName) if elem.hasAttribute(u"Name") and elem.Name == "gstlal_svd_bank_Bank"):
352 
353  # Create new SVD bank object
354  bank = Bank.__new__(Bank)
355 
356  # Read sngl inspiral table
357  bank.sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(root)
358 
359  # Read root-level scalar parameters
360  bank.filter_length = ligolw_param.get_pyvalue(root, 'filter_length')
361  bank.gate_threshold = ligolw_param.get_pyvalue(root, 'gate_threshold')
362  bank.logname = ligolw_param.get_pyvalue(root, 'logname')
363  bank.snr_threshold = ligolw_param.get_pyvalue(root, 'snr_threshold')
364  bank.template_bank_filename = ligolw_param.get_pyvalue(root, 'template_bank_filename')
365  bank.bank_id = ligolw_param.get_pyvalue(root, 'bank_id')
366 
367  # Read root-level arrays
368  bank.autocorrelation_bank = ligolw_array.get_array(root, 'autocorrelation_bank_real').array + 1j * ligolw_array.get_array(root, 'autocorrelation_bank_imag').array
369  bank.autocorrelation_mask = ligolw_array.get_array(root, 'autocorrelation_mask').array
370  bank.sigmasq = ligolw_array.get_array(root, 'sigmasq').array
371 
372  # Read bank fragments
373  bank.bank_fragments = []
374  for el in (node for node in root.childNodes if node.tagName == ligolw.LIGO_LW.tagName):
375  frag = BankFragment.__new__(BankFragment)
376 
377  # Read scalar params
378  frag.start = ligolw_param.get_pyvalue(el, 'start')
379  frag.end = ligolw_param.get_pyvalue(el, 'end')
380  frag.rate = ligolw_param.get_pyvalue(el, 'rate')
381 
382  # Read arrays
383  frag.chifacs = ligolw_array.get_array(el, 'chifacs').array
384  try:
385  frag.mix_matrix = ligolw_array.get_array(el, 'mix_matrix').array
386  except ValueError:
387  frag.mix_matrix = None
388  frag.orthogonal_template_bank = ligolw_array.get_array(el, 'orthogonal_template_bank').array
389  try:
390  frag.singular_values = ligolw_array.get_array(el, 'singular_values').array
391  except ValueError:
392  frag.singular_values = None
393  try:
394  frag.sum_of_squares_weights = ligolw_array.get_array(el, 'sum_of_squares_weights').array
395  except ValueError:
396  frag.sum_of_squares_weights = None
397  bank.bank_fragments.append(frag)
398 
399  banks.append(bank)
400  return banks