src/tests/catwalk_tests/test_model_trainers.py

Summary

Maintainability
A
2 hrs
Test Coverage
import pandas as pd
import random
import pytest


from triage.component.catwalk.model_grouping import ModelGrouper
from triage.component.catwalk.model_trainers import ModelTrainer
from triage.component.catwalk.utils import save_experiment_and_get_hash
from triage.tracking import initialize_tracking_and_get_run_id
from tests.utils import get_matrix_store


@pytest.fixture
def grid_config():
    return {
        "sklearn.tree.DecisionTreeClassifier": {
            "min_samples_split": [10, 100],
            "max_depth": [3, 5],
            "criterion": ["gini"],
        }
    }


@pytest.fixture(scope="function")
def default_model_trainer(db_engine_with_results_schema, project_storage):
    model_storage_engine = project_storage.model_storage_engine()
    experiment_hash = save_experiment_and_get_hash(
        config={'foo': 'bar'}, 
        db_engine=db_engine_with_results_schema
        )
    run_id = initialize_tracking_and_get_run_id(
        experiment_hash,
        experiment_class_path="",
        random_seed=5,
        experiment_kwargs={},
        db_engine=db_engine_with_results_schema
    )
    # import pdb; pdb.set_trace()
    trainer = ModelTrainer(
        experiment_hash=experiment_hash,
        model_storage_engine=model_storage_engine,
        db_engine=db_engine_with_results_schema,
        model_grouper=ModelGrouper(),
        run_id=run_id,
    )
    yield trainer


def test_model_trainer(grid_config, default_model_trainer):
    trainer = default_model_trainer
    db_engine = trainer.db_engine
    project_storage = trainer.model_storage_engine.project_storage
    model_storage_engine = trainer.model_storage_engine

    def set_test_seed():
        random.seed(5)
    set_test_seed()
    model_ids = trainer.train_models(
        grid_config=grid_config,
        misc_db_parameters=dict(),
        matrix_store=get_matrix_store(project_storage),
    )
    # assert
    # 1. that the models and feature importances table entries are present
    records = [
        row
        for row in db_engine.execute(
            "select * from train_results.feature_importances"
        )
    ]
    assert len(records) == 4 * 2  # maybe exclude entity_id? yes

    records = [
        row
        for row in db_engine.execute("select model_hash from triage_metadata.models")
    ]
    assert len(records) == 4
    hashes = [row[0] for row in records]

    # 2. that the model groups are distinct
    records = [
        row
        for row in db_engine.execute(
            "select distinct model_group_id from triage_metadata.models"
        )
    ]
    assert len(records) == 4

    # 2. that the random seeds are distinct
    records = [
        row
        for row in db_engine.execute(
            "select distinct random_seed from triage_metadata.models"
        )
    ]
    assert len(records) == 4

    # 3. that the model sizes are saved in the table and all are < 1 kB
    records = [
        row
        for row in db_engine.execute("select model_size from triage_metadata.models")
    ]
    assert len(records) == 4
    for i in records:
        size = i[0]
        assert size < 1

    # 4. that all four models are cached
    model_pickles = [model_storage_engine.load(model_hash) for model_hash in hashes]
    assert len(model_pickles) == 4
    assert len([x for x in model_pickles if x is not None]) == 4

    # 5. that their results can have predictions made on it
    test_matrix = pd.DataFrame.from_dict(
        {"entity_id": [3, 4], "feature_one": [4, 4], "feature_two": [6, 5]}
    ).set_index("entity_id")

    for model_pickle in model_pickles:
        predictions = model_pickle.predict(test_matrix)
        assert len(predictions) == 2

    # 6. when run again with the same starting seed, same models are returned
    set_test_seed()
    new_model_ids = trainer.train_models(
        grid_config=grid_config,
        misc_db_parameters=dict(),
        matrix_store=get_matrix_store(project_storage),
    )
    assert (
        len(
            [
                row
                for row in db_engine.execute(
                    "select model_hash from triage_metadata.models"
                )
            ]
        )
        == 4
    )
    assert model_ids == new_model_ids

    # 7. if replace is set, update non-unique attributes and feature importances
    max_batch_run_time = [
        row[0]
        for row in db_engine.execute(
            "select max(batch_run_time) from triage_metadata.models"
        )
    ][0]
    experiment_hash = save_experiment_and_get_hash(
        config={'foo': 'bar'}, 
        db_engine=db_engine
        )
    run_id = initialize_tracking_and_get_run_id(
        experiment_hash,
        experiment_class_path="",
        random_seed=5,
        experiment_kwargs={},
        db_engine=db_engine
    )
    trainer = ModelTrainer(
        experiment_hash=experiment_hash,
        model_storage_engine=model_storage_engine,
        model_grouper=ModelGrouper(
            model_group_keys=["label_name", "label_timespan"]
        ),
        db_engine=db_engine,
        replace=True,
        run_id=run_id,
    )
    set_test_seed()
    new_model_ids = trainer.train_models(
        grid_config=grid_config,
        misc_db_parameters=dict(),
        matrix_store=get_matrix_store(project_storage),
    )
    assert model_ids == new_model_ids
    assert [
        row["model_id"]
        for row in db_engine.execute(
            "select model_id from triage_metadata.models order by 1 asc"
        )
    ] == model_ids
    new_max_batch_run_time = [
        row[0]
        for row in db_engine.execute(
            "select max(batch_run_time) from triage_metadata.models"
        )
    ][0]
    assert new_max_batch_run_time > max_batch_run_time

    records = [
        row
        for row in db_engine.execute(
            "select * from train_results.feature_importances"
        )
    ]
    assert len(records) == 4 * 2  # maybe exclude entity_id? yes

    # 8. if the cache is missing but the metadata is still there, reuse the metadata
    set_test_seed()
    for row in db_engine.execute("select model_hash from triage_metadata.models"):
        model_storage_engine.delete(row[0])
    new_model_ids = trainer.train_models(
        grid_config=grid_config,
        misc_db_parameters=dict(),
        matrix_store=get_matrix_store(project_storage),
    )
    assert model_ids == sorted(new_model_ids)

    # 9. that the generator interface works the same way
    set_test_seed()
    new_model_ids = trainer.generate_trained_models(
        grid_config=grid_config,
        misc_db_parameters=dict(),
        matrix_store=get_matrix_store(project_storage),
    )
    assert model_ids == sorted([model_id for model_id in new_model_ids])


