src/tests/audition_tests/test_distance_from_best.py

Summary

Maintainability
A
0 mins
Test Coverage
from datetime import datetime, timedelta
from unittest.mock import patch

import factory
import numpy as np
import testing.postgresql
from sqlalchemy import create_engine

from triage.component.audition.distance_from_best import (
    DistanceFromBestTable,
    BestDistancePlotter,
)
from triage.component.catwalk.db import ensure_db

from tests.results_tests.factories import (
    EvaluationFactory,
    ModelFactory,
    ModelGroupFactory,
    init_engine,
    session,
)

from .utils import create_sample_distance_table


def _sql_add_days(sql_date, days):
    return datetime.strftime(
        datetime.strptime(sql_date, "%Y-%m-%d") + timedelta(days=days), "%Y-%m-%d"
    )


def test_DistanceFromBestTable():
    with testing.postgresql.Postgresql() as postgresql:
        engine = create_engine(postgresql.url())
        ensure_db(engine)
        init_engine(engine)
        model_groups = {
            "stable": ModelGroupFactory(model_type="myStableClassifier"),
            "bad": ModelGroupFactory(model_type="myBadClassifier"),
            "spiky": ModelGroupFactory(model_type="mySpikeClassifier"),
        }

        class StableModelFactory(ModelFactory):
            model_group_rel = model_groups["stable"]

        class BadModelFactory(ModelFactory):
            model_group_rel = model_groups["bad"]

        class SpikyModelFactory(ModelFactory):
            model_group_rel = model_groups["spiky"]

        models = {
            "stable_3y_ago": StableModelFactory(train_end_time="2014-01-01"),
            "stable_2y_ago": StableModelFactory(train_end_time="2015-01-01"),
            "stable_1y_ago": StableModelFactory(train_end_time="2016-01-01"),
            "bad_3y_ago": BadModelFactory(train_end_time="2014-01-01"),
            "bad_2y_ago": BadModelFactory(train_end_time="2015-01-01"),
            "bad_1y_ago": BadModelFactory(train_end_time="2016-01-01"),
            "spiky_3y_ago": SpikyModelFactory(train_end_time="2014-01-01"),
            "spiky_2y_ago": SpikyModelFactory(train_end_time="2015-01-01"),
            "spiky_1y_ago": SpikyModelFactory(train_end_time="2016-01-01"),
        }

        class ImmediateEvalFactory(EvaluationFactory):
            evaluation_start_time = factory.LazyAttribute(
                lambda o: o.model_rel.train_end_time
            )
            evaluation_end_time = factory.LazyAttribute(
                lambda o: _sql_add_days(o.model_rel.train_end_time, 1)
            )

        class MonthOutEvalFactory(EvaluationFactory):
            evaluation_start_time = factory.LazyAttribute(
                lambda o: _sql_add_days(o.model_rel.train_end_time, 31)
            )
            evaluation_end_time = factory.LazyAttribute(
                lambda o: _sql_add_days(o.model_rel.train_end_time, 32)
            )

        class Precision100Factory(ImmediateEvalFactory):
            metric = "precision@"
            parameter = "100_abs"

        class Precision100FactoryMonthOut(MonthOutEvalFactory):
            metric = "precision@"
            parameter = "100_abs"

        class Recall100Factory(ImmediateEvalFactory):
            metric = "recall@"
            parameter = "100_abs"

        class Recall100FactoryMonthOut(MonthOutEvalFactory):
            metric = "recall@"
            parameter = "100_abs"

        for (add_val, PrecFac, RecFac) in (
            (0, Precision100Factory, Recall100Factory),
            (-0.15, Precision100FactoryMonthOut, Recall100FactoryMonthOut),
        ):
            PrecFac(model_rel=models["stable_3y_ago"], stochastic_value=0.6 + add_val)
            PrecFac(model_rel=models["stable_2y_ago"], stochastic_value=0.57 + add_val)
            PrecFac(model_rel=models["stable_1y_ago"], stochastic_value=0.59 + add_val)
            PrecFac(model_rel=models["bad_3y_ago"], stochastic_value=0.4 + add_val)
            PrecFac(model_rel=models["bad_2y_ago"], stochastic_value=0.39 + add_val)
            PrecFac(model_rel=models["bad_1y_ago"], stochastic_value=0.43 + add_val)
            PrecFac(model_rel=models["spiky_3y_ago"], stochastic_value=0.8 + add_val)
            PrecFac(model_rel=models["spiky_2y_ago"], stochastic_value=0.4 + add_val)
            PrecFac(model_rel=models["spiky_1y_ago"], stochastic_value=0.4 + add_val)
            RecFac(model_rel=models["stable_3y_ago"], stochastic_value=0.55 + add_val)
            RecFac(model_rel=models["stable_2y_ago"], stochastic_value=0.56 + add_val)
            RecFac(model_rel=models["stable_1y_ago"], stochastic_value=0.55 + add_val)
            RecFac(model_rel=models["bad_3y_ago"], stochastic_value=0.35 + add_val)
            RecFac(model_rel=models["bad_2y_ago"], stochastic_value=0.34 + add_val)
            RecFac(model_rel=models["bad_1y_ago"], stochastic_value=0.36 + add_val)
            RecFac(model_rel=models["spiky_3y_ago"], stochastic_value=0.35 + add_val)
            RecFac(model_rel=models["spiky_2y_ago"], stochastic_value=0.8 + add_val)
            RecFac(model_rel=models["spiky_1y_ago"], stochastic_value=0.36 + add_val)
        session.commit()
        distance_table = DistanceFromBestTable(
            db_engine=engine, models_table="models", distance_table="dist_table", agg_type="worst"
        )
        metrics = [
            {"metric": "precision@", "parameter": "100_abs"},
            {"metric": "recall@", "parameter": "100_abs"},
        ]
        model_group_ids = [mg.model_group_id for mg in model_groups.values()]
        distance_table.create_and_populate(
            model_group_ids, ["2014-01-01", "2015-01-01", "2016-01-01"], metrics
        )

        # get an ordered list of the model groups for a particular metric/time
        query = """
            select model_group_id, raw_value, dist_from_best_case, dist_from_best_case_next_time
            from dist_table where metric = %s and parameter = %s and train_end_time = %s
            order by dist_from_best_case
        """

        prec_3y_ago = engine.execute(query, ("precision@", "100_abs", "2014-01-01"))
        assert [row for row in prec_3y_ago] == [
            (models["spiky_3y_ago"].model_group_id, 0.8, 0, 0.17),
            (models["stable_3y_ago"].model_group_id, 0.6, 0.2, 0),
            (models["bad_3y_ago"].model_group_id, 0.4, 0.4, 0.18),
        ]

        recall_2y_ago = engine.execute(query, ("recall@", "100_abs", "2015-01-01"))
        assert [row for row in recall_2y_ago] == [
            (models["spiky_2y_ago"].model_group_id, 0.8, 0, 0.19),
            (models["stable_2y_ago"].model_group_id, 0.56, 0.24, 0),
            (models["bad_2y_ago"].model_group_id, 0.34, 0.46, 0.19),
        ]

        assert distance_table.observed_bounds == {
            ("precision@", "100_abs"): (0.39, 0.8),
            ("recall@", "100_abs"): (0.34, 0.8),
        }


