neuropsychology/NeuroKit

View on GitHub
neurokit2/signal/signal_timefrequency.py

Summary

Maintainability
A
0 mins
Test Coverage
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
import scipy.signal

from ..signal.signal_detrend import signal_detrend


def signal_timefrequency(
    signal,
    sampling_rate=1000,
    min_frequency=0.04,
    max_frequency=None,
    method="stft",
    window=None,
    window_type="hann",
    mode="psd",
    nfreqbin=None,
    overlap=None,
    analytical_signal=True,
    show=True,
):
    """**Quantify changes of a nonstationary signal’s frequency over time**
    The objective of time-frequency analysis is to offer a more informative description of the
    signal which reveals the temporal variation of its frequency contents.

    There are many different Time-Frequency Representations (TFRs) available:

    * Linear TFRs: efficient but create tradeoff between time and frequency resolution

        * Short Time Fourier Transform (STFT): the time-domain signal is windowed into short
          segments and FT is applied to each segment, mapping the signal into the TF plane. This
          method assumes that the signal is quasi-stationary (stationary over the duration of the
          window). The width of the window is the trade-off between good time (requires short
          duration window) versus good frequency resolution (requires long duration windows)

        * Wavelet Transform (WT): similar to STFT but instead of a fixed duration window function,
          a varying window length by scaling the axis of the window is used. At low frequency, WT
          proves high spectral resolution but poor temporal resolution. On the other hand, for high
          frequencies, the WT provides high temporal resolution but poor spectral resolution.

    * Quadratic TFRs: better resolution but computationally expensive and suffers from having
      cross terms between multiple signal components

        * Wigner Ville Distribution (WVD): while providing very good resolution in time and
          frequency of the underlying signal structure, because of its bilinear nature, existence
          of negative values, the WVD has misleading TF results in the case of multi-component
          signals such as EEG due to the presence of cross terms and inference terms. Cross WVD
          terms can be reduced by using smoothing kernel functions as well as analyzing the
          analytic signal (instead of the original signal)

        * Smoothed Pseudo Wigner Ville Distribution (SPWVD): to address the problem of cross-terms
          suppression, SPWVD allows two independent analysis windows, one in time and the other in
          frequency domains.

    Parameters
    ----------
    signal : Union[list, np.array, pd.Series]
        The signal (i.e., a time series) in the form of a vector of values.
    sampling_rate : int
        The sampling frequency of the signal (in Hz, i.e., samples/second).
    method : str
        Time-Frequency decomposition method.
    min_frequency : float
        The minimum frequency.
    max_frequency : float
        The maximum frequency.
    window : int
        Length of each segment in seconds. If ``None`` (default), window will be automatically
        calculated. For ``"STFT" method``.
    window_type : str
        Type of window to create, defaults to ``"hann"``. See :func:`.scipy.signal.get_window` to
        see full options of windows. For ``"STFT" method``.
    mode : str
        Type of return values for ``"STFT" method``. Can be ``"psd"``, ``"complex"`` (default,
        equivalent to output of ``"STFT"`` with no padding or boundary extension), ``"magnitude"``,
        ``"angle"``, ``"phase"``. Defaults to ``"psd"``.
    nfreqbin : int, float
        Number of frequency bins. If ``None`` (default), nfreqbin will be set to
        ``0.5*sampling_rate``.
    overlap : int
        Number of points to overlap between segments. If ``None``, ``noverlap = nperseg // 8``.
        Defaults to ``None``.
    analytical_signal : bool
        If ``True``, analytical signal instead of actual signal is used in `Wigner Ville
        Distribution` methods.
    show : bool
        If ``True``, will return two PSD plots.

    Returns
    -------
    frequency : np.array
        Frequency.
    time : np.array
        Time array.
    stft : np.array
        Short Term Fourier Transform. Time increases across its columns and frequency increases
        down the rows.

    Examples
    -------
    .. ipython:: python

      import neurokit2 as nk

      sampling_rate = 100
      signal = nk.signal_simulate(100, sampling_rate, frequency=[3, 10])

      # STFT Method
      @savefig p_signal_timefrequency1.png scale=100%
      f, t, stft = nk.signal_timefrequency(signal,
                                           sampling_rate,
                                           max_frequency=20,
                                           method="stft",
                                           show=True)
      @suppress
      plt.close()

    .. ipython:: python

      # CWTM Method
      @savefig p_signal_timefrequency2.png scale=100%
      f, t, cwtm = nk.signal_timefrequency(signal,
                                           sampling_rate,
                                           max_frequency=20,
                                           method="cwt",
                                           show=True)
      @suppress
      plt.close()

    .. ipython:: python

      # WVD Method
      @savefig p_signal_timefrequency3.png scale=100%
      f, t, wvd = nk.signal_timefrequency(signal,
                                          sampling_rate,
                                          max_frequency=20,
                                          method="wvd",
                                          show=True)
      @suppress
      plt.close()

    .. ipython:: python

      # PWVD Method
      @savefig p_signal_timefrequency4.png scale=100%
      f, t, pwvd = nk.signal_timefrequency(signal,
                                           sampling_rate,
                                           max_frequency=20,
                                           method="pwvd",
                                           show=True)
      @suppress
      plt.close()

    """
    # Initialize empty container for results
    # Define window length
    if min_frequency == 0:
        min_frequency = 0.04  # sanitize lowest frequency to lf
    if max_frequency is None:
        max_frequency = sampling_rate // 2  # nyquist

    # STFT
    if method.lower() in ["stft"]:

        frequency, time, tfr = short_term_ft(
            signal,
            sampling_rate=sampling_rate,
            overlap=overlap,
            window=window,
            mode=mode,
            min_frequency=min_frequency,
            window_type=window_type,
        )
    # CWT
    elif method.lower() in ["cwt", "wavelet"]:
        frequency, time, tfr = continuous_wt(
            signal,
            sampling_rate=sampling_rate,
            min_frequency=min_frequency,
            max_frequency=max_frequency,
        )
    # WVD
    elif method in ["WignerVille", "wvd"]:
        frequency, time, tfr = wvd(
            signal,
            sampling_rate=sampling_rate,
            n_freqbins=nfreqbin,
            analytical_signal=analytical_signal,
            method="WignerVille",
        )
    # pseudoWVD
    elif method in ["pseudoWignerVille", "pwvd"]:
        frequency, time, tfr = wvd(
            signal,
            sampling_rate=sampling_rate,
            n_freqbins=nfreqbin,
            analytical_signal=analytical_signal,
            method="pseudoWignerVille",
        )

    # Sanitize output
    lower_bound = len(frequency) - len(frequency[frequency >= min_frequency])
    f = frequency[(frequency >= min_frequency) & (frequency <= max_frequency)]
    z = tfr[lower_bound : lower_bound + len(f)]

    if show is True:
        plot_timefrequency(
            z,
            time,
            f,
            signal=signal,
            method=method,
        )

    return f, time, z


