bin/gwin_plot_inj_recovery
#!/usr/bin/env python""" Plots the recovered versus injected parameter values from a populationof injections.""" import argparseimport loggingimport matplotlib as mpl; mpl.use("Agg")import matplotlib.colorbar as cbarimport matplotlib.pyplot as pltimport numpyimport pycbcimport pycbc.versionfrom matplotlib import cmfrom pycbc import injectfrom pycbc import transformsfrom gwin import (__version__, option_utils) # parse command lineparser = argparse.ArgumentParser(usage=__file__ + " [--options]", description=__doc__)parser.add_argument("--version", action="version", version=__version__, help="Prints version information.")parser.add_argument("--output-file", required=True, type=str, help="Path to save output plot.")parser.add_argument("--verbose", action="store_true", help="Allows print statements.")parser.add_argument("--quantiles", nargs=2, type=float, default=[0.05, 0.95], help="Quantiles to use as limits.")parser.add_argument("--injection-hdf-group", default="H1/injections", help="HDF group that contains injection values.")option_utils.add_inference_results_option_group(parser)option_utils.add_scatter_option_group(parser)opts = parser.parse_args() # set loggingpycbc.init_logging(opts.verbose) # read resultsfp, parameters, labels, samples = option_utils.results_from_cli(opts) # only plot one parameterassert(len(opts.parameters) == 1)parameter = parameters[0] if isinstance(parameters, list) else parameterslabel = labels[0][0] if isinstance(labels, list) else labels # create figurefig = plt.figure()ax = fig.add_subplot(111) # typecast to list for iteratationsamples = [samples] if not isinstance(samples, list) else samplesfp = [fp] if not isinstance(fp, list) else fp # if user wants a colorbarif opts.z_arg: # store list of z-axis values and label zvals = [] zlabel = None # loop over input files logging.info("Reading %s values", opts.z_arg) for i, input_fp in enumerate(fp): # get z-axis values and label likelihood_stats = input_fp.read_likelihood_stats( thin_start=opts.thin_start, thin_end=opts.thin_end, thin_interval=opts.thin_interval, iteration=opts.iteration) vals, zlabel = option_utils.get_zvalues(input_fp, opts.z_arg, likelihood_stats) zvals.append(numpy.median(vals)) # update range of colorbar min_zval = vals.min() if i == 0 else min(min_zval, vals.min()) max_zval = vals.max() if i == 0 else max(max_zval, vals.max()) # create colormap cmap = cm.get_cmap(opts.scatter_cmap) vmin = opts.vmin if opts.vmin else min_zval vmax = opts.vmax if opts.vmax else max_zval norm = mpl.colors.Normalize(vmin, vmax) # loop over input files and its sampleslogging.info("Plotting")for i, (input_file, input_fp, input_samples) in enumerate(zip(opts.input_file, fp, samples)): # read injections from HDF input file injs = inject.InjectionSet(input_file, hdf_group=opts.injection_hdf_group) # check if need extra parameters than parameters stored in injection file _, ts = transforms.get_common_cbc_transforms(opts.parameters, injs.table.fieldnames) # add parameters not included in injection file inj_parameters = transforms.apply_transforms(injs.table, ts) # get paramter values sampled_vals = input_samples[parameter].to_array() injected_vals = [e[0] for e in inj_parameters[parameter]] # compute quantiles of sampled results quantiles = numpy.array([numpy.percentile(sampled_vals, 100 * q) for q in opts.quantiles]) # get median and lowest and highest quntiles for plotting med = numpy.median(sampled_vals) high = quantiles.max() low = quantiles.min() # get color if opts.z_arg: color = cmap(norm(zvals[i])) else: color = "black" # plot a point for each injection if len(injected_vals) > 1: logging.warn("More than one injection in file %s", input_file) ax.errorbar([injected_vals], [med - injected_vals], yerr=[[(med - low)], [(high - med)]], ecolor=color, linestyle="None", zorder=10) # create a colorbarif opts.z_arg: cax, _ = cbar.make_axes(ax) cb2 = cbar.ColorbarBase(cax, cmap=cmap, norm=norm) cb2.set_label(r"Recovered Median " + zlabel) # set labelsax.set_ylabel(r"Recovered " + label + r"- Injected " + label)ax.set_xlabel(r"Injected " + r"{}".format(label)) # add grid to plotax.grid() # add 1:1 line to plotax.axhline(0, linestyle="dashed", color="gray", zorder=9) # save plotplt.savefig(opts.output_file) # donelogging.info("Done")