gwastro/gwin

View on GitHub
bin/gwin_plot_inj_recovery

Summary

Maintainability
Test Coverage
#!/usr/bin/env python
""" Plots the recovered versus injected parameter values from a population
of injections.
"""
 
import argparse
import logging
import matplotlib as mpl; mpl.use("Agg")
import matplotlib.colorbar as cbar
import matplotlib.pyplot as plt
import numpy
import pycbc
import pycbc.version
from matplotlib import cm
from pycbc import inject
from pycbc import transforms
from gwin import (__version__, option_utils)
 
# parse command line
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
description=__doc__)
parser.add_argument("--version", action="version", version=__version__,
help="Prints version information.")
parser.add_argument("--output-file", required=True, type=str,
help="Path to save output plot.")
parser.add_argument("--verbose", action="store_true",
help="Allows print statements.")
parser.add_argument("--quantiles", nargs=2, type=float, default=[0.05, 0.95],
help="Quantiles to use as limits.")
parser.add_argument("--injection-hdf-group", default="H1/injections",
help="HDF group that contains injection values.")
option_utils.add_inference_results_option_group(parser)
option_utils.add_scatter_option_group(parser)
opts = parser.parse_args()
 
# set logging
pycbc.init_logging(opts.verbose)
 
# read results
fp, parameters, labels, samples = option_utils.results_from_cli(opts)
 
# only plot one parameter
assert(len(opts.parameters) == 1)
parameter = parameters[0] if isinstance(parameters, list) else parameters
label = labels[0][0] if isinstance(labels, list) else labels
 
# create figure
fig = plt.figure()
ax = fig.add_subplot(111)
 
# typecast to list for iteratation
samples = [samples] if not isinstance(samples, list) else samples
fp = [fp] if not isinstance(fp, list) else fp
 
# if user wants a colorbar
if opts.z_arg:
 
# store list of z-axis values and label
zvals = []
zlabel = None
 
# loop over input files
logging.info("Reading %s values", opts.z_arg)
for i, input_fp in enumerate(fp):
 
# get z-axis values and label
likelihood_stats = input_fp.read_likelihood_stats(
thin_start=opts.thin_start, thin_end=opts.thin_end,
thin_interval=opts.thin_interval, iteration=opts.iteration)
vals, zlabel = option_utils.get_zvalues(input_fp, opts.z_arg,
likelihood_stats)
zvals.append(numpy.median(vals))
 
# update range of colorbar
min_zval = vals.min() if i == 0 else min(min_zval, vals.min())
max_zval = vals.max() if i == 0 else max(max_zval, vals.max())
 
# create colormap
cmap = cm.get_cmap(opts.scatter_cmap)
vmin = opts.vmin if opts.vmin else min_zval
vmax = opts.vmax if opts.vmax else max_zval
norm = mpl.colors.Normalize(vmin, vmax)
 
# loop over input files and its samples
logging.info("Plotting")
for i, (input_file, input_fp, input_samples) in enumerate(zip(opts.input_file,
fp, samples)):
 
# read injections from HDF input file
injs = inject.InjectionSet(input_file, hdf_group=opts.injection_hdf_group)
 
# check if need extra parameters than parameters stored in injection file
_, ts = transforms.get_common_cbc_transforms(opts.parameters,
injs.table.fieldnames)
 
# add parameters not included in injection file
inj_parameters = transforms.apply_transforms(injs.table, ts)
 
# get paramter values
sampled_vals = input_samples[parameter].to_array()
injected_vals = [e[0] for e in inj_parameters[parameter]]
 
# compute quantiles of sampled results
quantiles = numpy.array([numpy.percentile(sampled_vals, 100 * q)
for q in opts.quantiles])
 
# get median and lowest and highest quntiles for plotting
med = numpy.median(sampled_vals)
high = quantiles.max()
low = quantiles.min()
 
# get color
if opts.z_arg:
color = cmap(norm(zvals[i]))
else:
color = "black"
 
# plot a point for each injection
if len(injected_vals) > 1:
logging.warn("More than one injection in file %s", input_file)
ax.errorbar([injected_vals],
[med - injected_vals],
yerr=[[(med - low)], [(high - med)]],
ecolor=color, linestyle="None", zorder=10)
 
# create a colorbar
if opts.z_arg:
cax, _ = cbar.make_axes(ax)
cb2 = cbar.ColorbarBase(cax, cmap=cmap, norm=norm)
cb2.set_label(r"Recovered Median " + zlabel)
 
# set labels
ax.set_ylabel(r"Recovered " + label + r"- Injected " + label)
ax.set_xlabel(r"Injected " + r"{}".format(label))
 
# add grid to plot
ax.grid()
 
# add 1:1 line to plot
ax.axhline(0, linestyle="dashed", color="gray", zorder=9)
 
# save plot
plt.savefig(opts.output_file)
 
# done
logging.info("Done")