gwsumm/plot/builtin.py

Summary

Maintainability
F
1 wk
Test Coverage
# -*- coding: utf-8 -*-
# Copyright (C) Duncan Macleod (2013)
#
# This file is part of GWSumm.
#
# GWSumm is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GWSumm is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GWSumm.  If not, see <http://www.gnu.org/licenses/>.

"""Definitions for the standard plots
"""

import os.path
import warnings
from itertools import cycle

import numpy

from matplotlib.colors import LogNorm

from astropy.units import Quantity

from gwpy.plot.colors import tint
from gwpy.plot.gps import GPSTransform
from gwpy.segments import SegmentList

from gwdetchar.plot import texify

from .. import (globalv, io)
from ..mode import (Mode, get_mode)
from ..utils import re_cchar
from ..data import (get_timeseries, get_spectrogram,
                    get_coherence_spectrogram, get_range_spectrogram,
                    get_spectrum, get_coherence_spectrum, get_range_spectrum)
from ..state import ALLSTATE
from .registry import (get_plot, register_plot)
from .mixins import DataLabelSvgMixin

__author__ = 'Duncan Macleod <duncan.macleod@ligo.org>'

DataPlot = get_plot('data')
GREEN = (0.2, 0.8, 0.2)


class TimeSeriesDataPlot(DataLabelSvgMixin, DataPlot):
    """DataPlot of some `TimeSeries` data.
    """
    type = 'timeseries'
    data = 'timeseries'
    defaults = DataPlot.defaults.copy()
    defaults.update({
        'xscale': 'auto-gps',
        'yscale': 'linear',
    })

    def __init__(self, *args, **kwargs):
        super(TimeSeriesDataPlot, self).__init__(*args, **kwargs)
        if self.data == 'timeseries':
            for c in self.channels:
                c._timeseries = True

    def _update_defaults_from_channels(self):
        for chan in self.channels:
            if getattr(chan, 'amplitude_range', None) is not None:
                self.pargs.setdefault('ylim', chan.amplitude_range)
                break

    # -- utilities ------------------------------

    def add_state_segments(self, ax, visible=None, **kwargs):
        """Add an `Axes` below the given ``ax`` displaying the `SummaryState`
        for this `TimeSeriesDataPlot`.

        Parameters
        ----------
        ax : `Axes`
            the set of `Axes` below which to display the state segments.

        visible : `bool`, optional
            whether or not to display the axes, or just make space for them,
            default is `None`, meaning a dynamic choice based on the state

        **kwargs
            other keyword arguments will be passed to the
            :meth:`~gwpy.plot.Plot.add_segments_bar`
            method.
        """
        # allow user to disable the state segments axes
        if self.pargs.pop('no-state-segments', False):
            visible = False
        if visible is None and self.state is not None and (
                self.state.name.lower() != ALLSTATE):
            visible = True
        if visible:
            sax = self.plot.add_segments_bar(self.state, ax, height=.14,
                                             pad=.1,  **kwargs)
            sax.tick_params(axis='y', which='major', labelsize=12)
            sax.yaxis.set_ticks_position('none')
            sax.set_ylim(-.4, .4)
            return sax
        else:
            self.plot.subplots_adjust(bottom=0.17)
            return None

    def add_future_shade(self, gps=None, facecolor='gray', alpha=.1,
                         **kwargs):
        """Shade those parts of the figure that display times in the future
        """
        end = float(self.end)
        # get time 'now'
        if gps is None:
            gps = globalv.NOW
        # allow user to override
        if self.pargs.pop('no-future-shade', False) or end <= gps:
            return
        # shade time axes
        for ax in filter(
                lambda ax: isinstance(ax.xaxis.get_transform(), GPSTransform),
                self.plot.axes):
            ax.axvspan(gps, end, facecolor=facecolor, alpha=alpha, **kwargs)

    # -- init/finalize --------------------------

    def init_plot(self, *args, **kwargs):
        """Initialise the Figure and Axes objects for this
        `TimeSeriesDataPlot`.
        """
        epoch = kwargs.pop('epoch', self.pargs.pop('epoch', None))
        plot = super(TimeSeriesDataPlot, self).init_plot(*args, **kwargs)
        for ax in plot.axes:
            if get_mode() == Mode.month:
                ax.set_xscale('days')
            if isinstance(ax.xaxis.get_transform(), GPSTransform):
                ax.set_epoch(float(epoch if epoch is not None else self.start))
                if ax.get_autoscalex_on():
                    ax.set_xlim(float(self.start), float(self.end))
            ax.grid(visible=True, which='both')
        return plot

    # -- main draw method -----------------------

    def draw(self, outputfile=None):
        """Read in all necessary data, and generate the figure.
        """
        plot = self.init_plot()
        ax = plot.gca()

        plotargs = self.parse_plot_kwargs()
        legendargs = self.parse_legend_kwargs()

        # add data
        channels, groups = list(zip(*self.get_channel_groups()))
        for clist, pargs in list(zip(groups, plotargs)):
            # get data
            valid = self._get_data_segments(clist[0])
            data = [get_timeseries(c, valid, query=False)
                    for c in clist]

            if len(clist) > 1:
                data = [tsl.join(gap='pad', pad=numpy.nan) for tsl in data]
            flatdata = [ts for tsl in data for ts in tsl]
            # validate parameters
            for ts in flatdata:
                # double-check empty
                if ts.x0 is None:
                    ts.epoch = self.start
                # double-check log scales
                if self.logy:
                    ts.value[ts.value == 0] = 1e-100
            # set label
            try:
                label = pargs.pop('label')
            except KeyError:
                try:
                    label = texify(flatdata[0].name)
                except IndexError:
                    label = clist[0]
                else:
                    if self.fileformat == 'svg' and not label.startswith(
                            texify(
                            str(flatdata[0].channel)).split('.')[0]):
                        label += ' [%s]' % (
                            texify(str(flatdata[0].channel)))

            # plot groups or single TimeSeries
            if len(clist) > 1:
                data[1].name = None  # force no labels for shades
                data[2].name = None
                ax.plot_mmm(*data, label=label, **pargs)
            elif len(flatdata) == 0:
                ax.plot(data[0].EntryClass([], epoch=self.start, unit='s',
                                           name=label),
                        label=label, **pargs)
            else:
                for ts in data[0]:
                    line, = ax.plot(ts, label=label, **pargs)
                    label = None
                    pargs['color'] = line.get_color()

        # customise plot
        self.add_hvlines()
        self.apply_parameters(ax, **self.pargs)

        # add legend
        if ax.get_legend_handles_labels()[0]:
            ax.legend(**legendargs)

        self.add_state_segments(ax)
        self.add_future_shade()

        return self.finalize(outputfile=outputfile)

    def _get_data_segments(self, channel):
        """Get data segments for this plot
        """
        if self.state and not self.all_data:
            return self.state.active
        if channel.sample_rate is not None:
            return SegmentList([self.span.protract(
                1/channel.sample_rate.value)])
        return SegmentList([self.span])


