gwpy/plot/plot.py

Summary

Maintainability
A
2 hrs
Test Coverage
# -*- coding: utf-8 -*-
# Copyright (C) Louisiana State University (2014-2017)
#               Cardiff University (2017-2022)
#
# This file is part of GWpy.
#
# GWpy is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GWpy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GWpy.  If not, see <http://www.gnu.org/licenses/>.

"""Extension of the basic matplotlib Figure for GWpy
"""

import itertools
import importlib
from collections.abc import (KeysView, ValuesView)
from itertools import zip_longest

import numpy

from matplotlib import (figure, get_backend, _pylab_helpers)
from matplotlib.artist import setp
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import LogFormatterSciNotation
from matplotlib.projections import get_projection_class

from . import (colorbar as gcbar, utils)
from .gps import GPS_SCALES
from .log import LogFormatter
from .rc import (rcParams, MPL_RCPARAMS, get_subplot_params)

__all__ = ['Plot']

try:
    __IPYTHON__
except NameError:
    IPYTHON = False
else:
    IPYTHON = True

iterable_types = (list, tuple, KeysView, ValuesView,)


def interactive_backend():
    """Returns `True` if the current backend is interactive
    """
    from matplotlib.rcsetup import interactive_bk
    return get_backend() in interactive_bk


def get_backend_mod(name=None):
    """Returns the imported module for the given backend name

    Parameters
    ----------
    name : `str`, optional
        the name of the backend, defaults to the current backend.

    Returns
    -------
    backend_mod: `module`
        the module as returned by :func:`importlib.import_module`

    Examples
    --------
    >>> from gwpy.plot.plot import get_backend_mod
    >>> print(get_backend_mod('agg'))
    <module 'matplotlib.backends.backend_agg' from ... >
    """
    if name is None:
        name = get_backend()
    backend_name = (
        name[9:] if name.startswith("module://")
        else f"matplotlib.backends.backend_{name.lower()}"
    )
    return importlib.import_module(backend_name)


