piccolbo/altair_recipes

View on GitHub
altair_recipes/autoplot.py

Summary

Maintainability
C
7 hrs
Test Coverage
"""Automatic plot selection."""
from .barchart import barchart
from .boxplot import boxplot
from .common import col_cardinality
from .heatmap import heatmap
from .histogram import histogram
from .scatterplot import scatterplot
from .signatures import multivariate_recipe
from .stripplot import stripplot
from autosig import autosig
from numbers import Number
import pandas as pd


def bin_midpoints(xx, n):
    """Divide a vector in equal-width bins and return the midpoints if it is Numeric, otherwise the vector itself.

    Parameters
    ----------
    xx : Iterable
        The vector to be binned or other iterable to return as-is.
    n : integer
        Number of bins.

    Returns
    -------
    list or same as xx
        The midpoints or xx itself.

    """
    return (
        list(map(lambda x: x.mid, list(pd.cut(xx, n))))
        if issubclass(type(xx.iloc[0]), Number)
        else xx
    )


def overlap(data):
    """Find how dense a set of points is in relation to plotting them.

    Parameters
    ----------
    data : DataFrame
        The data whose density needs to be evaluated.

    Returns
    -------
    integer
        The maximum number of data points having identical coordinates on categorical dimensions or falling within the same bin when Numeric, after an equal-width binning operation into 100 bins (100 is somewhat arbitrary, experience based).

    """
    return data.apply(bin_midpoints, n=100).groupby(list(data.columns)).size().max()


def is_cat(xx):
    """Whether a list or vector is categorical (non Numeric).

    Parameters
    ----------
    xx : list or vector
        A list or vector.

    Returns
    -------
    bool
        Whether `xx` is categorical.

    """
    return not issubclass(type(xx.iloc[0]), Number)


@multivariate_recipe
def autoplot(data=None, columns=None, group_by=None, height=600, width=800):
    """Automatically choose and produce a statistical graphics based on up to three columns of data."""

    assert group_by is None, "long data not supported yet"
    vars_n = len(columns)
    assert vars_n <= 3, "Only up to three vars supported at this time"
    data = data[columns]
    columns = sorted(columns, key=lambda x: -len(data[x].unique()))
    y, x, z, *_ = columns + 2 * [None]
    overlap_deg = overlap(data[columns[: min(vars_n, 2)]])
    max_overlap = 10
    high_overlap = overlap_deg >= max_overlap
    no_overlap = overlap_deg == 1
    low_overlap = 1 < overlap_deg < max_overlap
    cat_vars_n = sum(map(lambda col: is_cat(data[col]), columns))
    numeric_vars_n = vars_n - cat_vars_n

    chart_type_selection = [
        (scatterplot, numeric_vars_n >= 2 and not high_overlap),
        (heatmap, numeric_vars_n >= 2 and high_overlap),
        (stripplot, numeric_vars_n == 1 and not high_overlap),
        (barchart, numeric_vars_n == 0),
        (histogram, numeric_vars_n == 1 and cat_vars_n == 0 and high_overlap),
        (boxplot, numeric_vars_n == 1 and cat_vars_n >= 1 and high_overlap),
    ]
    chart_type = list(filter(lambda x: x[1], chart_type_selection))
    assert len(chart_type) == 1
    chart_type = chart_type[0][0]
    use_facet = cat_vars_n >= 2 or (
        cat_vars_n == 1 and numeric_vars_n == 2 and not no_overlap
    )
    use_color = vars_n == 3 and not (use_facet and chart_type is scatterplot)
    use_opacity = (chart_type in (scatterplot, stripplot) and low_overlap) or (
        chart_type is heatmap and high_overlap and numeric_vars_n == 3
    )  # heat

    if chart_type is barchart:
        args = dict(x=y, y="count()", height=height, width=width)
        if use_facet:
            args.update(x=x, color=x, width=width // col_cardinality(data, y))
            facet_args = dict(column=y)
            if z is not None:
                args.update(height=height // col_cardinality(data, z))
                facet_args.update(row=z)
        chart = barchart(data, **args)
        if use_facet:
            chart = chart.facet(**facet_args)

    if chart_type is boxplot:
        chart = boxplot(
            data,
            columns=y,
            group_by=x,
            color=use_facet,  # this is a redundant use of color, different from necessary uses. Encodes x again
            height=height,
            width=width // col_cardinality(data, z),
        )
        if use_facet:
            chart = chart.facet(column=z)
    if chart_type is heatmap:
        args = dict(color="count()")
        if use_opacity:
            args.update(color=z, opacity="count()")
        chart = heatmap(
            data,
            x=x,
            y=y,
            aggregate="average" if use_opacity else "count",
            height=height // col_cardinality(data, z, use_facet),
            width=width // col_cardinality(data, z, use_facet),
            **args
        )
        if use_facet:
            chart = chart.facet(row=z)
    if chart_type is histogram:
        chart = histogram(data, column=y, height=height, width=width)
    if chart_type is stripplot:
        chart = stripplot(
            data,
            columns=y,
            group_by=x,
            color=x if use_facet else None,
            opacity=1 / overlap_deg if use_opacity else 1,
            height=height,
            width=width // col_cardinality(data, z, use_facet),
        )
        if use_facet:
            chart = chart.facet(column=z)
    if chart_type is scatterplot:
        chart = scatterplot(
            data,
            x=x,
            y=y,
            color=z if use_color else None,
            opacity=1 / overlap_deg if use_opacity else 1,
            height=height // col_cardinality(data, z, use_facet),
            width=width // col_cardinality(data, z, use_facet)
        )
        if use_facet:
            chart = chart.facet(row=z)
    return chart