BlueBrain/eFEL

View on GitHub
efel/pyfeatures/extrafeats.py

Summary

Maintainability
C
1 day
Test Coverage
File `extrafeats.py` has 385 lines of code (exceeds 250 allowed). Consider refactoring.
"""Extracellular features functions"""
 
"""
Copyright (c) 2024, EPFL/Blue Brain Project
 
This file is part of eFEL <https://github.com/BlueBrain/eFEL>
 
This library is free software; you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License version 3.0 as published
by the Free Software Foundation.
 
This library is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
details.
 
You should have received a copy of the GNU Lesser General Public License
along with this library; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
"""
 
import numpy as np
from scipy.stats import linregress
from scipy.signal import resample_poly
 
all_1D_features = [
"peak_to_valley",
"halfwidth",
"peak_trough_ratio",
"repolarization_slope",
"recovery_slope",
"neg_peak_relative",
"pos_peak_relative",
"neg_peak_diff",
"pos_peak_diff",
"neg_image",
"pos_image",
]
 
 
def _get_slope(x, y):
"""
Return the slope of x and y data, using scipy.signal.linregress
"""
 
slope = linregress(x, y)
return slope
 
 
def _get_trough_and_peak_idx(waveform, after_max_trough=False):
"""
Return the indices of the detected troughs (minimum of waveform)
and peaks (maximum of waveform, after trough) of the input waveforms.
 
Assumes negative troughs and positive peaks
 
Returns 0 if not detected
"""
if after_max_trough:
max_trough_idx = np.unravel_index(
np.argmin(waveform),
waveform.shape)[1]
trough_idx = (
np.argmin(waveform[:, max_trough_idx:], axis=1) + max_trough_idx
)
peak_idx = (
np.argmax(waveform[:, max_trough_idx:], axis=1) + max_trough_idx
)
else:
trough_idx = np.argmin(waveform, axis=1)
peak_idx = np.argmax(waveform, axis=1)
 
return trough_idx, peak_idx
 
 
Function `calculate_features` has a Cognitive Complexity of 16 (exceeds 5 allowed). Consider refactoring.
Function `calculate_features` has 5 arguments (exceeds 4 allowed). Consider refactoring.
def calculate_features(
waveforms,
sampling_frequency,
upsample=None,
feature_names=None,
recovery_slope_window=0.7
):
"""Calculate features for all waveforms
 
Args:
waveforms : numpy.ndarray (num_waveforms x num_samples)
waveforms to compute features for
sampling_frequency : float
rate at which the waveforms are sampled (Hz)
feature_names : list or None (if None, compute all)
features to compute
recovery_slope_window : float
window length in ms after peak wherein recovery slope is computed
 
Returns:
dict (num_waveforms x num_metrics): Dictionary with computed metrics.
Keys are the metric names, values are the computed features
"""
metrics = dict()
 
if feature_names is None:
feature_names = all_1D_features
else:
for name in feature_names:
assert name in all_1D_features, f"{name} not in {all_1D_features}"
 
if upsample is not None:
assert upsample > 0
waveforms = _upsample_wf(waveforms, int(upsample))
sampling_frequency = upsample * sampling_frequency
 
if "peak_to_valley" in feature_names:
metrics["peak_to_valley"] = peak_to_valley(
waveforms=waveforms, sampling_frequency=sampling_frequency
)
if "peak_trough_ratio" in feature_names:
metrics["peak_trough_ratio"] = peak_trough_ratio(waveforms=waveforms)
 
if "halfwidth" in feature_names:
metrics["halfwidth"] = halfwidth(
waveforms=waveforms, sampling_frequency=sampling_frequency
)
 
if "repolarization_slope" in feature_names:
metrics["repolarization_slope"] = repolarization_slope(
waveforms=waveforms,
sampling_frequency=sampling_frequency,
)
 
if "recovery_slope" in feature_names:
metrics["recovery_slope"] = recovery_slope(
waveforms=waveforms,
sampling_frequency=sampling_frequency,
window=recovery_slope_window,
)
 
if "neg_peak_diff" in feature_names:
metrics["neg_peak_diff"] = peak_time_diff(
waveforms=waveforms, fs=sampling_frequency, sign="negative"
)
 
if "pos_peak_diff" in feature_names:
metrics["pos_peak_diff"] = peak_time_diff(
waveforms=waveforms, fs=sampling_frequency, sign="positive"
)
 
if "neg_peak_relative" in feature_names:
metrics["neg_peak_relative"] = relative_amplitude(
waveforms=waveforms, sign="negative"
)
 
if "pos_peak_relative" in feature_names:
metrics["pos_peak_relative"] = relative_amplitude(
waveforms=waveforms, sign="positive"
)
 
if "neg_image" in feature_names:
metrics["neg_image"] = peak_image(waveforms=waveforms, sign="negative")
 
if "pos_image" in feature_names:
metrics["pos_image"] = peak_image(waveforms=waveforms, sign="positive")
 