class Plot(figure.Figure):
    """An extension of the core matplotlib `~matplotlib.figure.Figure`

    The `Plot` provides a number of methods to simplify generating
    figures from GWpy data objects, and modifying them on-the-fly in
    interactive mode.
    """
    def __init__(self, *data, **kwargs):

        # get default x-axis scale if all axes have the same x-axis units
        kwargs.setdefault('xscale', _parse_xscale(
            _group_axes_data(data, flat=True)))

        # set default size for time-axis figures
        if (
            kwargs.get('projection', None) == 'segments'
            or kwargs.get('xscale') in GPS_SCALES
        ):
            kwargs.setdefault('figsize', (12, 6))
            kwargs.setdefault('xscale', 'auto-gps')

        # initialise figure
        figure_kw = {key: kwargs.pop(key) for key in utils.FIGURE_PARAMS if
                     key in kwargs}
        self._init_figure(**figure_kw)

        # initialise axes with data
        if data or kwargs.get("geometry"):
            self._init_axes(data, **kwargs)

    def _init_figure(self, **kwargs):
        from matplotlib import pyplot

        # add new attributes
        self.colorbars = []
        self._coloraxes = []

        # create Figure
        num = kwargs.pop('num', max(pyplot.get_fignums() or {0}) + 1)
        self._parse_subplotpars(kwargs)
        super().__init__(**kwargs)
        self.number = num

        # add interactivity (scraped from pyplot.figure())
        backend_mod = get_backend_mod()
        try:
            manager = backend_mod.new_figure_manager_given_figure(num, self)
        except AttributeError:
            upstream_mod = importlib.import_module(
                pyplot.new_figure_manager.__module__)
            canvas = upstream_mod.FigureCanvasBase(self)
            manager = upstream_mod.FigureManagerBase(canvas, 1)
        manager._cidgcf = manager.canvas.mpl_connect(
            'button_press_event',
            lambda ev: _pylab_helpers.Gcf.set_active(manager))
        _pylab_helpers.Gcf.set_active(manager)
        pyplot.draw_if_interactive()

    def _init_axes(self, data, method='plot',
                   xscale=None, sharex=False, sharey=False,
                   geometry=None, separate=None, **kwargs):
        """Populate this figure with data, creating `Axes` as necessary
        """
        if isinstance(sharex, bool):
            sharex = "all" if sharex else "none"
        if isinstance(sharey, bool):
            sharey = "all" if sharey else "none"

        # parse keywords
        axes_kw = {key: kwargs.pop(key) for key in utils.AXES_PARAMS if
                   key in kwargs}

        # handle geometry and group axes
        if geometry is not None and geometry[0] * geometry[1] == len(data):
            separate = True
        axes_groups = _group_axes_data(data, separate=separate)
        if geometry is None:
            geometry = (len(axes_groups), 1)
        nrows, ncols = geometry
        if axes_groups and nrows * ncols != len(axes_groups):
            # mismatching data and geometry
            raise ValueError(
                f"cannot group data into {len(axes_groups)} with "
                f"a {nrows}x{ncols} grid"
            )

        # create grid spec
        gs = GridSpec(nrows, ncols)
        axarr = numpy.empty((nrows, ncols), dtype=object)

        # set default labels
        defxlabel = 'xlabel' not in axes_kw
        defylabel = 'ylabel' not in axes_kw
        flatdata = [s for group in axes_groups for s in group]
        for axis in ('x', 'y'):
            unit = _common_axis_unit(flatdata, axis=axis)
            if unit:
                axes_kw.setdefault(
                    f"{axis}label",
                    unit.to_string('latex_inline_dimensional'),
                )

        # create axes for each group and draw each data object
        for group, (row, col) in zip_longest(
                axes_groups, itertools.product(range(nrows), range(ncols)),
                fillvalue=[]):
            # create Axes
            shared_with = {"none": None, "all": axarr[0, 0],
                           "row": axarr[row, 0], "col": axarr[0, col]}
            axes_kw["sharex"] = shared_with[sharex]
            axes_kw["sharey"] = shared_with[sharey]
            axes_kw['xscale'] = xscale if xscale else _parse_xscale(group)
            ax = axarr[row, col] = self.add_subplot(gs[row, col], **axes_kw)

            # plot data
            plot_func = getattr(ax, method)
            if method in ('imshow', 'pcolormesh'):
                for obj in group:
                    plot_func(obj, **kwargs)
            elif group:
                plot_func(*group, **kwargs)

            # set default axis labels
            for axis, share, pos, n, def_ in (
                    (ax.xaxis, sharex, row, nrows, defxlabel),
                    (ax.yaxis, sharey, col, ncols, defylabel),
            ):
                # hide label if shared axis and not bottom left panel
                if share == 'all' and pos < n - 1:
                    axis.set_label_text('')
                # otherwise set default status
                else:
                    axis.isDefault_label = def_

        return self.axes

    @staticmethod
    def _parse_subplotpars(kwargs):
        # dynamically set the subplot positions based on the figure size
        # -- only if the user hasn't customised the subplot params
        figsize = kwargs.get('figsize') or rcParams['figure.figsize']
        subplotpars = get_subplot_params(figsize)
        use_subplotpars = (
            'subplotpars' not in kwargs
            and all([
                rcParams[f"figure.subplot.{pos}"]
                == MPL_RCPARAMS[f"figure.subplot.{pos}"]
                for pos in ('left', 'bottom', 'right', 'top')
            ])
        )
        if use_subplotpars:
            kwargs['subplotpars'] = subplotpars

    # -- Plot methods ---------------------------

    def refresh(self):
        """Refresh the current figure
        """
        for cbar in self.colorbars:
            cbar.draw_all()
        self.canvas.draw()

    def show(self, block=None, warn=True):
        """Display the current figure (if possible).

        If blocking, this method replicates the behaviour of
        :func:`matplotlib.pyplot.show()`, otherwise it just calls up to
        :meth:`~matplotlib.figure.Figure.show`.

        This method also supports repeatedly showing the same figure, even
        after closing the display window, which isn't supported by
        `pyplot.show` (AFAIK).

        Parameters
        ----------
        block : `bool`, optional
            open the figure and block until the figure is closed, otherwise
            open the figure as a detached window, default: `None`.
            If `None`, block if using an interactive backend and _not_
            inside IPython.

        warn : `bool`, optional
            print a warning if matplotlib is not running in an interactive
            backend and cannot display the figure, default: `True`.
        """
        # this method tries to reproduce the functionality of pyplot.show,
        # mainly for user convenience. However, as of matplotlib-3.0.0,
        # pyplot.show() ends up calling _back_ to Plot.show(),
        # so we have to be careful not to end up in a recursive loop
        import inspect
        try:
            callframe = inspect.currentframe().f_back
        except AttributeError:
            pass
        else:
            if 'matplotlib' in callframe.f_code.co_filename:
                block = False

        # render
        super().show(warn=warn)

        # don't block on ipython with interactive backends
        if block is None and interactive_backend():
            block = not IPYTHON

        # block in GUI loop (stolen from mpl.backend_bases._Backend.show)
        if block:
            backend_mod = get_backend_mod()
            backend_mod.Show().mainloop()

    def save(self, *args, **kwargs):
        """Save the figure to disk.

        This method is an alias to :meth:`~matplotlib.figure.Figure.savefig`,
        all arguments are passed directory to that method.
        """
        self.savefig(*args, **kwargs)

    def close(self):
        """Close the plot and release its memory.
        """
        from matplotlib.pyplot import close
        for ax in self.axes[::-1]:
            # avoid matplotlib/matplotlib#9970
            ax.set_xscale('linear')
            ax.set_yscale('linear')
            # clear the axes
            ax.cla()
        # close the figure
        close(self)

    # -- axes manipulation ----------------------

    def get_axes(self, projection=None):
        """Find all `Axes`, optionally matching the given projection

        Parameters
        ----------
        projection : `str`
            name of axes types to return

        Returns
        -------
        axlist : `list` of `~matplotlib.axes.Axes`
        """
        if projection is None:
            return self.axes
        return [ax for ax in self.axes if ax.name == projection.lower()]

    # -- colour bars ----------------------------

    def colorbar(
        self,
        mappable=None,
        cax=None,
        ax=None,
        fraction=0.,
        use_axesgrid=True,
        emit=True,
        **kwargs,
    ):
        """Add a colorbar to the current `Plot`.

        This method differs from the default
        :meth:`matplotlib.figure.Figure.colorbar` in that it doesn't
        resize the parent `Axes` to accommodate the colorbar, but rather
        draws a new Axes alongside it.

        Parameters
        ----------
        mappable : matplotlib data collection
            Collection against which to map the colouring

        cax : `~matplotlib.axes.Axes`
            Axes on which to draw colorbar

        ax : `~matplotlib.axes.Axes`
            Axes relative to which to position colorbar

        fraction : `float`, optional
            Fraction of original axes to use for colorbar.
            The default (``fraction=0``) is to not resize the
            original axes at all.

        emit : `bool`, optional
            If `True` update all mappables on `Axes` to match the same
            colouring as the colorbar.

        **kwargs
            other keyword arguments to be passed to the
            :meth:`~matplotlib.figure.Figure.colorbar`

        Returns
        -------
        cbar : `~matplotlib.colorbar.Colorbar`
            the newly added `Colorbar`

        Notes
        -----
        To revert to the default matplotlib behaviour, pass
        ``use_axesgrid=False, fraction=0.15``.

        See also
        --------
        matplotlib.figure.Figure.colorbar
        matplotlib.colorbar.Colorbar

        Examples
        --------
        >>> import numpy
        >>> from gwpy.plot import Plot

        To plot a simple image and add a colorbar:

        >>> plot = Plot()
        >>> ax = plot.gca()
        >>> ax.imshow(numpy.random.randn(120).reshape((10, 12)))
        >>> plot.colorbar(label='Value')
        >>> plot.show()

        Colorbars can also be generated by directly referencing the parent
        axes:

        >>> Plot = Plot()
        >>> ax = plot.gca()
        >>> ax.imshow(numpy.random.randn(120).reshape((10, 12)))
        >>> ax.colorbar(label='Value')
        >>> plot.show()
        """
        # pre-process kwargs (and maybe create new Axes)
        mappable, kwargs = gcbar.process_colorbar_kwargs(
            self,
            mappable,
            ax,
            cax=cax,
            fraction=fraction,
            **kwargs,
        )

        # generate colour bar
        cbar = super().colorbar(mappable, **kwargs)

        # force the minor ticks to be the same as the major ticks;
        # in practice, this normally swaps out LogFormatterSciNotation to
        # gwpy's LogFormatter;
        # this is hacky, and would be improved using a
        # subclass of Colorbar in the first place, but matplotlib's
        # cbar_factory doesn't support that
        longaxis = (
            cbar.ax.yaxis if cbar.orientation == "vertical"
            else cbar.ax.xaxis
        )
        if (
            isinstance(cbar.formatter, LogFormatter)
            and isinstance(
                longaxis.get_minor_formatter(),
                LogFormatterSciNotation,
            )
        ):
            longaxis.set_minor_formatter(type(cbar.formatter)())

        # record colorbar in parent object
        self.colorbars.append(cbar)

        # update mappables for this axis
        if emit:
            ax = kwargs.pop('ax')
            norm = mappable.norm
            cmap = mappable.get_cmap()
            for map_ in ax.collections + ax.images:
                map_.set_norm(norm)
                map_.set_cmap(cmap)

        return cbar

    # -- extra methods --------------------------

    def add_segments_bar(self, segments, ax=None, height=0.14, pad=0.1,
                         sharex=True, location='bottom', **plotargs):
        """Add a segment bar `Plot` indicating state information.

        By default, segments are displayed in a thin horizontal set of Axes
        sitting immediately below the x-axis of the main,
        similarly to a colorbar.

        Parameters
        ----------
        segments : `~gwpy.segments.DataQualityFlag`
            A data-quality flag, or `SegmentList` denoting state segments
            about this Plot

        ax : `Axes`, optional
            Specific `Axes` relative to which to position new `Axes`,
            defaults to :func:`~matplotlib.pyplot.gca()`

        height : `float, `optional
            Height of the new axes, as a fraction of the anchor axes

        pad : `float`, optional
            Padding between the new axes and the anchor, as a fraction of
            the anchor axes dimension

        sharex : `True`, `~matplotlib.axes.Axes`, optional
            Either `True` to set ``sharex=ax`` for the new segment axes,
            or an `Axes` to use directly

        location : `str`, optional
            Location for new segment axes, defaults to ``'bottom'``,
            acceptable values are ``'top'`` or ``'bottom'``.

        **plotargs
            extra keyword arguments are passed to
            :meth:`~gwpy.plot.SegmentAxes.plot`
        """
        # get axes to anchor against
        if not ax:
            ax = self.gca()

        # set options for new axes
        axes_kw = {
            'pad': pad,
            'sharex': ax if sharex is True else sharex or None,
            'axes_class': get_projection_class('segments'),
        }

        # map X-axis limit from old axes
        if axes_kw['sharex'] is ax and not ax.get_autoscalex_on():
            axes_kw['xlim'] = ax.get_xlim()

        # if axes uses GPS scaling, copy the epoch as well
        try:
            axes_kw['epoch'] = ax.get_epoch()
        except AttributeError:
            pass

        # add new axes
        try:
            divider = ax.get_axes_locator()._axes_divider
        except AttributeError:
            # get_axes_locator() is None _or_ the _axes_divider property
            # has been removed
            from mpl_toolkits.axes_grid1 import make_axes_locatable
            divider = make_axes_locatable(ax)
        if location not in {'top', 'bottom'}:
            raise ValueError("Segments can only be positoned at 'top' or "
                             "'bottom'.")
        segax = divider.append_axes(location, height, **axes_kw)

        # update anchor axes
        if axes_kw['sharex'] is ax and location == 'bottom':
            # map label
            segax.set_xlabel(ax.get_xlabel())
            segax.xaxis.isDefault_label = ax.xaxis.isDefault_label
            ax.set_xlabel("")
            # hide ticks on original axes
            setp(ax.get_xticklabels(), visible=False)

        # plot segments
        segax.plot(segments, **plotargs)
        segax.grid(False, which='both', axis='y')
        segax.autoscale(axis='y', tight=True)

        return segax