def test_baseline_exception_handling(default_model_trainer):
    grid_config = {
        "triage.component.catwalk.baselines.rankers.PercentileRankOneFeature": {
            "feature": ["feature_one", "feature_three"]
        }
    }
    trainer = default_model_trainer
    project_storage = trainer.model_storage_engine.project_storage

    train_tasks = trainer.generate_train_tasks(
        grid_config, dict(), get_matrix_store(project_storage)
    )

    model_ids = []
    for train_task in train_tasks:
        model_ids.append(trainer.process_train_task(**train_task))
    assert model_ids == [1, None]


def test_custom_groups(grid_config, db_engine_with_results_schema, project_storage):
    model_storage_engine = project_storage.model_storage_engine()
    experiment_hash = save_experiment_and_get_hash(
        config={'foo': 'bar'}, 
        db_engine=db_engine_with_results_schema
        )
    run_id = initialize_tracking_and_get_run_id(
        experiment_hash,
        experiment_class_path="",
        random_seed=5,
        experiment_kwargs={},
        db_engine=db_engine_with_results_schema
    )
    trainer = ModelTrainer(
        experiment_hash=experiment_hash,
        model_storage_engine=model_storage_engine,
        model_grouper=ModelGrouper(["class_path"]),
        db_engine=db_engine_with_results_schema,
        run_id=run_id,
    )
    # create training set
    model_ids = trainer.train_models(
        grid_config=grid_config,
        misc_db_parameters=dict(),
        matrix_store=get_matrix_store(project_storage),
    )
    # expect only one model group now
    records = [
        row[0]
        for row in db_engine_with_results_schema.execute(
            "select distinct model_group_id from triage_metadata.models"
        )
    ]
    assert len(records) == 1
    assert records[0] == model_ids[0]