register_plot(TimeSeriesDataPlot)


class SpectrogramDataPlot(TimeSeriesDataPlot):
    """DataPlot a Spectrogram
    """
    type = 'spectrogram'
    data = 'spectrogram'
    defaults = TimeSeriesDataPlot.defaults.copy()
    defaults.update({
        'yscale': 'log',
        'ylabel': 'Frequency [Hz]',
        'ratio': None,
        'format': None,
        'rasterized': True,
    })

    def __init__(self, *args, **kwargs):
        super(SpectrogramDataPlot, self).__init__(*args, **kwargs)
        self.ratio = self.pargs.pop('ratio')
        # set default colour-map for median ratio
        if self.ratio in ('median', 'mean'):
            self.pargs.setdefault('cmap', 'Spectral_r')

    def _update_defaults_from_channels(self):
        for channel in self.channels:
            self.pargs.setdefault('ylim', channel.frequency_range)
            if isinstance(self.pargs['ylim'], Quantity):
                self.pargs['ylim'] = self.pargs['ylim'].value

            if self.ratio is None and self.pargs.get('clim') is None:
                if (self.pargs.get('format') in ('amplitude', 'asd') and
                        hasattr(channel, 'asd_range')):
                    self.pargs['clim'] = channel.asd_range
                elif hasattr(channel, 'psd_range'):
                    self.pargs['clim'] = channel.psd_range

    @property
    def pid(self):
        try:
            return self._pid
        except AttributeError:
            super(SpectrogramDataPlot, self).pid
            if (isinstance(self.ratio, str) and
                    os.path.isfile(self.ratio)):
                self._pid += '_REFERENCE_RATIO'
            elif self.ratio:
                self._pid += '_%s_RATIO' % re_cchar.sub(
                    '_', str(self.ratio).upper())
            return self.pid

    @pid.setter
    def pid(self, id_):
        self._pid = str(id_)

    @pid.deleter
    def pid(self):
        del self._pid

    def get_ratio(self, specgrams):
        ratio = self.ratio
        # calculate ratio spectrum
        if len(specgrams) and (
                ratio in ['median', 'mean'] or isinstance(ratio, int)):
            try:
                allspec = specgrams.join(gap='ignore')
            except ValueError as e:
                if 'units do not match' in str(e):
                    warnings.warn(str(e))
                    for spec in specgrams[1:]:
                        spec.unit = specgrams[0].unit
                    allspec = specgrams.join(gap='ignore')
                else:
                    raise
            if isinstance(ratio, int):
                return allspec.percentile(ratio)
            else:
                return getattr(allspec, ratio)(axis=0)
        elif isinstance(ratio, str) and os.path.isfile(ratio):
            try:
                return io.read_frequencyseries(ratio)
            except IOError as e:  # skip if file can't be read
                warnings.warn('IOError: %s' % str(e))
        return ratio

    def draw(self):
        # initialise
        plot = self.init_plot()
        ax = plot.gca()
        ax.grid(visible=True, axis='y', which='major')
        channel = self.channels[0]

        # parse data arguments
        sdform = self.pargs.pop('format')

        # parse colorbar arguments
        clabel = self.pargs.pop('colorlabel', '')
        clim = self.pargs.pop('clim', None)
        clog = self.pargs.pop('logcolor', False)

        # parse plotting arguments
        if clim:  # clim -> (vmin, vmax)
            vmin, vmax = clim
            self.pargs.setdefault('vmin', vmin)
            self.pargs.setdefault('vmax', vmax)
        if clog:  # logcolor -> norm
            self.pargs.setdefault('norm', 'log')
        plotargs = self.parse_plot_kwargs()[0]  # only one channel

        # rework norm='log' into a LogNorm object
        # (this is only 'required' in the case of no data, when ax.scatter
        #  gets called manually below)
        if plotargs.get('norm', None) == 'log':
            vmin = self.pargs.get('vmin')
            vmax = self.pargs.get('vmax')
            plotargs['norm'] = LogNorm(vmin=vmin, vmax=vmax)

        # get data
        if self.state and not self.all_data:
            valid = self.state.active
        else:
            valid = SegmentList([self.span])
        if self.type == 'coherence-spectrogram':
            specgrams = get_coherence_spectrogram(
                self.channels, valid, query=False)
        elif self.type == 'range-spectrogram':
            specgrams = get_range_spectrogram(
                channel, valid, query=False)
        else:
            try:
                specgrams = get_spectrogram(
                    channel, valid, query=False, format=sdform)
            except ValueError as exc:
                if 'need more than 0 values' not in str(exc):
                    raise
                # attempted to do math but one input has zero size
                specgrams = []

        # get ratio as FrequencySeries
        ratio = self.get_ratio(specgrams)

        # plot data
        for i, specgram in enumerate(specgrams):

            # calculate ratio
            if ratio is not None:
                specgram = specgram.ratio(ratio)

            # undo demodulation and crop frequencies
            ylim = self.pargs.get('ylim', None)
            specgram = undo_demodulation(specgram, channel, ylim)
            if ylim is not None:
                specgram = specgram.crop_frequencies(*ylim)

            # plot
            ax.imshow(specgram, **plotargs)

        # add colorbar
        if len(specgrams) == 0:
            ax.pcolormesh([1, 10], [1, 10], [[1, 10], [1, 10]],
                          visible=False, **plotargs)
        ax.colorbar(label=clabel)

        # customise and finalise
        self.apply_parameters(ax, **self.pargs)
        self.add_state_segments(ax)
        self.add_future_shade()

        return self.finalize()


