src/tests/test_defaults.py

Summary

Maintainability
A
0 mins
Test Coverage
from sqlalchemy import create_engine
import testing.postgresql
import pytest

from triage.component.catwalk.db import ensure_db

from tests.utils import sample_config, populate_source_data
from triage.experiments.defaults import (
    fill_timechop_config_missing,
    fill_feature_group_definition,
    fill_model_grid_presets,
    model_grid_preset,
)


def test_fill_feature_group_definition():
    config = sample_config()
    fg_definition = fill_feature_group_definition(config)
    assert sorted(fg_definition['all']) == [True]


def test_fill_timechop_config_missing():
    remove_keys = [
        'model_update_frequency',
        'training_as_of_date_frequencies',
        'test_as_of_date_frequencies',
        'max_training_histories',
        'test_durations',
        'feature_start_time',
        'feature_end_time',
        'label_start_time',
        'label_end_time',
        'training_label_timespans',
        'test_label_timespans'
        ]

    # ensure redundant keys properly raise errors
    config = sample_config()
    config['temporal_config']['label_timespans'] = '1y'
    with pytest.raises(KeyError):
        timechop_config = fill_timechop_config_missing(config, None)

    with testing.postgresql.Postgresql() as postgresql:
        db_engine = create_engine(postgresql.url())
        ensure_db(db_engine)
        populate_source_data(db_engine)
        config = sample_config()

        for key in remove_keys:
            config['temporal_config'].pop(key)
        config['temporal_config']['label_timespans'] = '1y'

        timechop_config = fill_timechop_config_missing(config, db_engine)

        assert timechop_config['model_update_frequency'] == '100y'
        assert timechop_config['training_as_of_date_frequencies'] == '100y'
        assert timechop_config['test_as_of_date_frequencies'] == '100y'
        assert timechop_config['max_training_histories'] == '0d'
        assert timechop_config['test_durations'] == '0d'
        assert timechop_config['training_label_timespans'] == '1y'
        assert timechop_config['test_label_timespans'] == '1y'
        assert 'label_timespans' not in timechop_config.keys()
        assert timechop_config['feature_start_time'] == '2010-10-01'
        assert timechop_config['feature_end_time'] == '2013-10-01'
        assert timechop_config['label_start_time'] == '2010-10-01'
        assert timechop_config['label_end_time'] == '2013-10-01'


def test_model_grid_preset():
    test_cases = {
        'quickstart': {'total_items': 8, 'model_types': 3},
        'small': {'total_items': 35, 'model_types': 4},
        'medium': {'total_items': 64, 'model_types': 6},
        'large': {'total_items': 83, 'model_types': 6},
        'texas': {'total_items': 115, 'model_types': 9},
    }

    for grid_type, exp_results in test_cases.items():
        preset_grid = model_grid_preset(grid_type)
        assert len(preset_grid) == exp_results['model_types']
        total_items = sum([len(y) for x in preset_grid.values() for y in x.values()])
        assert total_items == exp_results['total_items']


def test_fill_model_grid_presets():

    # case 1: has grid, no preset
    config = sample_config()
    fill_grid = fill_model_grid_presets(config)
    assert fill_grid == config['grid_config']

    # case 2: has preset, no grid
    config = sample_config()
    config.pop('grid_config')
    config['model_grid_preset'] = 'quickstart'
    fill_grid = fill_model_grid_presets(config)
    assert len(fill_grid) == 3

    # case 3: neither
    config = sample_config()
    config.pop('grid_config')
    fill_grid = fill_model_grid_presets(config)
    assert fill_grid is None

    # case 4: both
    config = sample_config()
    config['model_grid_preset'] = 'quickstart'
    fill_grid = fill_model_grid_presets(config)
    assert len(fill_grid) == 3
    assert len(fill_grid.get('sklearn.tree.DecisionTreeClassifier', {}).get('max_depth', [])) == 3
    assert len(fill_grid.get('sklearn.tree.DecisionTreeClassifier', {}).get('criterion', [])) == 1