def test_BestDistancePlotter():
    with testing.postgresql.Postgresql() as postgresql:
        engine = create_engine(postgresql.url())
        distance_table, model_groups = create_sample_distance_table(engine)
        plotter = BestDistancePlotter(distance_table)
        df_dist = plotter.generate_plot_data(
            metric="precision@",
            parameter="100_abs",
            model_group_ids=[1, 2],
            train_end_times=["2014-01-01", "2015-01-01"],
        )
        # assert that we have the right # of columns and a row for each % diff value
        # 202 row because 101 percentiles (0-100 inclusive), 2 model groups
        assert df_dist.shape == (101 * 2, 5)

        # all of the model groups are within .34 of the best, so pick
        # a number higher than that and all should qualify
        for value in df_dist[df_dist["distance"] == 0.35]["pct_of_time"].values:
            assert np.isclose(value, 1.0)

        # model group 1 (stable) should be within 0.11 1/2 of the time
        # if we included 2016 in the train_end_times, this would be 1/3!
        for value in df_dist[
            (df_dist["distance"] == 0.11) & (df_dist["model_group_id"] == 1)
        ]["pct_of_time"].values:
            assert np.isclose(value, 0.5)


def test_BestDistancePlotter_plot():
    with patch("triage.component.audition.distance_from_best.plot_cats") as plot_patch:
        with testing.postgresql.Postgresql() as postgresql:
            engine = create_engine(postgresql.url())
            distance_table, model_groups = create_sample_distance_table(engine)
            plotter = BestDistancePlotter(distance_table)
            plotter.plot_all_best_dist(
                [{"metric": "precision@", "parameter": "100_abs"}],
                model_group_ids=[1, 2],
                train_end_times=["2014-01-01", "2015-01-01"],
            )
        assert plot_patch.called
        args, kwargs = plot_patch.call_args
        assert "distance" in kwargs["frame"]
        assert "pct_of_time" in kwargs["frame"]
        assert kwargs["x_col"] == "distance"
        assert kwargs["y_col"] == "pct_of_time"


def test_BestDistancePlotter_plot_bounds():
    class FakeDistanceTable:
        @property
        def observed_bounds(self):
            return {
                ("precision@", "100_abs"): (0.02, 0.87),
                ("recall@", "100_abs"): (0.0, 1.0),
                ("false positives@", "300_abs"): (2, 162),
            }

    plotter = BestDistancePlotter(FakeDistanceTable())
    assert plotter.plot_bounds("precision@", "100_abs") == (0.0, 1.0)
    assert plotter.plot_bounds("recall@", "100_abs") == (0.0, 1.0)
    assert plotter.plot_bounds("false positives@", "300_abs") == (2, 178)