return metrics
 
 
def peak_to_valley(waveforms, sampling_frequency):
"""
Time between trough and peak. If the peak precedes the trough,
peak_to_valley is negative.
 
Args:
waveforms : numpy.ndarray (num_waveforms x num_samples)
waveforms to compute feature for
sampling_frequency : float
rate at which the waveforms are sampled (Hz)
 
Returns:
np.ndarray (num_waveforms): peak_to_valley in seconds
"""
trough_idx, peak_idx = _get_trough_and_peak_idx(waveforms)
ptv = (peak_idx - trough_idx) * (1 / sampling_frequency)
ptv[ptv == 0] = np.nan
return ptv
 
 
def peak_trough_ratio(waveforms):
"""
Normalized ratio of peak height over trough depth
 
Assumes baseline is 0
 
Args:
waveforms : numpy.ndarray (num_waveforms x num_samples)
waveforms to compute feature for
 
Returns:
np.ndarray (num_waveforms): Peak to trough ratio
"""
trough_idx, peak_idx = _get_trough_and_peak_idx(waveforms)
ptratio = np.empty(trough_idx.shape[0])
ptratio[:] = np.nan
for i in range(waveforms.shape[0]):
if peak_idx[i] == 0 and trough_idx[i] == 0:
continue
ptratio[i] = np.abs(waveforms[i, peak_idx[i]] /
waveforms[i, trough_idx[i]])
 
return ptratio
 
 
Function `halfwidth` has a Cognitive Complexity of 13 (exceeds 5 allowed). Consider refactoring.
def halfwidth(
waveforms,
sampling_frequency,
return_idx=False
):
"""
Width of waveform at half of its amplitude.
If the peak precedes the trough, halfwidth is negative.
 
Computes the width of the waveform peak at half its height
 
Args:
waveforms : numpy.ndarray (num_waveforms x num_samples)
waveforms to compute features for
sampling_frequency : float
rate at which the waveforms are sampled (Hz)
return_idx : bool
if true, also returns index of threshold crossing before and
index of threshold crossing after peak
 
Returns:
np.ndarray or (np.ndarray, np.ndarray, np.ndarray):
Halfwidth of the waveforms or (Halfwidth of the waveforms,
index_cross_pre_peak, index_cross_post_peak)
"""
trough_idx, peak_idx = _get_trough_and_peak_idx(waveforms)
hw = np.empty(waveforms.shape[0])
hw[:] = np.nan
cross_pre_pk = np.empty(waveforms.shape[0], dtype=int)
cross_post_pk = np.empty(waveforms.shape[0], dtype=int)
 
for i in range(waveforms.shape[0]):
if peak_idx[i] >= trough_idx[i]:
trough_val = waveforms[i, trough_idx[i]]
threshold = (
0.5 * trough_val
) # threshold is half of peak heigth (assuming baseline is 0)
 
cpre_idx = np.where(waveforms[i, :trough_idx[i]] < threshold)[0]
cpost_idx = np.where(waveforms[i, trough_idx[i]:] < threshold)[0]
 
if len(cpre_idx) == 0 or len(cpost_idx) == 0:
continue
 
cross_pre_pk[i] = (
cpre_idx[0] - 1
) # last occurence of waveform lower than thr, before peak
cross_post_pk[i] = (
cpost_idx[-1] + 1 + trough_idx[i]
) # first occurence of waveform lower than peak, after peak
 
hw[i] = (cross_post_pk[i] - cross_pre_pk[i]) * (
1 / sampling_frequency
) # + peak_idx[i]
else:
peak_val = waveforms[i, peak_idx[i]]
threshold = (
0.5 * peak_val
) # threshold is half of peak heigth (assuming baseline is 0)
 
cpre_idx = np.where(waveforms[i, :peak_idx[i]] > threshold)[0]
cpost_idx = np.where(waveforms[i, peak_idx[i]:] > threshold)[0]
 
if len(cpre_idx) == 0 or len(cpost_idx) == 0:
continue
 
cross_pre_pk[i] = (
cpre_idx[0] - 1
) # last occurence of waveform lower than thr, before peak
cross_post_pk[i] = (
cpost_idx[-1] + 1 + trough_idx[i]
) # first occurence of waveform lower than peak, after peak
 
hw[i] = -(cross_post_pk[i] - cross_pre_pk[i]) * (
1 / sampling_frequency
) # + peak_idx[i]
 
if not return_idx:
return hw
 
return hw, cross_pre_pk, cross_post_pk
 
 
Function `repolarization_slope` has a Cognitive Complexity of 8 (exceeds 5 allowed). Consider refactoring.
def repolarization_slope(waveforms,
sampling_frequency,
return_idx=False
):
"""
Return slope of repolarization period between trough and baseline
 
After reaching its maxumum polarization, the neuron potential will
recover. The repolarization slope is defined as the dV/dT of the action
potential between trough and baseline.
 
Optionally the function returns also the indices per waveform where the
potential crosses baseline.
 
Args:
waveforms : numpy.ndarray (num_waveforms x num_samples)
waveforms to compute features for
sampling_frequency : float
rate at which the waveforms are sampled (Hz)
return_idx : bool
if true, also returns index of threshold crossing before and
index of threshold crossing after peak
 
Returns:
np.ndarray or (np.ndarray, np.ndarray): Repolarization slope of the
waveforms or (Repolarization slope of the waveforms, return to base
index)
"""
trough_idx, peak_idx = _get_trough_and_peak_idx(waveforms)
 
