gwastro/gwin

View on GitHub
bin/gwin_extract_samples

Summary

Maintainability
Test Coverage
#!/usr/bin/env python
 
# Copyright (C) 2017 Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 
 
#
# =============================================================================
#
# Preamble
#
# =============================================================================
#
"""Extracts some or all samples from an InferenceFile, writing results to a new
InferenceFile.
"""
 
import os
import argparse
 
import numpy
 
import pycbc
 
from gwin import (__version__, option_utils)
from gwin.io.hdf import InferenceFile
from gwin.io.txt import InferenceTXTFile
 
parser = argparse.ArgumentParser(description=__doc__)
 
parser.add_argument('-V', '--version', action='version', version=__version__,
help='show version number and exit')
parser.add_argument("--output-file", type=str, required=True,
help="Output file to create.")
parser.add_argument("--force", action="store_true", default=False,
help="If the output-file already exists, overwrite it. "
"Otherwise, an IOError is raised.")
parser.add_argument("--posterior-only", action="store_true", default=False,
help="Write copied parameter samples and likelihood stats "
"as flattened arrays. For example, if the sampler "
"wrote a parameter's samples to "
"{samples_group}/{param}/walker{x}/, then the "
"copied file will have all selected samples from "
"all walkers in {samples_group/{param}/. Default is "
"for the copied file to have the same structure "
"as the input file.")
option_utils.add_inference_results_option_group(parser)
parser.add_argument("--temperature", nargs="+", default=None,
help="For parallel-tempered samplers, which temperature "
"chain(s) to extract. Options are 'all', or one or "
"more indices indicating the temperature chain "
"(0 = coldest). Default is to only extract from "
"the coldest chain. ")
parser.add_argument("--verbose", action="store_true", default=False,
help="Be verbose")
 
opts = parser.parse_args()
 
pycbc.init_logging(opts.verbose)
 
# check that the output doesn't already exist
if os.path.exists(opts.output_file) and not opts.force:
raise IOError("output file already exists; use --force if you wish to "
"overwrite.")
 
# add any sampler specific options
sampler_kwargs = {}
if opts.temperature is not None:
temps = opts.temperature
if temps == ['all']:
temps = temps[0]
else:
temps = map(int, temps)
sampler_kwargs['temps'] = temps
 
# save TXT file
if option_utils.get_file_type(opts.output_file) == InferenceTXTFile:
fp, parameters, labels, samples = option_utils.results_from_cli(
opts, **sampler_kwargs)
likelihood_stats = fp.read_likelihood_stats(
thin_start=opts.thin_start,
thin_end=opts.thin_end,
thin_interval=opts.thin_interval,
iteration=opts.iteration)
stat_names = fp[fp.stats_group].keys()
n_cols = len(parameters) + len(stat_names)
n_rows = samples.size
out = numpy.hstack([samples.to_array(fields=parameters).T,
likelihood_stats.to_array(fields=stat_names).T])
InferenceTXTFile.write(opts.output_file, out, labels + stat_names)
 
# otherwise save HDF file
else:
 
# open the input and parse parameters
if len(opts.input_file) > 1:
raise ValueError("Too many input files.")
source = InferenceFile(opts.input_file[0])
parameters, pnames = option_utils.parse_parameters_opt(opts.parameters)
 
# copy to output
target = source.copy(opts.output_file, parameters=parameters,
parameter_names=pnames,
posterior_only=opts.posterior_only,
thin_start=opts.thin_start, thin_end=opts.thin_end,
thin_interval=opts.thin_interval,
iteration=opts.iteration,
**sampler_kwargs)
target.close()