gwsumm/data/spectral.py

Summary

Maintainability
F
3 days
Test Coverage
# -*- coding: utf-8 -*-
# Copyright (C) Duncan Macleod (2013)
#
# This file is part of GWSumm.
#
# GWSumm is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GWSumm is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GWSumm.  If not, see <http://www.gnu.org/licenses/>.

"""Get spectrograms and spectra
"""

import os.path
import operator
import warnings
from functools import reduce
from collections import OrderedDict

# imports for filter
from math import pi  # noqa: F401

import numpy

from scipy import interpolate

from astropy import units

from gwpy.frequencyseries import FrequencySeries
from gwpy.spectrogram import SpectrogramList

from .. import (globalv, io)
from ..utils import (vprint, safe_eval)
from ..channels import (
    get_channel,
    split_combination as split_channel_combination,
)
from .utils import (use_segmentlist, make_globalv_key, get_fftparams)
from .mathutils import (get_with_math, parse_math_definition)
from .timeseries import (get_timeseries, get_timeseries_dict)

OPERATOR = {
    '*': operator.mul,
    '-': operator.sub,
    '+': operator.add,
    '/': operator.truediv,
}

__author__ = 'Duncan Macleod <duncan.macleod@ligo.org>'


# -- spectrogram --------------------------------------------------------------

@use_segmentlist
def get_spectrogram(channel, segments, config=None, cache=None,
                    query=True, nds=None, format='power', return_=True,
                    frametype=None, nproc=1, datafind_error='raise',
                    **fftparams):
    """Retrieve the time-series and generate a spectrogram of the given
    channel
    """
    channel = get_channel(channel)

    # read data for all sub-channels
    specs = []
    channels = split_channel_combination(channel)
    for c in channels:
        specs.append(_get_spectrogram(c, segments, config=config, cache=cache,
                                      query=query, nds=nds, format=format,
                                      return_=return_, frametype=frametype,
                                      nproc=nproc,
                                      datafind_error=datafind_error,
                                      **fftparams))
    if return_ and len(channels) == 1:
        return specs[0]
    elif return_:
        return get_with_math(
            channel, segments, _get_spectrogram, _get_spectrogram,
            config=config, query=False, format=format, return_=True,
            **fftparams)


