gwastro/gwin

View on GitHub
bin/gwin_plot_acf

Summary

Maintainability
Test Coverage
#!/usr/bin/env python
 
# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 
import argparse
import logging
import sys
 
from matplotlib import use
use('agg')
from matplotlib import pyplot as plt
 
import pycbc
from pycbc import results
 
from gwin import (__version__, option_utils)
 
# command line usage
parser = argparse.ArgumentParser(
description="Plots autocorrelation function from inference samples.")
 
# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
help="Print logging info.")
parser.add_argument('--version', action='version', version=__version__,
help='show version number and exit')
# output plot options
parser.add_argument("--output-file", type=str, required=True,
help="Path to output plot.")
# add plotting options
parser.add_argument("--ymin", type=float,
help="Minimum value to plot on y-axis.")
parser.add_argument("--ymax", type=float,
help="Maximum value to plot on y-axis.")
# add results group
option_utils.add_inference_results_option_group(parser)
parser.add_argument("--per-walker", action="store_true", default=False,
help="Plot the ACF for each walker separately. Default is to average "
"parameter values over the walkers, then compute the ACF for the "
"averaged chain. Warning: turning this option on can significantly "
"increase the run time.")
parser.add_argument("--walkers", type=int, nargs="+", default=None,
help="Only include the given walkers when computing the ACF.")
parser.add_argument("--temps", nargs="+", default=None,
help="For multi-tempered samplers, plot the given temperatures. May "
"provide either a sequence of integers specifying the temperatures "
"to plot, or 'all' for all temperatures. Default is to only plot "
"the coldest (= 0) temperature chain.")
parser.add_argument("--no-legend", action="store_true", default=False,
help="Do not add a legend to the plot.")
 
# parse the command line
opts = parser.parse_args()
 
# setup log
pycbc.init_logging(opts.verbose)
 
# add the temperature args if it is specified
additional_args = {}
if opts.temps is not None:
if 'all' in opts.temps:
temps = 'all'
else:
temps = map(int, opts.temps)
additional_args['temps'] = temps
 
# load the results
fp, parameters, labels, _ = option_utils.results_from_cli(opts,
load_samples=False)
 
# calculate autocorrelation function
logging.info("Calculating autocorrelation functions")
acfs = fp.sampler_class.compute_acfs(fp, start_index=opts.thin_start,
end_index=opts.thin_end,
per_walker=opts.per_walker,
walkers=opts.walkers,
parameters=parameters,
**additional_args)
 
# check if we have multiple temperatures
if opts.per_walker:
multi_tempered = acfs.ndim == 3
else:
multi_tempered = acfs.ndim == 2
try:
temps = additional_args['temps']
if temps == 'all':
temps = range(acfs.shape[0])
except KeyError:
temps = [0]
fig_height = max(6, 3*len(temps))
 
# plot autocorrelation
logging.info("Plotting autocorrelation functions")
fig = plt.figure(figsize=(8,fig_height))
for kk,tk in enumerate(temps):
if multi_tempered:
logging.info("Temperature {}".format(tk))
axnum = len(temps) - kk
ax = fig.add_subplot(len(temps), 1, axnum)
for pi,param in enumerate(parameters):
logging.info("Parameter {}".format(param))
if multi_tempered:
acf = acfs[param][kk,...]
else:
acf = acfs[param]
if opts.per_walker:
lbl = '{}, walker {}'
for wi in range(acf.shape[0]):
wnum = opts.walkers[wi] if opts.walkers is not None else wi
ax.plot(acf[wi,:], label=lbl.format(labels[pi], wnum))
else:
ax.plot(acf, label=labels[pi])
if multi_tempered:
t = 1./fp.attrs['betas'][tk]
ax.set_title(r'$T = {}$ (chain {})'.format(t, tk))
if axnum == len(temps):
ax.set_xlabel("iteration")
ax.set_ylabel("autocorrelation function")
if opts.ymin:
ax.set_ylim(ymin=opts.ymin)
if opts.ymax:
ax.set_ylim(ymax=opts.ymax)
# add legend to the top axis
if not opts.no_legend:
ax.legend()
 
plt.tight_layout()
 
# save figure with meta-data
caption_kwargs = {
"parameters" : ", ".join(labels),
}
caption = """Autocorrelation function (ACF)."""
title = "Autocorrelation Function for {parameters}".format(**caption_kwargs)
results.save_fig_with_metadata(fig, opts.output_file,
cmd=" ".join(sys.argv),
title=title,
caption=caption)
plt.close()
 
# exit
fp.close()
logging.info("Done")