neuropsychology/NeuroKit.py

View on GitHub
neurokit/eeg/eeg_erp.py

Summary

Maintainability
D
1 day
Test Coverage
"""
ERP analysis EEG submodule.
"""
from .eeg_data import eeg_select_electrodes
from .eeg_data import eeg_to_df
from .eeg_data import eeg_to_all_evokeds


import numpy as np
import pandas as pd
import mne




# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
def eeg_erp(eeg, times=None, index=None, include="all", exclude=None, hemisphere="both", central=True, verbose=True, names="ERP", method="mean"):
    """
    DOCS INCOMPLETE :(
    """
    erp = {}

    data = eeg_to_df(eeg, index=index, include=include, exclude=exclude, hemisphere=hemisphere, central=central)

    for epoch_index, epoch in data.items():
        # Segment according to window
        if isinstance(times, list):
            if isinstance(times[0], list):
                values = {}
                for window_index, window in enumerate(times):
                    df = epoch[window[0]:window[1]]
                    value = df.mean().mean()
                    values[window_index] = value
                erp[epoch_index] = values
            else:
                df = epoch[times[0]:times[1]]
                value = df.mean().mean()
                erp[epoch_index] = [value]
        elif isinstance(times, tuple):
            values = {}
            for window_index, window in enumerate(times):
                df = epoch[window[0]:window[1]]
                if method == "max":
                    value = df.mean().max()
                elif method == "min":
                    value = df.mean().min()
                else:
                    value = df.mean().mean()
                values[window_index] = value
            erp[epoch_index] = values
        else:
            df = epoch[0:]
            value = df.mean().mean()
            erp[epoch_index] = [value]

    # Convert to dataframe
    erp = pd.DataFrame.from_dict(erp, orient="index")
    if isinstance(names, str):
        names = [names]
    if len(names) == len(erp.columns):
        erp.columns = names

    return(erp)


# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
def plot_eeg_erp(all_epochs, conditions=None, times=None, include="all", exclude=None, hemisphere="both", central=True, name=None, colors=None, gfp=False, ci=0.95, ci_alpha=0.333, invert_y=False, linewidth=1, linestyle="-", filter_hfreq=None):
    """
    DOCS INCOMPLETE :(
    """
    # Preserve original
    all_epochs_current = all_epochs.copy()

    # Filter using Savitzky-Golay polynomial method
    if (filter_hfreq is not None) and (isinstance(filter_hfreq, int)):
        for participant, epochs in all_epochs_current.items():
            all_epochs_current[participant] = epochs.savgol_filter(filter_hfreq, copy=True)

    # Crop
    if isinstance(times, list) and len(times) == 2:
        for participant, epochs in all_epochs_current.items():
            all_epochs_current[participant] = epochs.copy().crop(times[0], times[1])


    # Transform to evokeds
    all_evokeds = eeg_to_all_evokeds(all_epochs_current, conditions=conditions)

    data = {}
    for participant, epochs in all_evokeds.items():
        for condition, epoch in epochs.items():
            data[condition] = []
    for participant, epochs in all_evokeds.items():
        for condition, epoch in epochs.items():
            data[condition].append(epoch)


    conditions = list(data.keys())

    # Line styles
    if isinstance(linestyle, str):
        linestyle = [linestyle] * len(conditions)
    elif isinstance(linestyle, list) and len(linestyle) >= len(conditions):
        pass
    elif isinstance(linestyle, dict) and len(linestyle.keys()) >= len(conditions):
        linestyle = [linestyle[cond] for cond in conditions]
    else:
        print("NeuroKit Warning: plot_eeg_erp(): linestyle must be either a str, a list or a dict.")


    # Colors
    if isinstance(colors, str):
        colors = {condition: colors for condition in conditions}
    elif isinstance(colors, list) and len(colors) >= len(conditions):
        colors=  {condition: colors[index] for index, condition in enumerate(conditions)}
    elif isinstance(colors, dict) and len(colors.keys()) >= len(conditions):
        pass
    elif colors is None:
        pass
    else:
        print("NeuroKit Warning: plot_eeg_erp(): colors must be either a str, a list, a dict or None.")


    # Modify styles
    styles = {}
    for index, condition in enumerate(conditions):
        styles[condition] = {"linewidth": linewidth, "linestyle": linestyle[index]}


    # Select electrodes
    picks = mne.pick_types(epoch.info, eeg=True, selection=eeg_select_electrodes(epoch, include=include, exclude=exclude, hemisphere=hemisphere, central=central))

    # Plot
    try:
        plot = mne.viz.plot_compare_evokeds(data, picks=picks, colors=colors, styles=styles, title=name, gfp=gfp, ci=ci, invert_y=invert_y, ci_alpha=ci_alpha)
    except TypeError:
        print("NeuroKit Warning: plot_eeg_erp(): You're using a version of mne that does not support ci_alpha or ci_method parameters. Leaving defaults.")
        plot = mne.viz.plot_compare_evokeds(data, picks=picks, colors=colors, styles=styles, title=name, gfp=gfp, ci=ci, invert_y=invert_y)

    return(plot)





# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
# ==============================================================================
def plot_eeg_erp_topo(all_epochs, colors=None):
    """
    Plot butterfly plot.

    DOCS INCOMPLETE :(
    """
    all_evokeds = eeg_to_all_evokeds(all_epochs)

    data = {}
    for participant, epochs in all_evokeds.items():
        for cond, epoch in epochs.items():
            data[cond] = []
    for participant, epochs in all_evokeds.items():
        for cond, epoch in epochs.items():
            data[cond].append(epoch)

    if colors is not None:
        color_list = []
    else:
        color_list = None

    evokeds = []
    for condition, evoked in data.items():
        grand_average = mne.grand_average(evoked)
        grand_average.comment = condition
        evokeds += [grand_average]
        if colors is not None:
            color_list.append(colors[condition])

    plot = mne.viz.plot_evoked_topo(evokeds, background_color="w", color=color_list)
    return(plot)