LucaCappelletti94/histograms

View on GitHub
barplots/utils/get_axes.py

Summary

Maintainability
B
5 hrs
Test Coverage
A
100%
from matplotlib.figure import Figure
from matplotlib.axes import Axes
import pandas as pd
import numpy as np
from .get_max_bar_position import get_max_bar_position
from typing import Tuple, Dict, List
import matplotlib.pyplot as plt
from scipy.constants import golden_ratio
from math import ceil
from sanitize_ml_labels import sanitize_ml_labels
from .get_best_match import get_best_match


def swap(*args: List, flag: bool) -> List:
    """If the given flag is true returns """
    return args if flag else reversed(args)


def get_axes(
    df: pd.DataFrame,
    bar_width: float,
    space_width: float,
    height: float,
    dpi: int,
    title: str,
    data_label: str,
    vertical: bool,
    subplots: bool,
    titles: List[str],
    plots_per_row: int,
    custom_defaults: Dict[str, List[str]],
    expected_levels: int,
    scale: str,
    facecolors: Dict[str, str],
    show_title: bool,
    show_column_name: bool,
) -> Tuple[Figure, Axes]:
    """Setup axes for barplot plotting.

    Parameters
    ----------
    df: pd.DataFrame,
        Dataframe from which to obtain the curresponding barplot width.
    bar_width: float,
        Width of bars in considered barplot.
    space_width: float,
        Width of spaces between spaces.
    height: float,
        Height of considered barplot.
    dpi: int,
        DPI for rendered images.
    title: str,
        Title of the considered barplot.
    data_label: str,
        barplot's data_label. None for not showing any data_label (default).
    vertical: bool,
        Whetever to build the axis to show the bars as vertical or as horizontal.
    expected_levels: int,
        Number of levels to expect to plot as labels.
    scale: str,
        Scale to use for the barplots.
        Can either be "linear" or "log".
    facecolors: Dict[str, str],
        Colors for the background.
    show_title: str = True,
        Whetever to show or not the barplot title.
    show_column_name: bool = True
        Whether to show the metric name.

    Returns
    -----------
    Tuple containing new figure and axis.
    """
    if subplots:
        side = max(
            get_max_bar_position(df.loc[index], bar_width, space_width)
            for index in df.index.levels[0]
        )
    else:
        side = get_max_bar_position(df, bar_width, space_width)

    if height is None:
        exponent = 1 if subplots or expected_levels > 1 else 1.5
        height = side/(golden_ratio**exponent)

    if subplots:
        nrows = ceil(df.index.levels[0].size/plots_per_row)
    else:
        nrows = plots_per_row = 1

    width, height = swap(side, height, flag=vertical)
    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=plots_per_row,
        figsize=(width*plots_per_row, height*nrows),
        dpi=dpi
    )
    fig.patch.set_facecolor("white")

    if isinstance(axes, Axes):
        axes = np.array([axes])

    axes = axes.flatten()

    for subtitle, ax in zip(titles, axes):
        ax.set_facecolor(get_best_match(facecolors, subtitle))
        if vertical:
            ax.set_yscale(scale)
            ax.set_xlim(0, side)
            ax.set_xticks([])
            ax.yaxis.grid(True, which="both")
            if data_label is not None and show_column_name:
                ax.set_ylabel(sanitize_ml_labels(
                    data_label,
                    custom_defaults=custom_defaults
                ))
        else:
            ax.set_xscale(scale)
            ax.set_ylim(0, side)
            ax.set_yticks([])
            ax.xaxis.grid(True, which="both")
            if data_label is not None and show_column_name:
                ax.set_xlabel(sanitize_ml_labels(
                    data_label,
                    custom_defaults=custom_defaults
                ))
        if show_title:
            ax.set_title(sanitize_ml_labels(
                subtitle, custom_defaults=custom_defaults))

    for ax in axes[len(titles):]:
        ax.grid(False)
        ax.axis('off')

    if title is not None and len(axes) == 1 and show_title:
        axes[0].set_title(sanitize_ml_labels(
            title, custom_defaults=custom_defaults))

    return fig, axes