rslope = np.empty(waveforms.shape[0])
rslope[:] = np.nan
return_to_base_idx = np.empty(waveforms.shape[0], dtype=np.int_)
return_to_base_idx[:] = 0
 
time = np.arange(0, waveforms.shape[1]) * (1 / sampling_frequency) # in s
for i in range(waveforms.shape[0]):
if trough_idx[i] == 0:
continue
 
rtrn_idx = np.where(waveforms[i, trough_idx[i]:] >= 0)[0]
if len(rtrn_idx) == 0:
continue
 
return_to_base_idx[i] = (
rtrn_idx[0] + trough_idx[i]
) # first time after trough, where waveform is at baseline
 
if return_to_base_idx[i] - trough_idx[i] < 3:
continue
slope = _get_slope(
time[trough_idx[i]:return_to_base_idx[i]],
waveforms[i, trough_idx[i]:return_to_base_idx[i]]
)
rslope[i] = slope[0]
 
if not return_idx:
return rslope
 
return rslope, return_to_base_idx
 
 
def recovery_slope(waveforms, sampling_frequency, window):
"""
Return the recovery slope of input waveforms. After repolarization,
the neuron hyperpolarizes until it peaks. The recovery slope is the
slope of the action potential after the peak, returning to the baseline
in dV/dT. The slope is computed within a user-defined window after
the peak.
 
Takes a numpy array of waveforms and returns an array with
recovery slopes per waveform.
 
Args:
waveforms : numpy.ndarray (num_waveforms x num_samples)
waveforms to compute features for
sampling_frequency : float
rate at which the waveforms are sampled (Hz)
window : float
length after peak wherein to compute recovery slope (ms)
 
Returns:
np.ndarray: Recovery slope of the waveforms
"""
_, peak_idx = _get_trough_and_peak_idx(waveforms)
rslope = np.empty(waveforms.shape[0])
rslope[:] = np.nan
 
time = np.arange(0, waveforms.shape[1]) * (1 / sampling_frequency) # in s
 
for i in range(waveforms.shape[0]):
if peak_idx[i] in [0, waveforms.shape[1]]:
continue
max_idx = int(peak_idx[i] + ((window / 1000) * sampling_frequency))
max_idx = np.min([max_idx, waveforms.shape[1]])
 
if len(time[peak_idx[i]:max_idx]) < 3:
continue
slope = _get_slope(
time[peak_idx[i]:max_idx], waveforms[i, peak_idx[i]:max_idx]
)
rslope[i] = slope[0]
 
return rslope
 
 
def peak_image(waveforms, sign="negative"):
"""
Normalized amplitude at the time of peak minimum or maximum.
 
Args:
waveforms : numpy.ndarray (num_waveforms x num_samples)
waveforms to compute features for
sign : str
"positive" | "negative"
 
Returns:
np.ndarray: Peak images for the waveforms
"""
assert len(waveforms) > 1
 
if sign == "negative":
funarg = np.argmin
fun = np.min
else:
funarg = np.argmax
fun = np.max
 
peak_channel, peak_time = np.unravel_index(
funarg(waveforms), waveforms.shape
)
relative_peaks = waveforms[:, peak_time] / fun(waveforms[peak_channel])
 
return relative_peaks
 
 
def relative_amplitude(waveforms, sign="negative"):
"""
Normalized amplitude with respect to channel with largest amplitude.
 
Args:
waveforms : numpy.ndarray (num_waveforms x num_samples)
waveforms to compute features for
sign : str
"positive" | "negative"
 
Returns:
np.ndarray: Relative amplitudes for the waveforms
"""
assert len(waveforms) > 1
 
if sign == "negative":
fun = np.min
else:
fun = np.max
 
peak_amp = np.abs(fun(waveforms))
relative_peaks = np.abs(fun(waveforms, 1)) / peak_amp
 
return relative_peaks
 
 
def peak_time_diff(waveforms, fs, sign="negative"):
"""
Peak time differences with respect to channel with largest amplitude.
 
Args:
waveforms : numpy.ndarray (num_waveforms x num_samples)
waveforms to compute features for
fs : float
Sampling rate in Hz
sign : str
"positive" | "negative"
 
Returns:
np.ndarray: Peak time differences for the waveforms
"""
assert len(waveforms) > 1
 
if sign == "negative":
argfun = np.argmin
else:
argfun = np.argmax
 
peak_chan = np.unravel_index(argfun(waveforms), waveforms.shape)[0]
peak_time = argfun(waveforms[peak_chan])
relative_peak_times = (argfun(waveforms, 1) - peak_time) / fs
 
return relative_peak_times
 
 
def _upsample_wf(waveforms, upsample):
 
ndim = len(waveforms.shape)
waveforms_up = resample_poly(waveforms, up=upsample, down=1, axis=ndim - 1)
 
return waveforms_up