gwsumm/data/coherence.py
# -*- coding: utf-8 -*-
# Copyright (C) Duncan Macleod (2013)
#
# This file is part of GWSumm.
#
# GWSumm is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GWSumm is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GWSumm. If not, see <http://www.gnu.org/licenses/>.
"""Utilities for data handling and display
"""
import operator
import warnings
from functools import reduce
from itertools import zip_longest
from collections import OrderedDict
import numpy
from astropy import units
from gwpy.segments import (DataQualityFlag, SegmentList,
Segment, SegmentListDict)
from gwpy.frequencyseries import FrequencySeries
from gwpy.spectrogram import SpectrogramList
from .. import globalv
from ..utils import (vprint, safe_eval)
from ..channels import get_channel
from .utils import (use_segmentlist, get_fftparams, make_globalv_key)
from .timeseries import (get_timeseries, get_timeseries_dict)
__author__ = 'Duncan Macleod <duncan.macleod@ligo.org>'
@use_segmentlist
def get_coherence_spectrogram(channel_pair, segments, config=None,
cache=None, query=True, nds=None,
return_=True, frametype=None, nproc=1,
datafind_error='raise', return_components=False,
**fftparams):
"""Retrieve the time-series and generate a coherence spectrogram of
the two given channels
"""
specs = _get_coherence_spectrogram(channel_pair, segments,
config=config, cache=cache,
query=query, nds=nds,
return_=return_, frametype=frametype,
nproc=nproc,
datafind_error=datafind_error,
return_components=return_components,
**fftparams)
return specs
@use_segmentlist
def _get_coherence_spectrogram(channel_pair, segments, config=None,
cache=None, query=True, nds=None,
return_=True, frametype=None, nproc=1,
datafind_error='raise', return_components=False,
**fftparams):
channel1 = get_channel(channel_pair[0])
channel2 = get_channel(channel_pair[1])
# clean fftparams dict using channel 1 default values
fftparams.setdefault('method', 'welch')
fftparams = get_fftparams(channel1, **fftparams)
# key used to store the coherence spectrogram in globalv
key = make_globalv_key(channel_pair, fftparams)
# keys used to store component spectrograms in globalv
components = ('Cxy', 'Cxx', 'Cyy')
ckeys = [
make_globalv_key([channel1, channel2], fftparams),
make_globalv_key(channel1, fftparams),
make_globalv_key(channel2, fftparams),
]
# convert fftparams to regular dict
fftparams = fftparams.dict()
# work out what new segments are needed
# need to truncate to segments of integer numbers of strides
stride = float(fftparams.pop('stride'))
overlap = float(fftparams['overlap'])
new = type(segments)()
for seg in segments - globalv.SPECTROGRAMS.get(
key, SpectrogramList()).segments:
dur = float(abs(seg)) // stride * stride
if dur < stride + overlap:
continue
new.append(type(seg)(seg[0], seg[0]+dur))
# extract FFT params for TimeSeries.spectrogram
spec_fftparams = fftparams.copy()
for fftkey in ('method', 'scheme',):
fftparams.pop(fftkey, None)
# if there are no existing spectrogram, initialize as a list
globalv.SPECTROGRAMS.setdefault(key, SpectrogramList())
# XXX HACK: use dummy timeseries to find lower sampling rate
if len(segments) > 0:
s = segments[0].start
dts1 = get_timeseries(channel1, SegmentList([Segment(s, s+1)]),
config=config, cache=cache, frametype=frametype,
nproc=nproc, query=query,
datafind_error=datafind_error, nds=nds)
dts2 = get_timeseries(channel2, SegmentList([Segment(s, s+1)]),
config=config, cache=cache, frametype=frametype,
nproc=nproc, query=query,
datafind_error=datafind_error, nds=nds)
try:
sampling = min(dts1[0].sample_rate.value,
dts2[0].sample_rate.value)
except IndexError:
sampling = None
else:
sampling = None
# initialize component lists if they don't exist yet
# store compnents in a backup variable
coherence_bkp = {}
for ck in ckeys:
globalv.COHERENCE_COMPONENTS.setdefault(ck, SpectrogramList())
coherence_bkp[ck] = globalv.COHERENCE_COMPONENTS.get(
ck, SpectrogramList())
# When coherence components contain different segments,
# computing coherence for new segments can result in a
# ValueError traceback due to incompatible shapes.
# To prevent this issue, we collect segments from all components,
# clear them from the global variable, and then restore from the
# coherence_bkp only the data that the segments available in
# coherence all components.
# get the segment spans from all components
spans = SegmentListDict()
for ck in ckeys:
spans[ck] = SegmentList(
[spec.span for spec in globalv.COHERENCE_COMPONENTS[ck]])
# keep only the intersection of the segments
spans = spans.intersection(list(ckeys))
# clean the components in the global variable
globalv.COHERENCE_COMPONENTS.update(
{ck: SpectrogramList() for ck in ckeys})
# restore the data for segments available in all components
for seg in spans:
for ck in ckeys:
spec = _get_from_list(coherence_bkp[ck], seg)
add_coherence_component_spectrogram(spec, key=ck)
# explicitly delete the backup variable to decrease RAM consuption
del coherence_bkp
# get data if query=True or if there are new segments
query &= abs(new) != 0
if query:
# the intersecting segments will be calculated when needed
intersection = None
# loop over components needed to calculate coherence
for comp in components:
# key used to store this component in globalv (incl sample rate)
ckey = ckeys[components.index(comp)]
try:
filter_ = channel1.frequency_response
except AttributeError:
filter_ = None
else:
if isinstance(filter_, str):
filter_ = safe_eval(filter_, strict=True)
# check how much of this component still needs to be calculated
req = new - globalv.COHERENCE_COMPONENTS.get(
ckey, SpectrogramList()).segments
# get data if there are new segments
if abs(req) != 0:
# calculate intersection of segments lazily
# this should occur on first pass (Cxy)
if intersection is None:
total1 = get_timeseries(
channel1, req, config=config,
cache=cache, frametype=frametype,
nproc=nproc, query=query,
datafind_error=datafind_error,
nds=nds)
total2 = get_timeseries(
channel2, req, config=config,
cache=cache, frametype=frametype,
nproc=nproc, query=query,
datafind_error=datafind_error,
nds=nds)
intersection = total1.segments & total2.segments
# get required timeseries data (using intersection)
tslist1, tslist2 = [], []
if comp in ('Cxy', 'Cxx'):
tslist1 = get_timeseries(
channel1, intersection, config=config,
cache=cache, frametype=frametype,
nproc=nproc, query=query,
datafind_error=datafind_error,
nds=nds)
if comp in ('Cxy', 'Cyy'):
tslist2 = get_timeseries(
channel2, intersection, config=config,
cache=cache, frametype=frametype,
nproc=nproc, query=query,
datafind_error=datafind_error,
nds=nds)
# calculate component
if len(tslist1) + len(tslist2):
vprint(" Calculating component %s for coherence "
"spectrogram for %s and %s @ %d Hz" % (
comp, str(channel1), str(channel2), sampling))
for ts1, ts2 in zip_longest(tslist1, tslist2):
# ensure there is enough data to do something with
if comp in ('Cxx', 'Cxy') and abs(ts1.span) < stride:
continue
elif comp in ('Cyy', 'Cxy') and abs(ts2.span) < stride:
continue
# downsample if necessary
if ts1 is not None and ts1.sample_rate.value != sampling:
ts1 = ts1.resample(sampling)
if ts2 is not None and ts2.sample_rate.value != sampling:
ts2 = ts2.resample(sampling)
# ignore units when calculating coherence
if ts1 is not None:
ts1._unit = units.Unit('count')
if ts2 is not None:
ts2._unit = units.Unit('count')
# calculate the component spectrogram
if comp == 'Cxy':
specgram = ts1.csd_spectrogram(
ts2, stride, nproc=nproc, **fftparams)
elif comp == 'Cxx':
specgram = ts1.spectrogram(stride, nproc=nproc,
**spec_fftparams)
elif comp == 'Cyy':
specgram = ts2.spectrogram(stride, nproc=nproc,
**spec_fftparams)
if filter_:
specgram = (specgram**(1/2.)).filter(*filter_,
inplace=True) ** 2
add_coherence_component_spectrogram(specgram, key=ckey)
vprint('.')
if len(tslist1) + len(tslist2):
vprint('\n')
spans = [SegmentList([ # record spectrogaram spans
spec.span for spec in globalv.COHERENCE_COMPONENTS[ck]
]) for ck in ckeys]
new = reduce(operator.and_, spans, new).coalesce()
for seg in new: # compute coherence from components
cxy, cxx, cyy = [_get_from_list(
globalv.COHERENCE_COMPONENTS[ck], seg) for ck in ckeys]
csg = abs(cxy)**2 / cxx / cyy
globalv.SPECTROGRAMS[key].append(csg)
globalv.SPECTROGRAMS[key].coalesce()
if not return_:
return
elif return_components:
# return list of component spectrogram lists
out = [SpectrogramList(), SpectrogramList(), SpectrogramList()]
for comp in components:
index = components.index(comp)
ckey = ckeys[index]
for specgram in globalv.COHERENCE_COMPONENTS[ckey]:
for seg in segments:
if abs(seg) < specgram.dt.value:
continue
if specgram.span.intersects(seg):
common = specgram.span & type(seg)(
seg[0], seg[1] + specgram.dt.value)
s = specgram.crop(*common)
if s.shape[0]:
out[index].append(s)
out[index] = out[index].coalesce()
return out
else:
# return list of coherence spectrograms
out = SpectrogramList()
for specgram in globalv.SPECTROGRAMS[key]:
for seg in segments:
if abs(seg) < specgram.dt.value:
continue
if specgram.span.intersects(seg):
common = specgram.span & type(seg)(
seg[0], seg[1] + specgram.dt.value)
s = specgram.crop(*common)
if s.shape[0]:
out.append(s)
return out.coalesce()
def get_coherence_spectrum(channel_pair, segments, config=None,
cache=None, query=True, nds=None, return_=True,
**fftparams):
"""Retrieve the time-series and generate a coherence spectrogram of the
given channel
"""
channel1 = get_channel(channel_pair[0])
channel2 = get_channel(channel_pair[1])
if isinstance(segments, DataQualityFlag):
name = ','.join([channel1.ndsname, channel2.ndsname, segments.name])
segments = segments.active
else:
name = channel1.ndsname + ',' + channel2.ndsname
name += ',%s' % format
cmin = name + '.min'
cmax = name + '.max'
if name not in globalv.COHERENCE_SPECTRUM:
vprint(" Calculating 5/50/95 percentile spectra for %s"
% name.rsplit(',', 1)[0])
# ask for a list of component spectrograms (a list of SpectrogramLists)
speclist = get_coherence_spectrogram(
channel_pair, segments, config=config, cache=cache,
query=query, nds=nds, return_components=True,
**fftparams)
cdict = {}
components = ('Cxy', 'Cxx', 'Cyy')
# join spectrograms in each list so we can average it
for comp in components:
index = components.index(comp)
try:
cdict[comp] = speclist[index].join(gap='ignore')
except ValueError as e:
if 'units do not match' in str(e):
warnings.warn(str(e))
for spec in speclist[index][1:]:
spec.unit = speclist[0].unit
cdict[comp] = speclist[index].join(gap='ignore')
else:
raise
# average spectrograms to get PSDs and CSD
try:
Cxy = complex_percentile(cdict['Cxy'], 50)
Cxx = cdict['Cxx'].percentile(50)
Cyy = cdict['Cyy'].percentile(50)
globalv.COHERENCE_SPECTRUM[name] = FrequencySeries(
numpy.abs(Cxy)**2 / Cxx / Cyy, f0=Cxx.f0, df=Cxx.df)
except (ValueError, IndexError):
globalv.COHERENCE_SPECTRUM[name] = FrequencySeries(
[], channel=channel1, f0=0, df=1, unit=units.Unit(''))
globalv.COHERENCE_SPECTRUM[cmin] = globalv.COHERENCE_SPECTRUM[name]
globalv.COHERENCE_SPECTRUM[cmax] = globalv.COHERENCE_SPECTRUM[name]
else:
# FIXME: how to calculate percentiles correctly?
globalv.COHERENCE_SPECTRUM[cmin] = FrequencySeries(
abs(complex_percentile(cdict['Cxy'], 5))**2 /
cdict['Cxx'].percentile(95) / cdict['Cyy'].percentile(95),
f0=Cxx.f0, df=Cxx.df)
globalv.COHERENCE_SPECTRUM[cmax] = FrequencySeries(
abs(complex_percentile(cdict['Cxy'], 95))**2 /
cdict['Cxx'].percentile(5) / cdict['Cyy'].percentile(5),
f0=Cxx.f0, df=Cxx.df)
# set the spectrum's name manually; this will be used for the legend
globalv.COHERENCE_SPECTRUM[name].name = (
channel1.ndsname + '\n' + channel2.ndsname)
vprint(".\n")
if not return_:
return
cmin = '%s.min' % name
cmax = '%s.max' % name
out = (globalv.COHERENCE_SPECTRUM[name], globalv.COHERENCE_SPECTRUM[name],
globalv.COHERENCE_SPECTRUM[name])
return out
def add_coherence_component_spectrogram(specgram, key=None, coalesce=True):
"""Add a `Coherence spectrogram` to the global memory cache
"""
if key is None:
key = specgram.name or str(specgram.channel)
globalv.COHERENCE_COMPONENTS.setdefault(key, SpectrogramList())
globalv.COHERENCE_COMPONENTS[key].append(specgram)
if coalesce:
globalv.COHERENCE_COMPONENTS[key].coalesce()
@use_segmentlist
def get_coherence_spectrograms(channel_pairs, segments, config=None,
cache=None, query=True, nds=None,
return_=True, frametype=None, nproc=1,
datafind_error='raise', **fftparams):
"""Get coherence spectrograms for multiple channels
"""
if fftparams.get('method', 'welch') != 'welch':
raise ValueError("Cannot process coherence data with method=%r"
% fftparams.get('method'))
fftparams['method'] = 'welch'
channels = list(map(get_channel, channel_pairs))
pairs = list(zip(channels[0::2], channels[1::2]))
# get timeseries data in bulk
if query:
qchannels = []
havesegs = []
for c1, c2 in pairs:
c1 = get_channel(c1)
c2 = get_channel(c2)
fftparams_ = get_fftparams(c1, **fftparams)
key = make_globalv_key((c1, c2), fftparams_)
qchannels.extend((c1, c2))
havesegs.append(globalv.SPECTROGRAMS.get(
key, SpectrogramList()).segments)
havesegs = reduce(operator.and_, havesegs)
new = segments - havesegs
strides = set([getattr(c, 'stride', 0) for c in qchannels])
if len(strides) == 1:
stride = strides.pop()
new = type(new)([s for s in new if abs(s) >= stride])
get_timeseries_dict(qchannels, new, config=config, cache=cache,
nproc=nproc, frametype=frametype,
datafind_error=datafind_error, nds=nds,
return_=False)
# loop over channels and generate spectrograms
out = OrderedDict()
for channel_pair in pairs:
out[channel_pair] = get_coherence_spectrogram(
channel_pair, segments, config=config, cache=cache, query=query,
nds=nds, nproc=nproc, return_=return_,
datafind_error=datafind_error, **fftparams)
return out
def _get_from_list(serieslist, segment):
"""Internal function to crop a series from a serieslist
Should only be used in situations where the existence of the target
data within the list is guaranteed
"""
for series in serieslist:
if segment.intersects(series.span):
outseg = segment & series.span
return series.crop(*outseg)
raise ValueError("Cannot crop series for segment %s from list"
% str(segment))
def complex_percentile(array, percentile):
re = array.real.percentile(percentile)
im = array.imag.percentile(percentile) * 1j
return re + im