src/tests/test_tracking_experiments.py

Summary

Maintainability
A
1 hr
Test Coverage
from sqlalchemy.orm import Session
import pytest
import datetime
from unittest import mock

from triage.tracking import (
    initialize_tracking_and_get_run_id,
    get_run_for_update,
    increment_field,
)
from triage.util.db import scoped_session
from triage.experiments import MultiCoreExperiment, SingleThreadedExperiment
from triage.component.results_schema import TriageRun, TriageRunStatus
from tests.results_tests.factories import (
    ExperimentFactory,
    TriageRunFactory,
    session as factory_session,
)
from tests.utils import sample_config, populate_source_data, open_side_effect


@pytest.fixture(name="test_engine", scope="module")
def shared_db_engine_with_source_data(shared_db_engine):
    """A successfully-run experiment. Its database schemas and project storage can be queried.

    Returns: (triage.experiments.SingleThreadedExperiment)
    """
    populate_source_data(shared_db_engine)
    yield shared_db_engine


def test_experiment_tracker(test_engine, project_path):
    with mock.patch("triage.util.conf.open", side_effect=open_side_effect) as mock_file:
        experiment = MultiCoreExperiment(
            config=sample_config(),
            db_engine=test_engine,
            project_path=project_path,
            n_processes=4,
        )
    experiment_run = Session(bind=test_engine).query(TriageRun).get(experiment.run_id)
    assert experiment_run.current_status == TriageRunStatus.started
    assert experiment_run.run_hash == experiment.experiment_hash
    assert experiment_run.run_type == "experiment"
    assert (
        experiment_run.experiment_class_path
        == "triage.experiments.multicore.MultiCoreExperiment"
    )
    assert experiment_run.platform
    assert experiment_run.os_user
    assert experiment_run.installed_libraries
    assert experiment_run.matrices_skipped == 0
    assert experiment_run.matrices_errored == 0
    assert experiment_run.matrices_made == 0
    assert experiment_run.models_skipped == 0
    assert experiment_run.models_errored == 0
    assert experiment_run.models_made == 0

    experiment.run()
    experiment_run = Session(bind=test_engine).query(TriageRun).get(experiment.run_id)
    assert experiment_run.start_method == "run"
    assert experiment_run.matrices_made == len(experiment.matrix_build_tasks)
    assert experiment_run.matrices_skipped == 0
    assert experiment_run.matrices_errored == 0
    assert experiment_run.models_skipped == 0
    assert experiment_run.models_errored == 0
    assert experiment_run.models_made == len(
        list(
            task["train_kwargs"]["model_hash"]
            for batch in experiment._all_train_test_batches()
            for task in batch.tasks
        )
    )
    assert isinstance(experiment_run.matrix_building_started, datetime.datetime)
    assert isinstance(experiment_run.model_building_started, datetime.datetime)
    assert isinstance(experiment_run.last_updated_time, datetime.datetime)
    assert not experiment_run.stacktrace
    assert experiment_run.current_status == TriageRunStatus.completed


def test_experiment_tracker_exception(db_engine, project_path):
    with mock.patch("triage.util.conf.open", side_effect=open_side_effect) as mock_file:
        experiment = SingleThreadedExperiment(
            config=sample_config(),
            db_engine=db_engine,
            project_path=project_path,
        )
    # no source data means this should blow up
    with pytest.raises(Exception):
        experiment.run()

    with scoped_session(db_engine) as session:
        experiment_run = session.query(TriageRun).get(experiment.run_id)
        assert experiment_run.current_status == TriageRunStatus.failed
        assert isinstance(experiment_run.last_updated_time, datetime.datetime)
        assert experiment_run.stacktrace


def test_experiment_tracker_in_parts(test_engine, project_path):
    with mock.patch("triage.util.conf.open", side_effect=open_side_effect) as mock_file:
        experiment = SingleThreadedExperiment(
            config=sample_config(),
            db_engine=test_engine,
            project_path=project_path,
        )
    experiment.generate_matrices()
    experiment.train_and_test_models()
    with scoped_session(test_engine) as session:
        experiment_run = session.query(TriageRun).get(experiment.run_id)
        assert experiment_run.start_method == "generate_matrices"


def test_initialize_tracking_and_get_run_id(db_engine_with_results_schema):
    experiment = ExperimentFactory()
    factory_session.commit()
    experiment_hash = experiment.experiment_hash
    run_id = initialize_tracking_and_get_run_id(
        experiment_hash=experiment_hash,
        experiment_class_path="mymodule.MyClassName",
        random_seed=1234,
        experiment_kwargs={"key": "value"},
        db_engine=db_engine_with_results_schema,
    )
    assert run_id
    with scoped_session(db_engine_with_results_schema) as session:
        experiment_run = session.query(TriageRun).get(run_id)
        assert experiment_run.run_hash == experiment_hash
        assert experiment_run.experiment_class_path == "mymodule.MyClassName"
        assert experiment_run.random_seed == 1234
        assert experiment_run.experiment_kwargs == {"key": "value"}
    new_run_id = initialize_tracking_and_get_run_id(
        experiment_hash=experiment_hash,
        experiment_class_path="mymodule.MyClassName",
        random_seed=5432,
        experiment_kwargs={"key": "value"},
        db_engine=db_engine_with_results_schema,
    )
    assert new_run_id > run_id


def test_get_run_for_update(db_engine_with_results_schema):
    experiment_run = TriageRunFactory()
    factory_session.commit()
    with get_run_for_update(
        db_engine=db_engine_with_results_schema, run_id=experiment_run.run_id
    ) as run_obj:
        run_obj.stacktrace = "My stacktrace"

    with scoped_session(db_engine_with_results_schema) as session:
        experiment_run_from_db = session.query(TriageRun).get(experiment_run.run_id)
        assert experiment_run_from_db.stacktrace == "My stacktrace"


def test_increment_field(db_engine_with_results_schema):
    experiment_run = TriageRunFactory()
    factory_session.commit()
    increment_field(
        "matrices_made", experiment_run.run_id, db_engine_with_results_schema
    )
    increment_field(
        "matrices_made", experiment_run.run_id, db_engine_with_results_schema
    )

    with scoped_session(db_engine_with_results_schema) as session:
        experiment_run_from_db = session.query(TriageRun).get(experiment_run.run_id)
        assert experiment_run_from_db.matrices_made == 2