def test_reuse_model_random_seeds(grid_config, default_model_trainer):
    trainer = default_model_trainer
    db_engine = trainer.db_engine
    project_storage = trainer.model_storage_engine.project_storage
    model_storage_engine = trainer.model_storage_engine

    # re-using the random seeds requires the association between experiments and models
    # to exist, which we're not getting in these tests since we aren't using the experiment
    # architecture, so back-fill these associations after each train_models() run
    def update_experiment_models(db_engine):
        sql = """
            INSERT INTO triage_metadata.experiment_models(experiment_hash,model_hash) 
            SELECT er.run_hash, m.model_hash
            FROM triage_metadata.models m
            LEFT JOIN triage_metadata.triage_runs er
                ON m.built_in_triage_run = er.id
            LEFT JOIN triage_metadata.experiment_models em 
                ON m.model_hash = em.model_hash
                AND er.run_hash = em.experiment_hash
            WHERE em.experiment_hash IS NULL
            """
        db_engine.execute(sql)
        db_engine.execute('COMMIT;')

    random.seed(5)
    model_ids = trainer.train_models(
        grid_config=grid_config,
        misc_db_parameters=dict(),
        matrix_store=get_matrix_store(project_storage),
    )
    update_experiment_models(db_engine)

    # simulate running a new experiment where the experiment hash has changed
    # (e.g. because the model grid is different), but experiment seed is the
    # same, so previously-trained models should not get new seeds
    experiment_hash = save_experiment_and_get_hash(
        config={'baz': 'qux'}, 
        db_engine=db_engine
        )
    run_id = initialize_tracking_and_get_run_id(
        experiment_hash,
        experiment_class_path="",
        random_seed=5,
        experiment_kwargs={},
        db_engine=db_engine
    )
    trainer = ModelTrainer(
        experiment_hash=experiment_hash,
        model_storage_engine=model_storage_engine,
        db_engine=db_engine,
        model_grouper=ModelGrouper(),
        run_id=run_id,
    )
    new_grid = grid_config.copy()
    new_grid['sklearn.tree.DecisionTreeClassifier']['min_samples_split'] = [3,10,100]
    random.seed(5)
    new_model_ids = trainer.train_models(
        grid_config=new_grid,
        misc_db_parameters=dict(),
        matrix_store=get_matrix_store(project_storage),
    )
    update_experiment_models(db_engine)

    # should have received 5 models
    assert len(new_model_ids) == 6

    # all the original model ids should be in the new set
    assert len(set(new_model_ids) & set(model_ids)) == len(model_ids)

    # however, we should NOT re-use the random seeds (and so get new model_ids)
    # if the experiment-level seed is different
    experiment_hash = save_experiment_and_get_hash(
        config={'lorem': 'ipsum'}, 
        db_engine=db_engine
        )
    run_id = initialize_tracking_and_get_run_id(
        experiment_hash,
        experiment_class_path="",
        random_seed=42,
        experiment_kwargs={},
        db_engine=db_engine
    )
    trainer = ModelTrainer(
        experiment_hash=experiment_hash,
        model_storage_engine=model_storage_engine,
        db_engine=db_engine,
        model_grouper=ModelGrouper(),
        run_id=run_id,
    )
    random.seed(42) # different from above
    newer_model_ids = trainer.train_models(
        grid_config=new_grid,
        misc_db_parameters=dict(),
        matrix_store=get_matrix_store(project_storage),
    )
    update_experiment_models(db_engine)

    # should get entirely new models now (different IDs)
    assert len(newer_model_ids) == 6
    assert len(set(new_model_ids) & set(newer_model_ids)) == 0


def test_n_jobs_not_new_model(default_model_trainer):
    grid_config = {
        "sklearn.ensemble.AdaBoostClassifier": {"n_estimators": [10, 100, 1000]},
        "sklearn.ensemble.RandomForestClassifier": {
            "n_estimators": [10, 100],
            "max_features": ["sqrt", "log2"],
            "max_depth": [5, 10, 15, 20],
            "criterion": ["gini", "entropy"],
            "n_jobs": [12],
        },
    }

    trainer = default_model_trainer
    project_storage = trainer.model_storage_engine.project_storage
    db_engine = trainer.db_engine

    # generate train tasks, with a specific random seed so that we can compare
    # apples to apples later
    random.seed(5)
    train_tasks = trainer.generate_train_tasks(
        grid_config, dict(), get_matrix_store(project_storage)
    )

    for train_task in train_tasks:
        trainer.process_train_task(**train_task)

    # since n_jobs is a runtime attribute of the model, it should not make it
    # into the model group
    for row in db_engine.execute(
        "select hyperparameters from triage_metadata.model_groups"
    ):
        assert "n_jobs" not in row[0]

    hashes = set(task['model_hash'] for task in train_tasks)
    # generate the grid again with a different n_jobs (but the same random seed!)
    # and make sure that the hashes are the same as before
    random.seed(5)
    grid_config['sklearn.ensemble.RandomForestClassifier']['n_jobs'] = [24]
    new_train_tasks = trainer.generate_train_tasks(
        grid_config, dict(), get_matrix_store(project_storage)
    )
    assert hashes == set(task['model_hash'] for task in new_train_tasks)


def test_cache_models(default_model_trainer):
    assert not default_model_trainer.model_storage_engine.should_cache
    with default_model_trainer.cache_models():
        assert default_model_trainer.model_storage_engine.should_cache
    assert not default_model_trainer.model_storage_engine.should_cache