src/tests/audition_tests/utils.py

Summary

Maintainability
A
1 hr
Test Coverage
from triage.component.audition.distance_from_best import DistanceFromBestTable
from triage.component.catwalk.db import ensure_db

from tests.results_tests.factories import (
    ModelFactory,
    ModelGroupFactory,
    init_engine,
    session,
)


def create_sample_distance_table(engine):
    ensure_db(engine)
    init_engine(engine)
    model_groups = {
        "stable": ModelGroupFactory(model_type="myStableClassifier"),
        "spiky": ModelGroupFactory(model_type="mySpikeClassifier"),
    }

    class StableModelFactory(ModelFactory):
        model_group_rel = model_groups["stable"]

    class SpikyModelFactory(ModelFactory):
        model_group_rel = model_groups["spiky"]

    models = {
        "stable_3y_ago": StableModelFactory(train_end_time="2014-01-01"),
        "stable_2y_ago": StableModelFactory(train_end_time="2015-01-01"),
        "stable_1y_ago": StableModelFactory(train_end_time="2016-01-01"),
        "spiky_3y_ago": SpikyModelFactory(train_end_time="2014-01-01"),
        "spiky_2y_ago": SpikyModelFactory(train_end_time="2015-01-01"),
        "spiky_1y_ago": SpikyModelFactory(train_end_time="2016-01-01"),
    }
    session.commit()
    distance_table = DistanceFromBestTable(
        db_engine=engine, models_table="models", distance_table="dist_table", agg_type="worst"
    )
    distance_table._create()
    stable_grp = model_groups["stable"].model_group_id
    spiky_grp = model_groups["spiky"].model_group_id
    stable_3y_id = models["stable_3y_ago"].model_id
    stable_3y_end = models["stable_3y_ago"].train_end_time
    stable_2y_id = models["stable_2y_ago"].model_id
    stable_2y_end = models["stable_2y_ago"].train_end_time
    stable_1y_id = models["stable_1y_ago"].model_id
    stable_1y_end = models["stable_1y_ago"].train_end_time
    spiky_3y_id = models["spiky_3y_ago"].model_id
    spiky_3y_end = models["spiky_3y_ago"].train_end_time
    spiky_2y_id = models["spiky_2y_ago"].model_id
    spiky_2y_end = models["spiky_2y_ago"].train_end_time
    spiky_1y_id = models["spiky_1y_ago"].model_id
    spiky_1y_end = models["spiky_1y_ago"].train_end_time
    distance_rows = [
        (
            stable_grp,
            stable_3y_end,
            "precision@",
            "100_abs",
            0.5,
            0.6,
            0.1,
            0.5,
            0.15,
        ),
        (
            stable_grp,
            stable_2y_end,
            "precision@",
            "100_abs",
            0.5,
            0.84,
            0.34,
            0.5,
            0.18,
        ),
        (
            stable_grp,
            stable_1y_end,
            "precision@",
            "100_abs",
            0.46,
            0.67,
            0.21,
            0.5,
            0.11,
        ),
        (
            spiky_grp,
            spiky_3y_end,
            "precision@",
            "100_abs",
            0.45,
            0.6,
            0.15,
            0.5,
            0.19,
        ),
        (
            spiky_grp,
            spiky_2y_end,
            "precision@",
            "100_abs",
            0.84,
            0.84,
            0.0,
            0.5,
            0.3,
        ),
        (
            spiky_grp,
            spiky_1y_end,
            "precision@",
            "100_abs",
            0.45,
            0.67,
            0.22,
            0.5,
            0.12,
        ),
        (
            stable_grp,
            stable_3y_end,
            "recall@",
            "100_abs",
            0.4,
            0.4,
            0.0,
            0.4,
            0.0,
        ),
        (
            stable_grp,
            stable_2y_end,
            "recall@",
            "100_abs",
            0.5,
            0.5,
            0.0,
            0.5,
            0.0,
        ),
        (
            stable_grp,
            stable_1y_end,
            "recall@",
            "100_abs",
            0.6,
            0.6,
            0.0,
            0.6,
            0.0,
        ),
        (
            spiky_grp,
            spiky_3y_end,
            "recall@",
            "100_abs",
            0.65,
            0.65,
            0.0,
            0.65,
            0.0,
        ),
        (
            spiky_grp,
            spiky_2y_end,
            "recall@",
            "100_abs",
            0.55,
            0.55,
            0.0,
            0.55,
            0.0,
        ),
        (
            spiky_grp,
            spiky_1y_end,
            "recall@",
            "100_abs",
            0.45,
            0.45,
            0.0,
            0.45,
            0.0,
        ),
    ]
    for dist_row in distance_rows:
        engine.execute(
            "insert into dist_table values (%s, %s, %s, %s, %s, %s, %s, %s, %s)",
            dist_row,
        )
    return distance_table, model_groups