@use_segmentlist
def _get_spectrogram(channel, segments, config=None, cache=None,
                     query=True, nds=None, format='power', return_=True,
                     frametype=None, nproc=1,
                     datafind_error='raise', **fftparams):
    channel = get_channel(channel)

    # if we aren't given a method, check to see whether data have already
    # been processed, if so, choose that one
    if fftparams.get('method', None) is None:
        methods = set([key.split(';')[1] for key in globalv.SPECTROGRAMS
                       if key.startswith('%s;' % channel.ndsname)])
        try:
            fftparams['method'] = list(methods)[0]
        except IndexError:
            fftparams['method'] = 'welch'

    # clean fftparams dict using channel default values
    fftparams = get_fftparams(channel, **fftparams)
    # override special-case methods
    if format in ['rayleigh']:
        fftparams.method = format

    # key used to store the coherence spectrogram in globalv
    key = make_globalv_key(channel, fftparams)

    # keep FftParams as a dict for convenience
    fftparams = fftparams.dict()

    # read segments from global memory
    havesegs = globalv.SPECTROGRAMS.get(key, SpectrogramList()).segments
    new = segments - havesegs
    query &= abs(new) != 0

    globalv.SPECTROGRAMS.setdefault(key, SpectrogramList())

    if query:
        # extract spectrogram stride from dict
        try:
            stride = float(fftparams.pop('stride'))
        except (TypeError, KeyError) as e:
            msg = ('cannot parse a spectrogram stride from the kwargs '
                   'given, please give some or all of fftlength, overlap, '
                   'stride')
            if isinstance(e, TypeError):
                e.args = (msg,)
                raise
            raise TypeError(msg)

        # read channel information
        try:
            filter_ = channel.frequency_response
        except AttributeError:
            filter_ = None
        else:
            if isinstance(filter_, str) and os.path.isfile(filter_):
                filter_ = io.read_frequencyseries(filter_)
            elif isinstance(filter_, str):
                filter_ = safe_eval(filter_, strict=True)

        # get time-series data
        timeserieslist = get_timeseries(channel, new, config=config,
                                        cache=cache, frametype=frametype,
                                        nproc=nproc, query=query,
                                        datafind_error=datafind_error, nds=nds)
        # calculate spectrograms
        if len(timeserieslist):
            vprint("    Calculating (%s) spectrograms for %s"
                   % (fftparams['method'], str(channel)))
        for ts in timeserieslist:
            # if too short for a single segment, continue
            if abs(ts.span) < (stride + fftparams.get('overlap', 0)):
                continue
            # truncate timeseries to integer number of strides
            d = size_for_spectrogram(ts.duration.to('s').value, stride,
                                     fftparams['fftlength'],
                                     fftparams.get('overlap', 0))
            ts = ts.crop(ts.span[0], ts.span[0] + d, copy=False)
            # calculate spectrogram
            try:
                # rayleigh spectrogram has its own instance method
                if fftparams.get('method', None) == 'rayleigh':
                    spec_kw = fftparams.copy()
                    for fftkey in ('method', 'scheme',):  # remove ASD keys
                        spec_kw.pop(fftkey, None)
                    spec_func = ts.rayleigh_spectrogram
                else:
                    spec_kw = fftparams
                    spec_func = ts.spectrogram
                specgram = spec_func(stride, nproc=nproc, **spec_kw)
            except ZeroDivisionError:
                if stride == 0:
                    raise ZeroDivisionError("Spectrogram stride is 0")
                elif fftparams['fftlength'] == 0:
                    raise ZeroDivisionError("FFT length is 0")
                else:
                    raise
            except ValueError as e:
                if 'has no unit' in str(e):
                    unit = ts.unit
                    ts._unit = units.Unit('count')
                    specgram = ts.spectrogram(stride, nproc=nproc, **fftparams)
                    specgram._unit = unit ** 2 / units.Hertz
                else:
                    raise
            if isinstance(filter_, FrequencySeries) and (
                    fftparams['method'] not in ['rayleigh']):
                specgram = apply_transfer_function_series(specgram, filter_)
            elif filter_ and fftparams['method'] not in ['rayleigh']:
                # manually setting x0 is a hack against precision error
                # somewhere inside the **(1/2.) operation (Quantity)
                x0 = specgram.x0.value
                specgram = (specgram ** (1/2.)).filter(*filter_,
                                                       inplace=True) ** 2
                specgram.x0 = x0
            if specgram.unit is None:
                specgram._unit = channel.unit
            elif len(globalv.SPECTROGRAMS[key]):
                specgram._unit = globalv.SPECTROGRAMS[key][-1].unit
            add_spectrogram(specgram, key=key)
            vprint('.')
        if len(timeserieslist):
            vprint('\n')

    if not return_:
        return

    # return correct data
    out = SpectrogramList()
    for specgram in globalv.SPECTROGRAMS[key]:
        for seg in segments:
            if abs(seg) < specgram.dt.value:
                continue
            if specgram.span.intersects(seg):
                common = specgram.span & type(seg)(seg[0],
                                                   seg[1] + specgram.dt.value)
                s = specgram.crop(*common)
                if format in ['amplitude', 'asd']:
                    s = s**(1/2.)
                elif format in ['rayleigh']:
                    # XXX FIXME: this corrects the bias offset in Rayleigh
                    med = numpy.median(s.value)
                    s /= med
                if s.shape[0]:
                    out.append(s)
    return out.coalesce()


def add_spectrogram(specgram, key=None, coalesce=True):
    """Add a `Spectrogram` to the global memory cache
    """
    if key is None:
        key = specgram.name or str(specgram.channel)
    globalv.SPECTROGRAMS.setdefault(key, SpectrogramList())
    globalv.SPECTROGRAMS[key].append(specgram)
    if coalesce:
        globalv.SPECTROGRAMS[key].coalesce()


@use_segmentlist
def get_spectrograms(channels, segments, config=None, cache=None, query=True,
                     nds=None, format='power', return_=True, frametype=None,
                     nproc=1, datafind_error='raise', **fftparams):
    """Get spectrograms for multiple channels
    """
    channels = list(map(get_channel, channels))

    # get timeseries data in bulk
    if query:
        # get underlying list of data channels to read
        qchannels = list(map(get_channel,
                         set([c for group in
                              map(split_channel_combination, channels)
                              for c in group])))

        # work out FFT params and storage keys for each data channel
        keys = []
        for channel in qchannels:
            fftparams_ = get_fftparams(channel, **fftparams)
            keys.append(make_globalv_key(channel, fftparams_))

        # restrict segments to those big enough to hold >= 1 stride
        strides = set([getattr(c, 'stride', 0) for c in qchannels])
        if len(strides) == 1:
            stride = strides.pop()
            segments = type(segments)(s for s in segments if abs(s) >= stride)

        # work out new segments for which to read data
        havesegs = reduce(operator.and_, (globalv.SPECTROGRAMS.get(
            key, SpectrogramList()).segments for key in keys))
        new = segments - havesegs

        # read data for new segments
        get_timeseries_dict(qchannels, new, config=config, cache=cache,
                            nproc=nproc, frametype=frametype,
                            datafind_error=datafind_error, nds=nds,
                            return_=False)
    # loop over channels and generate spectrograms
    out = OrderedDict()
    for channel in channels:
        out[channel] = get_spectrogram(
            channel, segments, config=config, cache=cache, query=query,
            nds=nds, format=format, nproc=nproc,
            return_=return_, datafind_error=datafind_error,
            **fftparams)
    return out


