LucaCappelletti94/plot_keras_history

View on GitHub
plot_keras_history/plot_keras_history.py

Summary

Maintainability
F
4 days
Test Coverage
A
92%
"""Methods for plotting a keras model training history."""
import warnings
import matplotlib.pyplot as plt
from typing import List, Dict, Union, Tuple, Callable, Optional
import os
import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from sanitize_ml_labels import sanitize_ml_labels, is_normalized_metric, is_absolutely_normalized_metric
from .utils import to_dataframe, get_figsize, filter_signal, get_column_tuples, filter_columns, History


def _plot_history(
    histories: List[pd.DataFrame],
    average_history: Optional[pd.DataFrame] = None,
    standard_deviation_history: Optional[pd.DataFrame] = None,
    style: str = "-",
    interpolate: bool = False,
    side: float = 5,
    graphs_per_row: int = 4,
    dpi: int = 100,
    customization_callback: Optional[Callable] = None,
    path: Optional[str] = None,
    log_scale_metrics: bool = False,
    show_standard_deviation: bool = False,
    show_average: bool = True,
    monitor: Optional[str] = None,
    best_point_x: Optional[int] = None,
    title: Optional[str] = None,
    custom_defaults: Optional[Dict[str, Union[List[str], str]]] = None
) -> Tuple[Figure, Axes]:
    """Plot given training histories.

    Parameters
    -------------------------------
    histories: List[pd.DataFrame]
        The histories to plot.
    average_history: pd.DataFrame = None
        Average histories, if multiple histories were given.
    standard_deviation_history: Optional[pd.DataFrame] = None
        Standard deviation histories, if multiple histories were given.
    style: str = "-"
        The style to use when plotting the graphs.
    interpolate: bool = False
        Whetever to reduce the graphs noise.
    side: int = 5
        The side of every sub-graph.
    graphs_per_row: int = 4
        Number of graphs per row.
    dpi: int = 100
        DPI value to be used for the plots.
        By default we use matplotlib defaults, 100.
    customization_callback: Callable = None
        Callback for customising axis.
    path:str = None
        Where to save the graphs, by defalut nowhere.
    monitor: str = None
        Metric to use to display best points.
        For example you may use "loss" or "val_loss".
        By default None, to not display any best point.
    log_scale_metrics: bool = False
        Whether to use log scale for the metrics.
    show_standard_deviation: bool = False
        Whether to show the standard deviation when
        plotting multiple training histories.
    show_average: bool = True
        Whether to show the average when
        plotting multiple training histories.
    best_point_x: int = None
        Point to be highlighted as best.
    title: str = None
        Title to put on top of the subplots.
    custom_defaults: Dict[str, Union[List[str], str]] = None
        Dictionary of custom mapping to use to sanitize metric names.
    """
    x_label = "Epochs" if histories[0].index.name is None else histories[0].index.name
    metrics = [
        c[0]
        for c in get_column_tuples(histories[0])
    ]
    number_of_metrics = len(metrics)
    w, h = get_figsize(number_of_metrics, graphs_per_row)
    fig, axes = plt.subplots(h, w, figsize=(
        side*w, side*h), constrained_layout=True, dpi=dpi)
    flat_axes = np.array(axes).flatten()

    if show_average and average_history is not None:
        histories = [average_history] + histories

    for i, history in enumerate(histories):
        for metric, axis in zip(metrics, flat_axes):
            for name, kind, color in zip(
                *(
                    ((metric, f"val_{metric}"), ("Train", "Test"), ("tab:blue", "tab:orange"))
                    if f"val_{metric}" in history
                    else ((metric, ), ("", ), ("tab:blue",))
                )
            ):
                col = history[name]
                if is_normalized_metric(metric):
                    min_value = col.values.min()
                    max_value = col.values.max()
                    if min_value < 0.0 or max_value > 1.0:
                        warnings.warn(
                            (
                                "Please be advised that you have provided a metric called `{metric}` "
                                "that is expected to be normalized, i.e. between 0 and 1. The values "
                                "you have provided for this metric were between {min_value:0.3f} and "
                                "{max_value:0.3f}."
                            ).format(
                                metric=metric,
                                min_value=min_value,
                                max_value=max_value
                            )
                        )
                if is_absolutely_normalized_metric(metric):
                    min_value = col.values.min()
                    max_value = col.values.max()
                    if min_value < -1.0 or max_value > 1.0:
                        warnings.warn(
                            (
                                "Please be advised that you have provided a metric called `{metric}` "
                                "that is expected to be absolutely normalized, i.e. between -1 and 1. The values "
                                "you have provided for this metric were between {min_value:0.3f} and "
                                "{max_value:0.3f}."
                            ).format(
                                metric=metric,
                                min_value=min_value,
                                max_value=max_value
                            )
                        )
                if i == 0:
                    if best_point_x is not None:
                        best_point_y = col.values[best_point_x]
                        if len(kind) == 0:
                            kind = f"Best value ({monitor})"
                        else:
                            kind = f"{kind} best value ({monitor})"
                    else:
                        best_point_y = col.iloc[-1]
                        if len(kind) == 0:
                            kind = f"Last value"
                        else:
                            kind = f"{kind} last value"

                    values = filter_signal(
                        col.values
                    ) if interpolate else col.values

                    if show_standard_deviation and standard_deviation_history is not None:
                        axis.fill_between(
                            col.index.values,
                            values-standard_deviation_history[name].values,
                            values+standard_deviation_history[name].values,
                            color=color,
                            alpha=0.1
                        )
                        axis.plot(
                            col.index.values,
                            values-standard_deviation_history[name].values,
                            color=color,
                            linewidth=0.5,
                            alpha=0.1
                        )
                        axis.plot(
                            col.index.values,
                            values+standard_deviation_history[name].values,
                            color=color,
                            linewidth=0.5,
                            alpha=0.1
                        )
                    line = axis.plot(
                        col.index.values,
                        values,
                        style,
                        label='{kind}: {val:0.4f}'.format(
                            kind=kind,
                            val=best_point_y
                        ),
                        linewidth=2 if len(histories) > 1 else 1,
                        color=color,
                        zorder=10000
                    )[0]
                    if best_point_x is not None:
                        best_point_y = col.values[best_point_x]
                        axis.scatter(
                            [best_point_x],
                            [best_point_y],
                            s=30,
                            alpha=0.9,
                            color=line.get_color(),
                            zorder=10000
                        )
                        axis.hlines(
                            best_point_y,
                            0,
                            best_point_x,
                            linestyles="dashed",
                            color=line.get_color(),
                            alpha=0.5,
                        )
                        axis.vlines(
                            best_point_x,
                            0,
                            best_point_y,
                            linestyles="dashed",
                            color=line.get_color(),
                            alpha=0.6,
                        )
                else:
                    axis.plot(
                        col.index.values,
                        filter_signal(
                            col.values) if interpolate else col.values,
                        style,
                        color=color,
                        alpha=0.5
                    )

    for metric, axis in zip(metrics, flat_axes):
        alias = sanitize_ml_labels(metric, custom_defaults=custom_defaults)
        axis.set_xlabel(x_label)
        if log_scale_metrics:
            axis.set_yscale("log")
        axis.set_ylabel("{alias}{scale}".format(
            alias=alias,
            scale=" (Log scale)" if log_scale_metrics else ""
        ))
        axis.set_title(alias)
        axis.grid(True)
        axis.legend()
        if is_normalized_metric(metric):
            axis.set_ylim(-0.05, 1.05)
        elif is_absolutely_normalized_metric(metric):
            axis.set_ylim(-1.05, 1.05)
        if history.shape[0] <= 4:
            axis.set_xticks(range(history.shape[0]))
        if customization_callback is not None:
            customization_callback(axis)

    for axis in flat_axes[len(metrics):]:
        axis.axis("off")

    if title is not None:
        fig.suptitle(title, fontsize=20)

    if path is not None:
        fig.savefig(path)

    return fig, axes


