altair_recipes/scatterplot.py
"""Scatterplots."""
from .common import choose_kwargs, hue_scale_dark, hue_scale_light
from .signatures import bivariate_recipe, multivariate_recipe, color, tooltip
import altair as alt
from autosig import autosig, Signature, param
from numbers import Number
scatterplot_sig = Signature(
color=color(default=None, position=3),
opacity=param(
default=1,
position=4,
converter=float,
docstring="""`float`
A constant value for the opacity of the mark""",
),
tooltip=tooltip(default=None, position=5),
)
@autosig(bivariate_recipe + scatterplot_sig)
def scatterplot(
data=None, x=0, y=1, color=None, opacity=1, tooltip=None, height=600, width=800
):
"""Generate a scatterplot."""
if color is not None:
if isinstance(data[color].iloc[0], Number):
color = alt.Color(
color, scale=hue_scale_light if opacity == 1 else hue_scale_dark
)
opt_args = choose_kwargs(locals(), ["color", "tooltip"])
return alt.Chart(
data,
height=height,
width=width,
mark=alt.MarkDef(type="point" if opacity == 1 else "circle", opacity=opacity),
).encode(
x=alt.X(x, scale=alt.Scale(zero=False)),
y=alt.Y(y, scale=alt.Scale(zero=False)),
**opt_args
)
@autosig(multivariate_recipe + scatterplot_sig)
def multiscatterplot(
data=None,
columns=None,
group_by=None,
color=None,
opacity=1,
tooltip=None,
height=600,
width=800,
):
"""Generate many scatterplots.
Based on several columns, pairwise.
"""
opt_args = choose_kwargs(locals(), ["color", "tooltip"])
assert group_by is None, "Long format not supported yet"
return (
alt.Chart(data, height=height // len(columns), width=width // len(columns))
.mark_point(size=1 / len(columns), opacity=opacity)
.encode(
alt.X(
alt.repeat("column"), type="quantitative", scale=alt.Scale(zero=False)
),
alt.Y(alt.repeat("row"), type="quantitative", scale=alt.Scale(zero=False)),
**opt_args
)
.repeat(row=columns, column=columns)
)