bin/gwin_extract_samples
#!/usr/bin/env python # Copyright (C) 2017 Collin Capano## This program is free software; you can redistribute it and/or modify it# under the terms of the GNU General Public License as published by the# Free Software Foundation; either version 3 of the License, or (at your# option) any later version.## This program is distributed in the hope that it will be useful, but# WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General# Public License for more details.## You should have received a copy of the GNU General Public License along# with this program; if not, write to the Free Software Foundation, Inc.,# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. ## =============================================================================## Preamble## =============================================================================#"""Extracts some or all samples from an InferenceFile, writing results to a newInferenceFile.""" import osimport argparse import numpy import pycbc from gwin import (__version__, option_utils)from gwin.io.hdf import InferenceFilefrom gwin.io.txt import InferenceTXTFile parser = argparse.ArgumentParser(description=__doc__) parser.add_argument('-V', '--version', action='version', version=__version__, help='show version number and exit')parser.add_argument("--output-file", type=str, required=True, help="Output file to create.")parser.add_argument("--force", action="store_true", default=False, help="If the output-file already exists, overwrite it. " "Otherwise, an IOError is raised.")parser.add_argument("--posterior-only", action="store_true", default=False, help="Write copied parameter samples and likelihood stats " "as flattened arrays. For example, if the sampler " "wrote a parameter's samples to " "{samples_group}/{param}/walker{x}/, then the " "copied file will have all selected samples from " "all walkers in {samples_group/{param}/. Default is " "for the copied file to have the same structure " "as the input file.")option_utils.add_inference_results_option_group(parser)parser.add_argument("--temperature", nargs="+", default=None, help="For parallel-tempered samplers, which temperature " "chain(s) to extract. Options are 'all', or one or " "more indices indicating the temperature chain " "(0 = coldest). Default is to only extract from " "the coldest chain. ")parser.add_argument("--verbose", action="store_true", default=False, help="Be verbose") opts = parser.parse_args() pycbc.init_logging(opts.verbose) # check that the output doesn't already existif os.path.exists(opts.output_file) and not opts.force: raise IOError("output file already exists; use --force if you wish to " "overwrite.") # add any sampler specific optionssampler_kwargs = {}if opts.temperature is not None: temps = opts.temperature if temps == ['all']: temps = temps[0] else: temps = map(int, temps) sampler_kwargs['temps'] = temps # save TXT fileif option_utils.get_file_type(opts.output_file) == InferenceTXTFile: fp, parameters, labels, samples = option_utils.results_from_cli( opts, **sampler_kwargs) likelihood_stats = fp.read_likelihood_stats( thin_start=opts.thin_start, thin_end=opts.thin_end, thin_interval=opts.thin_interval, iteration=opts.iteration) stat_names = fp[fp.stats_group].keys() n_cols = len(parameters) + len(stat_names) n_rows = samples.size out = numpy.hstack([samples.to_array(fields=parameters).T, likelihood_stats.to_array(fields=stat_names).T]) InferenceTXTFile.write(opts.output_file, out, labels + stat_names) # otherwise save HDF fileelse: # open the input and parse parameters if len(opts.input_file) > 1: raise ValueError("Too many input files.") source = InferenceFile(opts.input_file[0]) parameters, pnames = option_utils.parse_parameters_opt(opts.parameters) # copy to output target = source.copy(opts.output_file, parameters=parameters, parameter_names=pnames, posterior_only=opts.posterior_only, thin_start=opts.thin_start, thin_end=opts.thin_end, thin_interval=opts.thin_interval, iteration=opts.iteration, **sampler_kwargs) target.close()