def size_for_spectrogram(size, stride, fftlength, overlap):
    if size < stride:
        return None
    x = size // stride * stride + overlap
    if x > size:
        x -= fftlength
    return x


def apply_transfer_function_series(specgram, tfunc):
    """Multiply a spectrogram by a transfer function `FrequencySeries`

    This method interpolates the transfer function onto the frequency vector
    of the spectrogram, so should work regardless of the inputs
    """
    # interpolate transfer function onto spectrogram frequency series
    interpolator = interpolate.interp1d(tfunc.frequencies.value, tfunc.value)
    itfunc = numpy.zeros((1, specgram.frequencies.size))
    known = specgram.frequencies.value >= tfunc.frequencies.value[0]
    known &= specgram.frequencies.value <= tfunc.frequencies.value[-1]
    itfunc[0, :][known] = interpolator(specgram.frequencies.value[known])
    # and multiply
    return (specgram ** (1/2.) * itfunc) ** 2


# -- spectrum -----------------------------------------------------------------

@use_segmentlist
def get_spectrum(channel, segments, config=None, cache=None,
                 query=True, nds=None, format='power', return_=True,
                 frametype=None, nproc=1, datafind_error='raise',
                 state=None, **fftparams):
    """Retrieve the time-series and generate a spectrum of the given
    channel
    """
    channel = get_channel(channel)

    # read data for all sub-channels
    specs = []
    channels = list(parse_math_definition(str(channel))[0])
    if len(channels) == 0:
        channels = [channel]
    for c in channels:
        specs.append(_get_spectrum(c, segments, config=config, cache=cache,
                                   query=query, nds=nds, format=format,
                                   return_=return_, frametype=frametype,
                                   nproc=nproc,
                                   datafind_error=datafind_error,
                                   state=state,
                                   **fftparams))
    if return_ and len(channels) == 1:
        return specs[0]
    elif return_:
        return [get_with_math(
                    channel, segments, _get_spectrum, _get_spectrum,
                    config=config, format=format, return_=True,
                    which=which)[0] for which in ['mean', 'min', 'max']]


def _get_spectrum(channel, segments, config=None, cache=None, query=True,
                  nds=None, format='power', return_=True, which='all',
                  state=None, **fftparams):
    """Retrieve the time-series and generate a spectrum of the given
    channel
    """
    channel = get_channel(channel)

    name = f'{channel.ndsname},{format}'
    if state:
        name = f'{channel.ndsname},{state},{format}'
    cmin = f'{name}.min'
    cmax = f'{name}.max'

    if name not in globalv.SPECTRUM:
        if os.path.isfile(channel.ndsname):
            globalv.SPECTRUM[name] = io.read_frequencyseries(channel.ndsname)
            globalv.SPECTRUM[cmin] = globalv.SPECTRUM[name]
            globalv.SPECTRUM[cmax] = globalv.SPECTRUM[name]
        else:
            fftparams.setdefault('fftlength', 1)
            fftparams.setdefault('overlap', 0.5)
            if 'stride' not in fftparams and 'fftlength' in fftparams:
                fftparams.setdefault('stride', fftparams['fftlength'])

            speclist = get_spectrogram(channel, segments, config=config,
                                       cache=cache, query=query, nds=nds,
                                       format=format, **fftparams)
            try:
                specgram = speclist.join(gap='ignore')
            except ValueError as e:
                if 'units do not match' in str(e):
                    warnings.warn(str(e))
                    for spec in speclist[1:]:
                        spec.unit = speclist[0].unit
                    specgram = speclist.join(gap='ignore')
                else:
                    raise
            try:
                globalv.SPECTRUM[name] = specgram.percentile(50)
            except (ValueError, IndexError):
                globalv.SPECTRUM[name] = FrequencySeries(
                    [], channel=channel, f0=0, df=1, unit=units.Unit(''))
                globalv.SPECTRUM[cmin] = globalv.SPECTRUM[name]
                globalv.SPECTRUM[cmax] = globalv.SPECTRUM[name]
            else:
                globalv.SPECTRUM[cmin] = specgram.percentile(5)
                globalv.SPECTRUM[cmax] = specgram.percentile(95)

    if not return_:
        return

    if which == 'all':
        return (globalv.SPECTRUM[name], globalv.SPECTRUM[cmin],
                globalv.SPECTRUM[cmax])
    if which == 'mean':
        return globalv.SPECTRUM[name]
    if which == 'min':
        return globalv.SPECTRUM[cmin]
    if which == 'max':
        return globalv.SPECTRUM[cmax]
    raise ValueError(f"Unrecognised value for `which`: {which}")