src/tests/postmodeling_tests/test_model_group_evaluator.py
from triage.component.postmodeling.contrast.model_group_evaluator import ModelGroupEvaluator
import pandas as pd
import pytest
from tests.utils import assert_plot_figures_added
@pytest.fixture(scope="module")
def model_group_evaluator(finished_experiment):
return ModelGroupEvaluator((1,1), finished_experiment.db_engine)
def test_ModelGroupEvaluator_metadata(model_group_evaluator):
assert isinstance(model_group_evaluator.metadata, list)
assert len(model_group_evaluator.metadata) == 2 # 2 models expected for a model_group from basic experiment
for row in model_group_evaluator.metadata:
assert isinstance(row, dict)
def test_ModelGroupEvaluator_model_type(model_group_evaluator):
assert model_group_evaluator.model_type[0] == 'sklearn.tree.DecisionTreeClassifier'
def test_ModelGroupEvaluator_predictions(model_group_evaluator):
assert isinstance(model_group_evaluator.predictions, pd.DataFrame)
def test_ModelGroupEvaluator_feature_importances(model_group_evaluator):
assert isinstance(model_group_evaluator.feature_importances, pd.DataFrame)
def test_ModelGroupEvaluator_metrics(model_group_evaluator):
assert isinstance(model_group_evaluator.metrics, pd.DataFrame)
def test_ModelGroupEvaluator_feature_groups(model_group_evaluator):
assert isinstance(model_group_evaluator.feature_groups, pd.DataFrame)
def test_ModelGroupEvaluator_same_time_models(model_group_evaluator):
assert isinstance(model_group_evaluator.same_time_models, pd.DataFrame)
def test_ModelGroupEvaluator_plot_prec_across_time(model_group_evaluator):
with assert_plot_figures_added():
model_group_evaluator.plot_prec_across_time()
def test_ModelGroupEvaluator_feature_loi_loo(model_group_evaluator):
with pytest.raises(IndexError):
model_group_evaluator.feature_loi_loo()
def test_ModelGroupEvaluator_plot_ranked_correlation_preds(model_group_evaluator):
with assert_plot_figures_added():
model_group_evaluator.plot_ranked_correlation_preds(param_type='rank_abs', param=10, top_n_features=10)
def test_ModelGroupEvaluator_plot_ranked_correlation_features(model_group_evaluator):
with assert_plot_figures_added():
model_group_evaluator.plot_ranked_correlation_features(param_type='rank_abs', param=10, top_n_features=10)
def test_ModelGroupEvaluator_plot_jaccard_preds(model_group_evaluator):
with assert_plot_figures_added():
model_group_evaluator.plot_jaccard_preds(param_type='rank_abs', param=10)
def test_ModelGroupEvaluator_plot_jaccard_features(model_group_evaluator):
with assert_plot_figures_added():
model_group_evaluator.plot_jaccard_features()
def test_ModelGroupEvaluator_plot_preds_comparison(model_group_evaluator):
with assert_plot_figures_added():
model_group_evaluator.plot_preds_comparison(param_type='rank_abs', param=10)