# =============================================================================
# Short-Time Fourier Transform (STFT)
# =============================================================================


def short_term_ft(
    signal,
    sampling_rate=1000,
    min_frequency=0.04,
    overlap=None,
    window=None,
    window_type="hann",
    mode="psd",
):
    """Short-term Fourier Transform."""

    if window is not None:
        nperseg = int(window * sampling_rate)
    else:
        # to capture at least 5 times slowest wave-length
        nperseg = int((2 / min_frequency) * sampling_rate)

    frequency, time, tfr = scipy.signal.spectrogram(
        signal,
        fs=sampling_rate,
        window=window_type,
        scaling="density",
        nperseg=nperseg,
        nfft=None,
        detrend=False,
        noverlap=overlap,
        mode=mode,
    )

    return frequency, time, np.abs(tfr)


# =============================================================================
# Continuous Wavelet Transform (CWT) - Morlet
# =============================================================================


def continuous_wt(
    signal, sampling_rate=1000, min_frequency=0.04, max_frequency=None, nfreqbin=None
):
    """**Continuous Wavelet Transform**

     References
     ----------
     * Neto, O. P., Pinheiro, A. O., Pereira Jr, V. L., Pereira, R., Baltatu, O. C., & Campos, L.
       A. (2016). Morlet wavelet transforms of heart rate variability for autonomic nervous system
       activity. Applied and Computational Harmonic Analysis, 40(1), 200-206.

    * Wachowiak, M. P., Wachowiak-Smolíková, R., Johnson, M. J., Hay, D. C., Power, K. E.,
      & Williams-Bell, F. M. (2018). Quantitative feature analysis of continuous analytic wavelet
      transforms of electrocardiography and electromyography. Philosophical Transactions of the
      Royal Society A: Mathematical, Physical and Engineering Sciences, 376(2126), 20170250.

    """

    # central frequency
    w = 6.0  # recommended

    if nfreqbin is None:
        nfreqbin = sampling_rate // 2

    # frequency
    frequency = np.linspace(min_frequency, max_frequency, nfreqbin)

    # time
    time = np.arange(len(signal)) / sampling_rate
    widths = w * sampling_rate / (2 * frequency * np.pi)

    # Mother wavelet = Morlet
    tfr = scipy.signal.cwt(signal, scipy.signal.morlet2, widths, w=w)

    return frequency, time, np.abs(tfr)