# -- utilities ----------------------------------------------------------------

def _group_axes_data(inputs, separate=None, flat=False):
    """Determine the number of axes from the input args to this `Plot`

    Parameters
    ----------
    inputs : `list` of array-like data sets
        A list of data arrays, or a list of lists of data sets

    sep : `bool`, optional
        Plot each set of data on a separate `Axes`

    flat : `bool`, optional
        Return a flattened list of data objects

    Returns
    -------
    axesdata : `list` of lists of array-like data
        A `list` with one element per required `Axes` containing the
        array-like data sets for those `Axes`, unless ``flat=True``
        is given.

    Notes
    -----
    The logic for this method is as follows:

    - if a `list` of data arrays are given, and `separate=False`, use 1 `Axes`
    - if a `list` of data arrays are given, and `separate=True`, use N `Axes,
      one for each data array
    - if a nested `list` of data arrays are given, ignore `sep` and
      use one `Axes` for each group of arrays.

    Examples
    --------
    >>> from gwpy.plot import Plot
    >>> Plot._group_axes_data([1, 2], separate=False)
    [[1, 2]]
    >>> Plot._group_axes_data([1, 2], separate=True)
    [[1], [2]]
    >>> Plot._group_axes_data([[1, 2], 3])
    [[1, 2], [3]]
    """
    # determine auto-separation
    if separate is None and inputs:
        # if given a nested list of data, multiple axes are required
        if any(isinstance(x, iterable_types + (dict,)) for x in inputs):
            separate = True
        # if data are of different types, default to separate
        elif not all(type(x) is type(inputs[0]) for x in inputs):  # noqa: E721
            separate = True

    # build list of lists
    out = []
    for x in inputs:
        if isinstance(x, dict):  # unwrap dict
            x = list(x.values())

        # new group from iterable, notes:
        #     the iterable is presumed to be a list of independent data
        #     structures, unless its a list of scalars in which case we
        #     should plot them all as one
        if (
                isinstance(x, (KeysView, ValuesView))
                or isinstance(x, (list, tuple)) and (
                    not x
                    or not numpy.isscalar(x[0])
                )
        ):
            out.append(x)

        # dataset starts a new group
        elif separate or not out:
            out.append([x])

        # dataset joins current group
        else:  # append input to most recent group
            out[-1].append(x)

    if flat:
        return [s for group in out for s in group]

    return out


def _common_axis_unit(data, axis='x'):
    units = set()
    uname = f"{axis}unit"
    for x in data:
        units.add(getattr(x, uname, None))
    if len(units) == 1:
        return units.pop()
    return None


def _parse_xscale(data):
    unit = _common_axis_unit(data, axis='x')
    if unit is None:
        return None
    if unit.physical_type == 'time':
        return 'auto-gps'