def plot_history(
    histories: Union[History, List[History], Dict[str, List[float]], pd.DataFrame, List[pd.DataFrame], str, List[str]],
    style: str = "-",
    interpolate: bool = False,
    side: float = 5,
    graphs_per_row: int = 4,
    dpi: int = 100,
    customization_callback: Optional[Callable] = None,
    path: Optional[str] = None,
    single_graphs: bool = False,
    max_epochs: Union[int, str] = "max",
    monitor: Optional[str] = None,
    monitor_mode: str = "max",
    log_scale_metrics: bool = False,
    show_standard_deviation: bool = False,
    show_average: bool = True,
    title: Optional[str] = None,
    custom_defaults: Optional[Dict[str, Union[List[str], str]]] = None
) -> Tuple[Union[Figure, List[Figure]], Union[Axes, List[Axes]]]:
    """Plot given training histories.

    Parameters
    ----------------------------
    histories
        the histories to plot.
        This parameter can either be a single or multiple dataframes
        or one or more paths to the stored CSVs or JSON of the history.
    style: str = "-"
        the style to use when plotting the graphs.
    interpolate: bool = False
        whetever to reduce the graphs noise.
    side: int = 5
        the side of every sub-graph.
    graphs_per_row: int = 4
        number of graphs per row.
    dpi: int = 100
        DPI value to be used for the plots.
        By default we use matplotlib defaults, 100.
    customization_callback: Callable = None
        callback for customising axis.
    path: str = None
        where to save the graphs, by defalut nowhere.
    single_graphs: bool = False
        whetever to create the graphs one by one.
    max_epochs: Union[int, str] = "max"
        Number of epochs to plot. Can either be "max", "min" or a positive integer value.
    monitor: str = None
        Metric to use to display best points.
        For example you may use "loss" or "val_loss".
        By default None, to not display any best point.
    monitor_mode: str = "max"
        Mode to display the monitor metric best point.
        Can either be "max" or "min".
    log_scale_metrics: bool = False
        Whether to use log scale for the metrics.
    show_standard_deviation: bool = False
        Whether to show the standard deviation when
        plotting multiple training histories.
    show_average: bool = True
        Whether to show the average when
        plotting multiple training histories.
    title: str = None,
        Title to put on top of the subplots.
    custom_defaults: Dict[str, Union[List[str], str]] = None
        Dictionary of custom mapping to use to sanitize metric names.

    Raises
    --------------------------
    ValueError
        Currently the monitor metric best point cannot be displayed if interpolation is active.
    ValueError
        If monitor_mode is not either "min" or "max".
    ValueError
        If max_epochs is not either "min", "max" or a numeric integer.
    """
    # Some parameters validation
    if interpolate and monitor is not None:
        raise ValueError((
            "Currently the monitor metric best point "
            "cannot be displayed if interpolation is active."
        ))
    if monitor_mode not in ("min", "max"):
        raise ValueError("Given monitor mode '{}' is not supported.".format(
            monitor_mode
        ))
    if max_epochs not in ("min", "max") and not isinstance(max_epochs, int):
        raise ValueError("Given parameter max_epochs '{}' is not supported.".format(
            max_epochs
        ))
    # If the histories are not provided as a list, we normalized it
    # to a list.
    if not isinstance(histories, list):
        histories = [histories]
    # If the path is not None, we prepare the directory where to
    # store the created image(s).
    if path is not None:
        directory_name = os.path.dirname(path)
        # The directory name may be an empty string.
        if directory_name:
            os.makedirs(directory_name, exist_ok=True)

    # Normalize the training histories.
    histories = [
        to_dataframe(history)._get_numeric_data()
        for history in histories
    ]

    # Filter out the epochs as required.
    if max_epochs in ("max", "min"):
        epochs = [
            len(history)
            for history in histories
        ]
        if max_epochs == "max":
            max_epochs = max(epochs)

        if max_epochs == "min":
            max_epochs = min(epochs)

    histories = [
        history[:max_epochs]
        for history in histories
    ]

    if len(histories) > 1:
        grouped_histories = pd.concat(histories)
        average_history = grouped_histories.groupby(grouped_histories.index).mean()
        standard_deviation_history = grouped_histories.groupby(grouped_histories.index).std()
    else:
        average_history = standard_deviation_history =  None

    # If we want to plot informations relative to the monitored metrics
    if monitor is not None:
        history_to_monitor = (
            histories[0] if average_history is None else average_history)[monitor]
        if monitor_mode == "max":
            best_point_x = history_to_monitor.argmax()
        elif monitor_mode == "min":
            best_point_x = history_to_monitor.argmin()
    else:
        best_point_x = None

    if single_graphs:
        return list(zip(*[
            _plot_history(
                filter_columns(histories, columns),
                average_history,
                standard_deviation_history,
                style,
                interpolate,
                side,
                graphs_per_row,
                dpi,
                customization_callback,
                path = None if path is None else "{path}/{c}.png".format(path=path, c=columns[0]),
                log_scale_metrics=log_scale_metrics,
                show_standard_deviation=show_standard_deviation,
                monitor=sanitize_ml_labels(
                    monitor,
                    custom_defaults=custom_defaults
                ),
                best_point_x=best_point_x,
                title=title,
                custom_defaults=custom_defaults,
            )
            for columns in get_column_tuples(histories[0])
        ]))
    else:
        return _plot_history(
            histories,
            average_history,
            standard_deviation_history,
            style,
            interpolate,
            side,
            graphs_per_row,
            dpi,
            customization_callback,
            path,
            log_scale_metrics=log_scale_metrics,
            show_standard_deviation=show_standard_deviation,
            show_average=show_average,
            monitor=sanitize_ml_labels(
                monitor,
                custom_defaults=custom_defaults
            ),
            best_point_x=best_point_x,
            title=title,
            custom_defaults=custom_defaults,
        )