register_plot(SpectrogramDataPlot)


class CoherenceSpectrogramDataPlot(SpectrogramDataPlot):
    """DataPlot a Spectrogram of the coherence between two channels
    """
    type = 'coherence-spectrogram'
    data = 'coherence-spectrogram'
    defaults = SpectrogramDataPlot.defaults.copy()
    defaults.update({
        'ratio': None,
        'format': None,
        'clim': None,
        'logcolor': False,
        'colorlabel': None,
    })


register_plot(CoherenceSpectrogramDataPlot)


class SpectrumDataPlot(DataPlot):
    """Spectrum plot for a `SummaryTab`
    """
    type = 'spectrum'
    data = 'spectrum'
    defaults = DataPlot.defaults.copy()
    defaults.update({
        'xscale': 'log',
        'yscale': 'log',
        'format': None,
        'zorder': 1,
        'no-percentiles': False,
        'reference-linestyle': '--',
    })

    def _update_defaults_from_channels(self):
        for channel in self.channels:
            if getattr(channel, 'frequency_range', None) is not None:
                self.pargs.setdefault('xlim', channel.frequency_range)
                if isinstance(self.pargs['xlim'], Quantity):
                    self.pargs['xlim'] = self.pargs['xlim'].value
            if (self.pargs.get('format') in ['amplitude', 'asd'] and
                    hasattr(channel, 'asd_range')):
                self.pargs.setdefault('ylim', channel.asd_range)
            elif hasattr(channel, 'psd_range'):
                self.pargs.setdefault('ylim', channel.psd_range)

    def draw(self):
        pargs = self.pargs.copy()
        try:
            self._draw()
        except OverflowError:
            self.pargs = pargs
            self.pargs['alpha'] = 0.0
            self._draw()

    def _draw(self):
        """Load all data, and generate this `SpectrumDataPlot`
        """
        plot = self.init_plot()
        ax = plot.gca()
        ax.grid(visible=True, axis='both', which='both')

        if self.state:
            self.pargs.setdefault(
                'suptitle',
                '[%s-%s, state: %s]' % (self.span[0], self.span[1],
                                        texify(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # get spectrum format: 'amplitude' or 'power'
        sdform = self.pargs.pop('format')
        if sdform == 'rayleigh':
            method = 'rayleigh'
        else:
            method = None
        use_percentiles = str(
            self.pargs.pop('no-percentiles')).lower() == 'false'

        # parse plotting arguments
        plotargs = self.parse_plot_kwargs()
        legendargs = self.parse_legend_kwargs()
        use_legend = False

        # get reference arguments
        refs = self.parse_references()

        # add data
        if self.type == 'coherence-spectrum':
            iterator = list(zip(self.channels[0::2], self.channels[1::2],
                                plotargs))
        else:
            iterator = list(zip(self.channels, plotargs))

        for chantuple in iterator:
            channel = chantuple[0]
            channel2 = chantuple[1]
            pargs = chantuple[-1]

            if self.state and not self.all_data:
                valid = self.state
            else:
                valid = SegmentList([self.span])

            if self.type == 'coherence-spectrum':
                data = get_coherence_spectrum(
                    [str(channel), str(channel2)], valid, query=False)
            elif self.type == 'range-spectrum':
                data = get_range_spectrum(str(channel), valid, query=False,
                                          state=valid)
            elif self.type == 'cumulative-range-spectrum':
                data = get_range_spectrum(str(channel), valid, query=False,
                                          which='mean', state=valid)
                if str(data.unit) == 'Mpc':
                    data = (data**3).cumsum() ** (1/3.)
                else:
                    data = (data**2).cumsum() ** (1/2.)
                try:
                    data = (100 * data / data[-1],)
                except IndexError:
                    data = tuple()
            else:
                try:
                    data = get_spectrum(str(channel), valid, query=False,
                                        format=sdform, method=method,
                                        state=valid)
                except ValueError as exc:
                    # math op failed beacuse one of the datasets is empty
                    if (
                            'could not be broadcast' in str(exc) and
                            '(0,)' in str(exc)
                    ):
                        data = []
                    else:
                        raise

            # undo demodulation
            data = list(data)
            for i, spec in enumerate(data):
                data[i] = undo_demodulation(spec, channel,
                                            self.pargs.get('xlim', None))

            # anticipate log problems
            if self.logx:
                data = [s[1:] for s in data]
            if self.logy:
                for sp in data:
                    sp.value[sp.value == 0] = 1e-100

            if 'label' in pargs:
                use_legend = True

            if data and self.type == 'cumulative-range-spectrum':
                pargs_reverse = pargs.copy()
                pargs_reverse.pop('label', None)
                pargs_reverse['linestyle'] = 'dashed'
                # plot cumulative spectrum and its reverse
                ax.plot(data[0], **pargs)
                ax.plot(100 - data[0], **pargs_reverse)
            elif data and use_percentiles:
                _, minline, maxline, _ = ax.plot_mmm(*data, **pargs)
                # make min, max lines lighter:
                minline.set_alpha(pargs.get('alpha', .1) * 2)
                maxline.set_alpha(pargs.get('alpha', .1) * 2)
            elif data:
                ax.plot(data[0], **pargs)

        # display references
        for source, refparams in refs.items():
            refspec = io.read_frequencyseries(source)
            refparams.setdefault('zorder', -len(refs) + 1)
            if 'filter' in refparams:
                refspec = refspec.filter(*refparams.pop('filter'))
            if 'scale' in refparams:
                refspec *= refparams.pop('scale', 1)
            if 'label' in refparams:
                use_legend = True
            ax.plot(refspec, **refparams)

        # customise
        self.add_hvlines()
        self.apply_parameters(ax, **self.pargs)
        if use_legend or len(self.channels) > 1 or ax.legend_ is not None:
            ax.legend(**legendargs)

        return self.finalize()

    def parse_references(self, prefix=r'reference(\d+)?\Z'):
        """Parse parameters for displaying one or more reference traces
        """
        return self.parse_list('reference')


register_plot(SpectrumDataPlot)


class CoherenceSpectrumDataPlot(SpectrumDataPlot):
    """Coherence pectrum plot for a `SummaryTab`
    """
    type = 'coherence-spectrum'
    data = 'coherence-spectrogram'
    defaults = SpectrumDataPlot.defaults.copy()
    defaults.update({
        'yscale': 'linear',
        'alpha': 0.1,
    })

    # override this to allow us to set the legend manually
    def _parse_labels(self, defaults=None):
        return self.pargs.pop('labels', defaults)

    def get_channel_groups(self):
        """Hi-jacked method to return pairs of channels

        For the `CoherenceSpectrumDataPlot` this method is only used in
        determining how to separate lists of plotting argument given by
        the user.
        """
        all_ = self.allchannels
        return [(all_[i], all_[i:i+2]) for i in range(0, len(all_), 2)]


register_plot(CoherenceSpectrumDataPlot)


class TimeSeriesHistogramPlot(DataPlot):
    """HistogramPlot from a Series
    """
    type = 'histogram'
    data = 'timeseries'
    defaults = DataPlot.defaults.copy()
    defaults.update({
        'ylabel': 'Rate [Hz]',
        'log': True,
        'histtype': 'stepfilled',
        'rwidth': 1,
        'bottom': 1e-300,
    })

    def _update_defaults_from_channels(self):
        for channel in self.channels:
            if hasattr(channel, 'amplitude_range'):
                self.pargs.setdefault('xlim', channel.amplitude_range)
                break

    def init_plot(self, geometry=None, **kwargs):
        """Initialise the Figure and Axes objects for this
        `TimeSeriesDataPlot`.
        """
        if geometry is None and self.pargs.pop('sep', False):
            if self.data == 'segments':
                geometry = (len(self.flags), 1)
            else:
                geometry = (len(self.channels), 1)
        elif geometry is None:
            geometry = (1, 1)
        plot = super(TimeSeriesHistogramPlot, self).init_plot(
            geometry=geometry, **kwargs)
        for ax in plot.axes:
            ax.grid(visible=True, which='both')
        return plot

    def parse_plot_kwargs(self, **defaults):
        kwargs = super(TimeSeriesHistogramPlot, self).parse_plot_kwargs(
            **defaults)
        for histargs in kwargs:
            histargs.setdefault('logbins', self.logx)
            logy = histargs.get('log', False)
            self.pargs.setdefault('yscale', 'log' if logy else 'linear')
            # set range as xlim
            if 'range' not in histargs and 'xlim' in self.pargs:
                histargs['range'] = self.pargs.get('xlim')
            # set alpha
            if len(self.channels) > 1:
                histargs.setdefault('alpha', 0.7)
        return kwargs

    def draw(self, outputfile=None):
        """Get data and generate the figure.
        """
        # get plot and axes
        plot = self.init_plot()
        axes = plot.axes

        if self.state:
            self.pargs.setdefault(
                'suptitle',
                '[%s-%s, state: %s]' % (self.span[0], self.span[1],
                                        texify(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # extract histogram arguments
        histargs = self.parse_plot_kwargs()

        # get data
        data = []
        for channel in self.channels:
            if self.state and not self.all_data:
                valid = self.state.active
            else:
                valid = SegmentList([self.span])
            data.append(get_timeseries(channel, valid, query=False).join(
                gap='ignore', pad=numpy.nan))

        # plot
        for ax, arr, pargs in zip(cycle(axes), data, histargs):
            # set range
            pargs['range'] = self._get_range(
                data,
                # use range from first dataset if already calculated
                range=histargs[0].get('range'),
                # use xlim if manually set (user or INI)
                xlim=None if ax.get_autoscalex_on() else ax.get_xlim(),
            )

            # plot histogram
            _, _, patches = ax.hist(arr, **pargs)

            # update edge color of histogram to be tinted version of face
            if pargs.get('histtype', None) == 'stepfilled':
                for p in patches:
                    if not p.get_edgecolor()[3]:
                        p.set_edgecolor(tint(p.get_facecolor(), .7))

        # customise plot
        legendargs = self.parse_legend_kwargs()
        for i, ax in enumerate(axes):
            for key, val in self.pargs.items():
                if key == 'title' and i > 0:
                    continue
                if key == 'xlabel' and i < (len(axes) - 1):
                    continue
                if key == 'ylabel' and (
                        (len(axes) % 2 and i != len(axes) // 2) or
                        (len(axes) % 2 == 0 and i > 0)):
                    continue
                try:
                    getattr(ax, 'set_%s' % key)(val)
                except AttributeError:
                    setattr(ax, key, val)
            if len(self.channels) > 1:
                ax.legend(**legendargs)
        if len(axes) % 2 == 0 and axes[0].get_ylabel():
            label = axes[0].yaxis.label
            ax = axes[int(len(axes) // 2)-1]
            ax.set_ylabel(label.get_text())
            ax.yaxis.label.set_position((0, -.2 / len(axes)))
            if len(axes) != 2:
                label.set_text('')

        # add extra axes and finalise
        return self.finalize(outputfile=outputfile)

    def _get_range(self, data, range=None, xlim=None):
        if range is not None:
            if range == 'autoscaling':
                return None
            else:
                range = list(range)
                try:
                    if range[0] == 'min':
                        range[0] = numpy.min(data)
                    if range[1] == 'max':
                        range[1] = numpy.max(data)
                    if range[0] < range[1]:
                        return range
                except (ValueError, IndexError) as exc:
                    if not str(exc).startswith('zero-size array'):
                        raise
        if xlim is not None:
            return xlim
        try:
            return numpy.min(data), numpy.max(data)
        except ValueError as exc:
            if not str(exc).startswith('zero-size array'):
                raise
        return None


register_plot(TimeSeriesHistogramPlot)


class TimeSeriesHistogram2dDataPlot(TimeSeriesHistogramPlot):
    """DataPlot of the 2D histogram of two `TimeSeries`.
    """
    type = 'histogram2d'
    data = 'timeseries'
    defaults = TimeSeriesHistogramPlot.defaults.copy()
    defaults.update({
        'yscale': 'linear',
        'grid': 'both',
        'shading': 'flat',
        'cmap': 'inferno_r',
        'alpha': None,
        'edgecolors': 'None',
        'bins': 100,
        'normed': True
    })

    def __init__(self, *args, **kwargs):
        super(TimeSeriesHistogram2dDataPlot, self).__init__(*args, **kwargs)
        channels = self.channels
        if isinstance(channels, (list, tuple)) and len(channels) > 2:
            raise ValueError("Cannot generate TimeSeriesHistogram2dDataPlot "
                             " plot with more than 2 channels")

    def _update_defaults_from_channels(self):
        c1, c2 = self.channels
        self.pargs.setdefault('xlabel', texify(str(c1)))
        self.pargs.setdefault('ylabel', texify(str(c2)))
        if hasattr(c1, 'amplitude_range'):
            self.pargs.setdefault('xlim', c1.amplitude_range)
        if hasattr(c2, 'amplitude_range'):
            self.pargs.setdefault('ylim', c2.amplitude_range)

    def parse_hist_kwargs(self, **defaults):
        kwargs = {'bins': self.pargs.pop('bins'),
                  'normed': self.pargs.pop('normed')}
        if 'range' in self.pargs:
            ranges = [float(r) for r in self.pargs['range'].split(',')]
            kwargs['range'] = [[ranges[0], ranges[1]],
                               [ranges[2], ranges[3]]]
        elif 'xlim' in self.pargs and 'ylim' in self.pargs:
            xlim = self.pargs['xlim']
            ylim = self.pargs['ylim']
            kwargs['range'] = [[xlim[0], xlim[1]], [ylim[0], ylim[1]]]
        else:
            kwargs['range'] = None
        return kwargs

    def parse_pcmesh_kwargs(self, **defaults):
        kwargs = {
                  'cmap': self.pargs.pop('cmap'),
                  'edgecolors': self.pargs.pop('edgecolors'),
                  'shading': self.pargs.pop('shading'),
                  'alpha': self.pargs.pop('alpha')
                 }
        return kwargs

    def draw(self, outputfile=None):
        """Get data and generate the figure.
        """
        # get histogram parameters
        plot = self.init_plot()
        ax = plot.gca()

        if self.state:
            self.pargs.setdefault(
                'suptitle',
                '[%s-%s, state: %s]' % (self.span[0], self.span[1],
                                        texify(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')
        # get data
        data = []
        for channel in self.channels:
            if self.state and not self.all_data:
                valid = self.state.active
            else:
                valid = SegmentList([self.span])
            data.append(get_timeseries(channel, valid, query=False).join(
                gap='ignore', pad=numpy.nan))
        if len(data) == 1:
            data.append(data[0])
        # histogram
        hist_kwargs = self.parse_hist_kwargs()
        h, xedges, yedges = numpy.histogram2d(data[0], data[1], **hist_kwargs)
        h = numpy.ma.masked_where(h == 0, h)
        x, y = numpy.meshgrid(xedges, yedges, copy=False, sparse=True)
        # plot
        pcmesh_kwargs = self.parse_pcmesh_kwargs()
        ax.pcolormesh(x, y, h.T, **pcmesh_kwargs)

        # customise plot
        self.apply_parameters(ax, **self.pargs)

        return self.finalize(outputfile=outputfile)


register_plot(TimeSeriesHistogram2dDataPlot)


class SpectralVarianceDataPlot(SpectrumDataPlot):
    """SpectralVariance histogram plot for a `DataTab`
    """
    type = 'variance'
    data = 'spectrogram'
    defaults = SpectrumDataPlot.defaults.copy()
    defaults.update({
        'xscale': 'log',
        'yscale': 'log',
        'reference-linestyle': '--',
        'log': True,
        'nbins': 100,
    })

    def __init__(self, channels, *args, **kwargs):
        if isinstance(channels, (list, tuple)) and len(channels) > 1:
            raise ValueError("Cannot generate SpectralVariance plot with "
                             "more than 1 channel")
        super(SpectralVarianceDataPlot, self).__init__(
            channels, *args, **kwargs)

    def _update_defaults_from_channels(self):
        chan = self.channels[0]

        if getattr(chan, 'frequency_range', None) is not None:
            self.pargs.setdefault('xlim', chan.frequency_range)
            if isinstance(self.pargs['xlim'], Quantity):
                self.pargs['xlim'] = self.pargs['xlim'].value

        if hasattr(chan, 'asd_range'):
            self.pargs.setdefault('ylim', chan.asd_range)

        if hasattr(chan, 'asd_range'):
            low, high = chan.asd_range
            self.pargs.setdefault('low', low)
            self.pargs.setdefault('high', high)

    def parse_variance_kwargs(self):
        varargs = dict()
        for key in ['low', 'high', 'log', 'nbins', 'bins', 'density', 'norm']:
            if key in self.pargs:
                varargs[key] = self.pargs.pop(key)
        return varargs

    def _draw(self):
        """Load all data, and generate this `SpectrumDataPlot`
        """
        plot = self.init_plot()
        ax = plot.gca()

        if self.state:
            self.pargs.setdefault(
                'suptitle',
                '[%s-%s, state: %s]' % (self.span[0], self.span[1],
                                        texify(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # parse plotting arguments
        cmap = self.pargs.pop('cmap', None)
        varargs = self.parse_variance_kwargs()
        plotargs = self.parse_plot_kwargs()[0]

        # get reference arguments
        refs = self.parse_references()

        # calculate spectral variance and plot
        # pad data request to over-fill plots (no gaps at the end)
        if self.state and not self.all_data:
            valid = self.state.active
        else:
            valid = SegmentList([self.span])
        livetime = float(abs(valid))

        if livetime:
            plotargs.setdefault('vmin', 1/livetime)
        plotargs.setdefault('vmax', 1.)
        plotargs.pop('label')

        specgram = get_spectrogram(self.channels[0], valid, query=False,
                                   format='asd').join(gap='ignore')

        if specgram.size:
            asd = specgram.median(axis=0)
            asd.name = None
            variance = specgram.variance(**varargs)
            # normalize the variance
            variance /= livetime / specgram.dt.value
            # undo demodulation
            variance = undo_demodulation(variance, self.channels[0],
                                         ax.get_xlim())

            # plot
            ax.plot(asd, color='grey', linewidth=0.3)
            ax.imshow(variance, cmap=cmap, **plotargs)

        # display references
        for source, refparams in refs.items():
            refspec = io.read_frequencyseries(source)
            if 'filter' in refparams:
                refspec = refspec.filter(*refparams.pop('filter'))
            if 'scale' in refparams:
                refspec *= refparams.pop('scale', 1)
            ax.plot(refspec, **refparams)

        # customise
        self.add_hvlines()
        self.apply_parameters(ax, **self.pargs)
        ax.grid(visible=True, axis='both', which='both')

        return self.finalize()


register_plot(SpectralVarianceDataPlot)


class RayleighSpectrogramDataPlot(SpectrogramDataPlot):
    """Rayleigh statistic versino of `SpectrogramDataPlot`
    """
    type = 'rayleigh-spectrogram'
    data = 'rayleigh-spectrogram'
    defaults = SpectrogramDataPlot.defaults.copy()
    defaults.update({
        'ratio': None,
        'format': 'rayleigh',
        'clim': [0.25, 4],
        'cmap': 'BrBG_r',
        'colorlabel': 'Rayleigh statistic',
    })


register_plot(RayleighSpectrogramDataPlot)


class RayleighSpectrumDataPlot(SpectrumDataPlot):
    """Rayleigh statistic versino of `SpectrumDataPlot`
    """
    type = 'rayleigh-spectrum'
    data = 'rayleigh-spectrum'
    defaults = {'format': 'rayleigh',
                'xscale': 'log',
                'yscale': 'log',
                'alpha': 0.1,
                'zorder': 1,
                'no-percentiles': True,
                'reference-linestyle': '--'}


register_plot(RayleighSpectrumDataPlot)


def undo_demodulation(spec, channel, limits=None):
    if spec.size == 0:
        return spec
    # undo demodulation
    try:
        demod = channel.demodulation
    except AttributeError:
        return spec
    else:
        spec = spec[:]  # views data with copied metadata
        del spec.frequencies
        spec.f0 = demod
        # if physical frequency-range is below demod, get negative df
        try:
            low, high = channel.frequency_range
        except (AttributeError, TypeError):
            try:
                low, high = limits
            except TypeError:
                return spec
        high = Quantity(high, 'Hz')
        if high < spec.f0:
            if spec.ndim > 1:  # Spectrogram
                spec.value[:] = numpy.fliplr(spec.value)
            else:  # FrequencySeries
                spec.value[:] = spec.value[::-1]
            spec.df *= -1
            spec.frequencies = spec.frequencies[::-1]
        return spec