# =============================================================================
# Wigner-Ville Distribution
# =============================================================================
def wvd(signal, sampling_rate=1000, n_freqbins=None, analytical_signal=True, method="WignerVille"):
    """Wigner Ville Distribution and Pseudo-Wigner Ville Distribution."""
    # Compute the analytical signal
    if analytical_signal:
        signal = scipy.signal.hilbert(signal_detrend(signal))

    # Pre-processing
    if n_freqbins is None:
        n_freqbins = 256

    if method in ["pseudoWignerVille", "pwvd"]:
        fwindows = np.zeros(n_freqbins + 1)
        fwindows_mpts = len(fwindows) // 2
        windows_length = n_freqbins // 4
        windows_length = windows_length - windows_length % 2 + 1
        windows = np.hamming(windows_length)
        fwindows[fwindows_mpts + np.arange(-windows_length // 2, windows_length // 2)] = windows
    else:
        fwindows = np.ones(n_freqbins + 1)
        fwindows_mpts = len(fwindows) // 2

    time = np.arange(len(signal)) * 1.0 / sampling_rate

    # This is discrete frequency (should we return?)
    if n_freqbins % 2 == 0:
        frequency = np.hstack((np.arange(n_freqbins / 2), np.arange(-n_freqbins / 2, 0)))
    else:
        frequency = np.hstack(
            (np.arange((n_freqbins - 1) / 2), np.arange(-(n_freqbins - 1) / 2, 0))
        )
    tfr = np.zeros((n_freqbins, time.shape[0]), dtype=complex)  # the time-frequency matrix

    tausec = round(n_freqbins / 2.0)
    winlength = tausec - 1
    # taulens: len of tau for each step
    taulens = np.min(
        np.c_[
            np.arange(signal.shape[0]),
            signal.shape[0] - np.arange(signal.shape[0]) - 1,
            winlength * np.ones(time.shape),
        ],
        axis=1,
    )
    conj_signal = np.conj(signal)
    # iterate and compute the wv for each indices
    for idx in range(time.shape[0]):
        tau = np.arange(-taulens[idx], taulens[idx] + 1).astype(int)
        # this step is required to use the efficient DFT
        indices = np.remainder(n_freqbins + tau, n_freqbins).astype(int)
        tfr[indices, idx] = (
            fwindows[fwindows_mpts + tau] * signal[idx + tau] * conj_signal[idx - tau]
        )
        if (idx < signal.shape[0] - tausec) and (idx >= tausec + 1):
            tfr[tausec, idx] = (
                fwindows[fwindows_mpts + tausec]
                * signal[idx + tausec]
                * np.conj(signal[idx - tausec])
                + fwindows[fwindows_mpts - tausec]
                * signal[idx - tausec]
                * conj_signal[idx + tausec]
            )
            tfr[tausec, idx] *= 0.5

    # Now tfr contains the product of the signal segments and its conjugate.
    # To find wd we need to apply fft one more time.
    tfr = np.fft.fft(tfr, axis=0)
    tfr = np.real(tfr)

    # continuous time frequency
    frequency = 0.5 * np.arange(n_freqbins, dtype=float) / n_freqbins * sampling_rate

    return frequency, time, tfr


# =============================================================================
# Smooth Pseudo-Wigner-Ville Distribution
# =============================================================================


def smooth_pseudo_wvd(
    signal,
    sampling_rate=1000,
    freq_length=None,
    time_length=None,
    segment_step=1,
    nfreqbin=None,
    window_method="hamming",
):
    """**Smoothed Pseudo Wigner Ville Distribution**

    Parameters
    ----------
    signal : Union[list, np.array, pd.Series]
        The signal (i.e., a time series) in the form of a vector of values.
    sampling_rate : int
        The sampling frequency of the signal (in Hz, i.e., samples/second).
    freq_length : np.ndarray
        Lenght of frequency smoothing window.
    time_length: np.array
        Lenght of time smoothing window
    segment_step : int
        The step between samples in ``time_array``. Default to 1.
    nfreqbin : int
        Number of Frequency bins.
    window_method : str
        Method used to create smoothing windows. Can be "hanning"/ "hamming" or "gaussian".

    Returns
    -------
    frequency_array : np.ndarray
        Frequency array.
    time_array : np.ndarray
        Time array.
    pwvd : np.ndarray
        SPWVD. Time increases across its columns and frequency increases
        down the rows.

    References
    ----------
    * J. M. O' Toole, M. Mesbah, and B. Boashash, (2008), "A New Discrete Analytic Signal for
      Reducing Aliasing in the Discrete Wigner-Ville Distribution", IEEE Trans.

    """

    # Define parameters
    N = len(signal)
    # sample_spacing = 1 / sampling_rate
    if nfreqbin is None:
        nfreqbin = 300

    # Zero-padded signal to length 2N
    signal_padded = np.append(signal, np.zeros_like(signal))

    # DFT
    signal_fft = np.fft.fft(signal_padded)
    signal_fft[1 : N - 1] = signal_fft[1 : N - 1] * 2
    signal_fft[N:] = 0

    # Inverse FFT
    signal_ifft = np.fft.ifft(signal_fft)
    signal_ifft[N:] = 0

    # Make analytic signal
    signal = scipy.signal.hilbert(signal_detrend(signal_ifft))

    # Create smoothing windows in time and frequency
    if freq_length is None:
        freq_length = np.floor(N / 4.0)
        # Plus one if window length is not odd
        if freq_length % 2 == 0:
            freq_length += 1
    elif len(freq_length) % 2 == 0:
        raise ValueError("The length of frequency smoothing window must be odd.")

    if time_length is None:
        time_length = np.floor(N / 10.0)
        # Plus one if window length is not odd
        if time_length % 2 == 0:
            time_length += 1
    elif len(time_length) % 2 == 0:
        raise ValueError("The length of time smoothing window must be odd.")

    if window_method == "hamming":
        freq_window = scipy.signal.hamming(int(freq_length))  # normalize by max
        time_window = scipy.signal.hamming(int(time_length))  # normalize by max
    elif window_method == "gaussian":
        std_freq = freq_length / (6 * np.sqrt(2 * np.log(2)))
        freq_window = scipy.signal.gaussian(freq_length, std_freq)
        freq_window /= max(freq_window)
        std_time = time_length / (6 * np.sqrt(2 * np.log(2)))
        time_window = scipy.signal.gaussian(time_length, std_time)
        time_window /= max(time_window)
    # to add warning if method is not one of the supported methods

    # Mid-point index of windows
    midpt_freq = (len(freq_window) - 1) // 2
    midpt_time = (len(time_window) - 1) // 2

    # Create arrays
    time_array = np.arange(start=0, stop=N, step=segment_step, dtype=int) / sampling_rate
    # frequency_array = np.fft.fftfreq(nfreqbin, sample_spacing)[0:nfreqbin / 2]
    frequency_array = 0.5 * np.arange(nfreqbin, dtype=float) / N
    pwvd = np.zeros((nfreqbin, len(time_array)), dtype=complex)

    # Calculate pwvd
    for i, t in enumerate(time_array):
        # time shift
        tau_max = np.min(
            [t + midpt_time - 1, N - t + midpt_time, np.round(N / 2.0) - 1, midpt_freq]
        )
        # time-lag list
        tau = np.arange(
            start=-np.min([midpt_time, N - t]), stop=np.min([midpt_time, t - 1]) + 1, dtype="int"
        )
        time_pts = (midpt_time + tau).astype(int)
        g2 = time_window[time_pts]
        g2 = g2 / np.sum(g2)
        signal_pts = (t - tau - 1).astype(int)
        # zero frequency
        pwvd[0, i] = np.sum(g2 * signal[signal_pts] * np.conjugate(signal[signal_pts]))
        # other frequencies
        for m in range(int(tau_max)):
            tau = np.arange(
                start=-np.min([midpt_time, N - t - m]),
                stop=np.min([midpt_time, t - m - 1]) + 1,
                dtype="int",
            )
            time_pts = (midpt_time + tau).astype(int)
            g2 = time_window[time_pts]
            g2 = g2 / np.sum(g2)
            signal_pt1 = (t + m - tau - 1).astype(int)
            signal_pt2 = (t - m - tau - 1).astype(int)
            # compute positive half
            rmm = np.sum(g2 * signal[signal_pt1] * np.conjugate(signal[signal_pt2]))
            pwvd[m + 1, i] = freq_window[midpt_freq + m + 1] * rmm
            # compute negative half
            rmm = np.sum(g2 * signal[signal_pt2] * np.conjugate(signal[signal_pt1]))
            pwvd[nfreqbin - m - 1, i] = freq_window[midpt_freq - m + 1] * rmm

        m = np.round(N / 2.0)

        if t <= N - m and t >= m + 1 and m <= midpt_freq:
            tau = np.arange(
                start=-np.min([midpt_time, N - t - m]),
                stop=np.min([midpt_time, t - 1 - m]) + 1,
                dtype="int",
            )
            time_pts = (midpt_time + tau + 1).astype(int)
            g2 = time_window[time_pts]
            g2 = g2 / np.sum(g2)
            signal_pt1 = (t + m - tau).astype(int)
            signal_pt2 = (t - m - tau).astype(int)
            x = np.sum(g2 * signal[signal_pt1] * np.conjugate(signal[signal_pt2]))
            x *= freq_window[midpt_freq + m + 1]
            y = np.sum(g2 * signal[signal_pt2] * np.conjugate(signal[signal_pt1]))
            y *= freq_window[midpt_freq - m + 1]
            pwvd[m, i] = 0.5 * (x + y)

    pwvd = np.real(np.fft.fft(pwvd, axis=0))

    # Visualization

    return frequency_array, time_array, pwvd


# =============================================================================
# Plot function
# =============================================================================
def plot_timefrequency(z, time, f, signal=None, method="stft"):
    """Visualize a time-frequency matrix."""

    if method == "stft":
        figure_title = "Short-time Fourier Transform Magnitude"
        fig, ax = plt.subplots()
        for i in range(len(time)):
            ax.plot(f, z[:, i], label="Segment" + str(np.arange(len(time))[i] + 1))
        ax.legend()
        ax.set_title("Signal Spectrogram")
        ax.set_ylabel("STFT Magnitude")
        ax.set_xlabel("Frequency (Hz)")

    elif method == "cwt":
        figure_title = "Continuous Wavelet Transform Magnitude"
    elif method == "wvd":
        figure_title = "Wigner Ville Distrubution Spectrogram"
        fig = plt.figure()
        plt.plot(time, signal)
        plt.xlabel("Time (sec)")
        plt.ylabel("Signal")

    elif method == "pwvd":
        figure_title = "Pseudo Wigner Ville Distribution Spectrogram"

    fig, ax = plt.subplots()
    spec = ax.pcolormesh(time, f, z, cmap=plt.get_cmap("magma"), shading="auto")
    plt.colorbar(spec)
    ax.set_title(figure_title)
    ax.set_ylabel("Frequency (Hz)")
    ax.set_xlabel("Time (sec)")
    return fig