def show_history(
    histories: Union[History, List[History], Dict[str, List[float]], pd.DataFrame, List[pd.DataFrame], str, List[str]],
    style: str = "-",
    interpolate: bool = False,
    side: float = 5,
    graphs_per_row: int = 4,
    dpi: int = 100,
    customization_callback: Optional[Callable] = None,
    path: Optional[str] = None,
    single_graphs: bool = False,
    max_epochs: Union[int, str] = "max",
    monitor: Optional[str] = None,
    monitor_mode: str = "max",
    log_scale_metrics: bool = False,
    show_standard_deviation: bool = False,
    show_average: bool = True,
    title: Optional[str] = None,
    custom_defaults: Optional[Dict[str, Union[List[str], str]]] = None
) -> Tuple[Union[Figure, List[Figure]], Union[Axes, List[Axes]]]:
    """Plot given training histories.

    Parameters
    ----------------------------
    histories
        the histories to plot.
        This parameter can either be a single or multiple dataframes
        or one or more paths to the stored CSVs or JSON of the history.
    style: str = "-"
        the style to use when plotting the graphs.
    interpolate: bool = False
        whetever to reduce the graphs noise.
    side: int = 5
        the side of every sub-graph.
    graphs_per_row: int = 4
        number of graphs per row.
    dpi: int = 100
        DPI value to be used for the plots.
        By default we use matplotlib defaults, 100.
    customization_callback: Callable = None
        callback for customising axis.
    path: str = None
        where to save the graphs, by defalut nowhere.
    single_graphs: bool = False
        whetever to create the graphs one by one.
    max_epochs: Union[int, str] = "max"
        Number of epochs to plot. Can either be "max", "min" or a positive integer value.
    monitor: str = None
        Metric to use to display best points.
        For example you may use "loss" or "val_loss".
        By default None, to not display any best point.
    monitor_mode: str = "max"
        Mode to display the monitor metric best point.
        Can either be "max" or "min".
    log_scale_metrics: bool = False
        Whether to use log scale for the metrics.
    show_standard_deviation: bool = False
        Whether to show the standard deviation when
        plotting multiple training histories.
    show_average: bool = True
        Whether to show the average when
        plotting multiple training histories.
    title: str = None
        Title to put on top of the subplots.
    custom_defaults: Dict[str, Union[List[str], str]] = None
        Dictionary of custom mapping to use to sanitize metric names.

    Raises
    --------------------------
    ValueError
        Currently the monitor metric best point cannot be displayed if interpolation is active.
    ValueError
        If monitor_mode is not either "min" or "max".
    ValueError
        If max_epochs is not either "min", "max" or a numeric integer.
    """
    plot_history(
        histories=histories,
        style=style,
        interpolate=interpolate,
        side=side,
        graphs_per_row=graphs_per_row,
        dpi=dpi,
        customization_callback=customization_callback,
        path=path,
        single_graphs=single_graphs,
        max_epochs=max_epochs,
        monitor=monitor,
        monitor_mode=monitor_mode,
        log_scale_metrics=log_scale_metrics,
        show_standard_deviation=show_standard_deviation,
        show_average=show_average,
        title=title,
        custom_defaults=custom_defaults,
    )
    plt.show()