"""Helpers for vizualization."""
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms
import numpy as np

from ..embedding import SpectralEmbedding, LocallyLinearEmbedding

def plot_embedding(X,
                   title="Embedding of covariances",
    """Plot 2D embedding of SPD matrices.

    X : ndarray, shape (n_matrices, n_channels, n_channels)
        Set of SPD matrices.
    y : None | ndarray, shape (n_matrices,), default=None
        Labels for each matrix.
    metric : string, default="riemann"
        Metric used in the embedding. Can be {"riemann", "logeuclid",
        "euclid"} for Locally Linear Embedding, and {"riemann", "logeuclid",
        "euclid", "logdet", "kullback", "kullback_right", "kullback_sym"}
        for Spectral Embedding.
    title : str, default="Embedding of covariances"
        Title of figure.
    embd_type : {"Spectral", "LocallyLinear"}, default="Spectral"
        Embedding type.
    normalize : bool, default=True
        If True, the plot is normalized from -1 to +1.

    fig : matplotlib figure
        Figure of embedding.

    .. versionadded:: 0.2.6
    if embd_type == "Spectral":
        lapl = SpectralEmbedding(n_components=2, metric=metric)
    elif embd_type == "LocallyLinear":
        lapl = LocallyLinearEmbedding(n_components=2,
        raise ValueError(f"Unknown embedding type {embd_type}. "
                         "Valid types are: 'Spectral', 'LocallyLinear'.")

    embd = lapl.fit_transform(X)

    if y is None:
        y = np.ones(embd.shape[0])

    fig, ax = plt.subplots(figsize=(7, 7), facecolor="white")
    for label in np.unique(y):
        idx = y == label
        ax.scatter(embd[idx, 0], embd[idx, 1], s=36)

    ax.set_xlabel(r"$\varphi_1$", fontsize=16)
    ax.set_ylabel(r"$\varphi_2$", fontsize=16)
    ax.set_title(f"{embd_type} {title}", fontsize=16)
    if normalize:
        ax.set_xticks([-1.0, -0.5, 0.0, +0.5, 1.0])
        ax.set_yticks([-1.0, -0.5, 0.0, +0.5, 1.0])

    return fig

def plot_cospectra(X, freqs, *, ylabels=None, title="Cospectra"):
    """Plot cospectral matrices.

    X : ndarray, shape (n_freqs, n_channels, n_channels)
        Cospectral matrices.
    freqs : ndarray, shape (n_freqs,)
        The frequencies associated to cospectra.
    ylabels : list of str, default=None
        ylabels of figure.
    title : str, default="Cospectra"
        Title of figure.

    fig : matplotlib figure
        Figure of cospectra.

    .. versionadded:: 0.2.7
    if X.ndim != 3:
        raise ValueError("Input X has not 3 dimensions")
    n_freqs, n_channels, _ = X.shape
    if freqs.shape != (n_freqs,):
        raise ValueError(
            "Input freqs has not the same number of frequencies as X")

    fig = plt.figure(figsize=(12, 7))
    for f in range(n_freqs):
        ax = plt.subplot((n_freqs - 1) // 8 + 1, 8, f + 1)
        plt.imshow(X[f], cmap=plt.get_cmap("Reds"))
        plt.title("{} Hz".format(freqs[f]))
        if ylabels and f == 0:
            plt.yticks(np.arange(0, len(ylabels), 2), ylabels[::2])
            ax.tick_params(axis="both", which="major", labelsize=7)
        elif ylabels and f == 8:
            plt.yticks(np.arange(1, len(ylabels), 2), ylabels[1::2])
            ax.tick_params(axis="both", which="major", labelsize=7)

    return fig

def plot_waveforms(X, display, *, times=None, color="gray", alpha=0.5,
                   linewidth=1.5, color_mean="k", color_std="gray", n_bins=50,
    """Plot repetitions of a multichannel waveform.

    X : ndarray, shape (n_reps, n_channels, n_times)
        Repetitions of the multichannel waveform.
    display : {"all", "mean", "mean+/-std", "hist"}
        Type of display:

        * "all" for all the repetitions;
        * "mean" for the mean of the repetitions;
        * "mean+/-std" for the mean +/- standard deviation of the repetitions;
        * "hist" for the 2D histogram of the repetitions.
    time : None | ndarray, shape (n_times,), default=None
        Values to display on x-axis.
    color : matplotlib color, optional
        Color of the lines, when ``display=all``.
    alpha : float, optional
        Alpha value used to cumulate repetitions, when ``display=all``.
    linewidth : float, optional
        Line width in points, when ``display=mean``.
    color_mean : matplotlib color, optional
        Color of the mean line, when ``display=mean``.
    color_std : matplotlib color, optional
        Color of the standard deviation area, when ``display=mean+/-std``.
    n_bins : int, optional
        Number of vertical bins for the 2D histogram, when ``display=hist``.
    cmap : Colormap or str, optional
        Color map for the histogram, when ``display=hist``.

    fig : matplotlib figure
        Figure of waveform (one subplot by channel).

    .. versionadded:: 0.3
    if X.ndim != 3:
        raise ValueError("Input X has not 3 dimensions")
    n_reps, n_channels, n_times = X.shape
    if times is None:
        times = np.arange(n_times)
    elif times.shape != (n_times,):
        raise ValueError(
            "Parameter times has not the same number of values as X")

    fig, axes = plt.subplots(nrows=n_channels, ncols=1)
    if n_channels == 1:
        axes = [axes]
    channels = np.arange(n_channels)

    if display == "all":
        for (channel, ax) in zip(channels, axes):
            for i_rep in range(n_reps):
                ax.plot(times, X[i_rep, channel], c=color, alpha=alpha)

    elif display in ["mean", "mean+/-std"]:
        mean = np.mean(X, axis=0)
        for (channel, ax) in zip(channels, axes):
            ax.plot(times, mean[channel], c=color_mean, lw=linewidth)
        if display == "mean+/-std":
            std = np.std(X, axis=0)
            for (channel, ax) in zip(channels, axes):
                ax.fill_between(times, mean[channel] - std[channel],
                                mean[channel] + std[channel], color=color_std)

    elif display == "hist":
        times_rep = np.repeat(times[np.newaxis, :], n_reps, axis=0)
        for (channel, ax) in zip(channels, axes):
            ax.hist2d(times_rep.ravel(), X[:, channel, :].ravel(),
                      bins=(n_times, n_bins), cmap=cmap)

        raise ValueError(f"Unknown parameter display {display}")

    if n_channels > 1:
        for ax in axes[:-1]:
            ax.set_xticklabels([])  # remove xticklabels
    return fig

def _add_alpha(colors, alphas):
    """Add alphas to RGB channels"""
        from matplotlib.colors import to_rgb
    except ImportError:
        raise ImportError("Install matplotlib to add alpha")

    cols = [to_rgb(c) for c in colors]
    return [(c[0], c[1], c[2], a) for c, a in zip(cols, alphas[-len(cols):])]

def plot_cov_ellipse(ax, X, n_std=2.5, **kwds):
    """Plot 2x2 covariance matrix as an ellipse.

    ax : matplotlib axis
        Axis of figure.
    X : ndarray, shape (2, 2)
        Covariance matrix.
    n_std : float, default=2.5
        Number of standard deviations.
    **kwds : dict
        Any further parameters are passed directly to the Ellipse.

    ax : matplotlib axis
        Axis of figure.

    .. versionadded:: 0.6

    .. [1]
    """  # noqa
    if X.shape != (2, 2):
        raise ValueError("Input X must be a 2x2 covariance matrix")

    pearson = X[0, 1] / np.sqrt(X[0, 0] * X[1, 1])
    ell_radius_x = np.sqrt(1 + pearson)
    ell_radius_y = np.sqrt(1 - pearson)
    ellipse = Ellipse((0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2,
                      facecolor='none', **kwds)
    scale_x = np.sqrt(X[0, 0]) * n_std
    scale_y = np.sqrt(X[1, 1]) * n_std
    transf = transforms.Affine2D().rotate_deg(45).scale(scale_x, scale_y)
    ellipse.set_transform(transf + ax.transData)

    return ax.add_patch(ellipse)

def plot_bihist(X, y, n_bins=10, title="Histogram"):
    """Plot histogram of bi-class predictions.

    X : ndarray, shape (n_matrices, 2)
        Predictions, distances or probabilities.
    y : ndarray, shape (n_matrices,)
        Labels for each matrix.
    n_bins : int, default=10
        Number of bins of histogram.
    title : str, default="Histogram"
        Title of figure.

    fig : matplotlib figure
        Figure of histogram.

    .. versionadded:: 0.6
    if X.ndim != 2:
        raise ValueError("Input X has not 2 dimensions")
    if X.shape[1] != 2:
        raise ValueError("Input X has not 2 classes")

    classes = np.unique(y)
    if classes.shape[0] != 2:
        raise ValueError("Input y has not 2 labels")

    X = X / np.sum(X, axis=1, keepdims=True)
    X0 = X[y == classes[0], 0]
    X1 = 1 - X[y == classes[1], 1]

    def get_bins(X, n_bins, target=0.5):
        """Estimate bins with the garantee to have target value in bin edges"""
        bins = np.histogram_bin_edges(X, bins=n_bins)
        idx = (np.abs(bins - target)).argmin()
        bins[idx] = target
        return bins

    fig, ax = plt.subplots(figsize=(6, 5))
    ax.axvline(x=0.5, c="k", linestyle=":")
    ax.hist(X0, bins=get_bins(X0, n_bins), label=classes[0], alpha=0.5)
    ax.hist(X1, bins=get_bins(X1, n_bins), label=classes[1], alpha=0.5)

    (Xmin, Xmax) = ax.get_xlim()
    Xm = min(Xmin, 1 - Xmax)
    ax.set_xlim(Xm, 1 - Xm)
    ax.set(xlabel="Rescaled predictions", ylabel="Frequency", title=title)
    ax.legend(title="Classes", loc="upper left")

    return fig

def plot_biscatter(X, y):
    """Plot scatter of bi-class predictions.

    X : ndarray, shape (n_matrices, 2)
        Predictions, distances or probabilities.
    y : ndarray, shape (n_matrices,)
        Labels for each matrix.

    fig : matplotlib figure
        Figure of scatter plot.

    .. versionadded:: 0.6

    if X.ndim != 2:
        raise ValueError("Input X has not 2 dimensions")
    if X.shape[1] != 2:
        raise ValueError("Input X has not 2 classes")

    classes = np.unique(y)
    if classes.shape[0] != 2:
        raise ValueError("Input y has not 2 labels")

    X0 = X[y == classes[0]]
    X1 = X[y == classes[1]]

    fig, ax = plt.subplots(figsize=(7, 7))
    ax.scatter(X0[:, 0], X0[:, 1], label=classes[0], alpha=1)
    ax.scatter(X1[:, 0], X1[:, 1], label=classes[1], alpha=0.5)
    ax.legend(title="Classes", loc="upper left")

    (Xmin, Xmax) = ax.get_xlim()
    (Ymin, Ymax) = ax.get_ylim()
    XYmin, XYmax = min(Xmin, Ymin), max(Xmax, Ymax)
    ax.plot([XYmin, XYmax], [XYmin, XYmax], c="k", linestyle=":")
    ax.set_xlim([XYmin, XYmax])
    ax.set_ylim([XYmin, XYmax])

    return fig