src/tests/utils.py

Summary

Maintainability
B
5 hrs
Test Coverage
import datetime
import functools
import importlib
import os
import random
import tempfile
from contextlib import contextmanager
from unittest import mock

import matplotlib
import numpy as np
import pandas as pd
import testing.postgresql
from descriptors import cachedproperty
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from triage.component.catwalk.db import ensure_db
from triage.component.catwalk.storage import MatrixStore, ProjectStorage
from triage.component.catwalk.utils import filename_friendly_hash
from triage.component.results_schema import Model, Matrix
from triage.experiments import CONFIG_VERSION
from triage.util.structs import FeatureNameList

from tests.results_tests.factories import init_engine, session, MatrixFactory

matplotlib.use("Agg")

from matplotlib import pyplot as plt  # noqa


CONFIG_QUERY_DATA = {
    "cohort": {
        "query": """
            select distinct(entity_id)
            from events
            where '{as_of_date}'::date >= outcome_date
        """,
        "filepath": "cohorts/file.sql",
    },
    "label": {
        "query": """
            select
                events.entity_id,
                bool_or(outcome::bool)::integer as outcome
            from events
            where '{as_of_date}'::date <= outcome_date
                and outcome_date < '{as_of_date}'::date + interval '{label_timespan}'
            group by entity_id
        """,
        "filepath": "labels/file.sql",
    },
}

MOCK_FILES = {
    os.path.join(
        os.path.abspath(os.getcwd()), f"{CONFIG_QUERY_DATA['label']['filepath']}"
    ): CONFIG_QUERY_DATA["label"]["query"],
    os.path.join(
        os.path.abspath(os.getcwd()), f"{CONFIG_QUERY_DATA['cohort']['filepath']}"
    ): CONFIG_QUERY_DATA["cohort"]["query"],
}


def open_side_effect(name):
    return mock.mock_open(read_data=MOCK_FILES[name]).return_value


def fake_labels(length):
    return np.array([random.choice([True, False]) for i in range(0, length)])


class MockTrainedModel:
    def predict_proba(self, dataset):
        return np.random.rand(len(dataset), len(dataset))


class MockMatrixStore(MatrixStore):
    def __init__(
        self,
        matrix_type,
        matrix_uuid,
        label_count,
        db_engine,
        init_labels=None,
        metadata_overrides=None,
        matrix=None,
        init_as_of_dates=None,
    ):
        base_metadata = {
            "feature_start_time": datetime.date(2014, 1, 1),
            "end_time": datetime.date(2015, 1, 1),
            "as_of_date_frequency": "1y",
            "matrix_id": "some_matrix",
            "label_name": "label",
            "label_timespan": "3month",
            "indices": MatrixStore.indices,
            "matrix_type": matrix_type,
            "as_of_times": [datetime.date(2014, 10, 1), datetime.date(2014, 7, 1)],
        }
        metadata_overrides = metadata_overrides or {}
        base_metadata.update(metadata_overrides)
        if matrix is None:
            matrix = pd.DataFrame.from_dict(
                {
                    "entity_id": [1, 2],
                    "as_of_date": [pd.Timestamp(2014, 10, 1), pd.Timestamp(2014, 7, 1)],
                    "feature_one": [3, 4],
                    "feature_two": [5, 6],
                    "label": [7, 8],
                }
            ).set_index(MatrixStore.indices)
        if init_labels is None:
            init_labels = []
        labels = matrix.pop("label")
        self.matrix_label_tuple = matrix, labels
        self.metadata = base_metadata
        self.label_count = label_count
        self.init_labels = pd.Series(init_labels, dtype="float64")
        self.matrix_uuid = matrix_uuid
        self.init_as_of_dates = init_as_of_dates or []

        session = sessionmaker(db_engine)()
        session.add(Matrix(matrix_uuid=matrix_uuid))
        session.commit()

    @property
    def as_of_dates(self):
        """The list of as-of-dates in the matrix"""
        return self.init_as_of_dates or self.metadata["as_of_times"]

    @property
    def labels(self):
        if len(self.init_labels) > 0:
            return self.init_labels
        else:
            return fake_labels(self.label_count)


def fake_trained_model(
    db_engine, train_matrix_uuid="efgh", train_end_time=datetime.datetime(2016, 1, 1)
):
    """Creates and stores a trivial trained model and training matrix

    Args:
        db_engine (sqlalchemy.engine)

    Returns:
        (int) model id for database retrieval
    """
    session = sessionmaker(db_engine)()
    session.merge(Matrix(matrix_uuid=train_matrix_uuid))

    # Create the fake trained model and store in db
    trained_model = MockTrainedModel()
    db_model = Model(
        model_hash="abcd",
        train_matrix_uuid=train_matrix_uuid,
        train_end_time=train_end_time,
    )
    session.add(db_model)
    session.commit()
    model_id = db_model.model_id
    session.close()
    return trained_model, model_id


def matrix_metadata_creator(**override_kwargs):
    """Create a sample valid matrix metadata with optional overrides

    Args:
    **override_kwargs: Keys and values to override in the metadata

    Returns: (dict)
    """
    base_metadata = {
        "feature_start_time": datetime.date(2012, 12, 20),
        "end_time": datetime.date(2016, 12, 20),
        "label_name": "label",
        "as_of_date_frequency": "1w",
        "max_training_history": "5y",
        "matrix_id": "tester-1",
        "state": "active",
        "cohort_name": "default",
        "label_timespan": "1y",
        "metta-uuid": "1234",
        "matrix_type": "test",
        "feature_names": FeatureNameList(["ft1", "ft2"]),
        "feature_groups": ["all: True"],
        "indices": MatrixStore.indices,
        "as_of_times": [datetime.date(2016, 12, 20)],
    }
    for override_key, override_value in override_kwargs.items():
        base_metadata[override_key] = override_value

    return base_metadata


def matrix_creator():
    """Return a sample matrix."""

    source_dict = {
        "entity_id": [1, 2],
        "as_of_date": [pd.Timestamp(2016, 1, 1), pd.Timestamp(2016, 1, 1)],
        "feature_one": [3, 4],
        "feature_two": [5, 6],
        "label": [0, 1],
    }
    return pd.DataFrame.from_dict(source_dict)


def get_matrix_store(project_storage, matrix=None, metadata=None, write_to_db=True):
    """Return a matrix store associated with the given project storage.
    Also adds an entry in the matrices table if it doesn't exist already

    Args:
        project_storage (triage.component.catwalk.storage.ProjectStorage) A project's storage
        matrix (dataframe, optional): A matrix to store. Defaults to the output of matrix_creator()
        metadata (dict, optional): matrix metadata.
            defaults to the output of matrix_metadata_creator()
    """
    if matrix is None:
        matrix = matrix_creator()
    if not metadata:
        metadata = matrix_metadata_creator()
    matrix["as_of_date"] = matrix["as_of_date"].apply(pd.Timestamp)
    matrix.set_index(MatrixStore.indices, inplace=True)
    matrix_store = project_storage.matrix_storage_engine().get_store(
        filename_friendly_hash(metadata)
    )
    matrix_store.metadata = metadata
    new_matrix = matrix.copy()
    labels = new_matrix.pop(matrix_store.label_column_name)
    matrix_store.matrix_label_tuple = new_matrix, labels
    matrix_store.save()
    matrix_store.clear_cache()
    if write_to_db:
        if (
            session.query(Matrix)
            .filter(Matrix.matrix_uuid == matrix_store.uuid)
            .count()
            == 0
        ):
            MatrixFactory(matrix_uuid=matrix_store.uuid)
            session.commit()
    return matrix_store


@contextmanager
def rig_engines():
    """Set up a db engine and project storage engine

    Yields (tuple) (database engine, project storage engine)
    """
    with testing.postgresql.Postgresql() as postgresql:
        db_engine = create_engine(postgresql.url())
        ensure_db(db_engine)
        init_engine(db_engine)
        with tempfile.TemporaryDirectory() as temp_dir:
            project_storage = ProjectStorage(temp_dir)
            yield db_engine, project_storage


def populate_source_data(db_engine):
    complaints = [
        (1, "2010-10-01", 5),
        (1, "2011-10-01", 4),
        (1, "2011-11-01", 4),
        (1, "2011-12-01", 4),
        (1, "2012-02-01", 5),
        (1, "2012-10-01", 4),
        (1, "2013-10-01", 5),
        (2, "2010-10-01", 5),
        (2, "2011-10-01", 5),
        (2, "2011-11-01", 4),
        (2, "2011-12-01", 4),
        (2, "2012-02-01", 6),
        (2, "2012-10-01", 5),
        (2, "2013-10-01", 6),
        (3, "2010-10-01", 5),
        (3, "2011-10-01", 3),
        (3, "2011-11-01", 4),
        (3, "2011-12-01", 4),
        (3, "2012-02-01", 4),
        (3, "2012-10-01", 3),
        (3, "2013-10-01", 4),
    ]

    entity_zip_codes = [(1, "60120"), (2, "60123"), (3, "60123")]

    zip_code_demographics = [
        ("60120", "hispanic", "2011-01-01"),
        ("60123", "white", "2011-01-01"),
    ]

    zip_code_events = [("60120", "2012-10-01", 1), ("60123", "2012-10-01", 10)]

    events = [
        (1, 1, "2011-01-01"),
        (1, 1, "2011-06-01"),
        (1, 1, "2011-09-01"),
        (1, 1, "2012-01-01"),
        (1, 1, "2012-01-10"),
        (1, 1, "2012-06-01"),
        (1, 1, "2013-01-01"),
        (1, 0, "2014-01-01"),
        (1, 1, "2015-01-01"),
        (2, 1, "2011-01-01"),
        (2, 1, "2011-06-01"),
        (2, 1, "2011-09-01"),
        (2, 1, "2012-01-01"),
        (2, 1, "2013-01-01"),
        (2, 1, "2014-01-01"),
        (2, 1, "2015-01-01"),
        (3, 0, "2011-01-01"),
        (3, 0, "2011-06-01"),
        (3, 0, "2011-09-01"),
        (3, 0, "2012-01-01"),
        (3, 0, "2013-01-01"),
        (3, 1, "2014-01-01"),
        (3, 0, "2015-01-01"),
    ]

    db_engine.execute(
        """create table cat_complaints (
        entity_id int,
        as_of_date date,
        cat_sightings int
        )"""
    )

    db_engine.execute(
        """create table entity_zip_codes (
        entity_id int,
        zip_code text
        )"""
    )

    db_engine.execute(
        "create table zip_code_demographics (zip_code text, ethnicity text, as_of_date date)"
    )
    for demographic_row in zip_code_demographics:
        db_engine.execute(
            "insert into zip_code_demographics values (%s, %s, %s)", demographic_row
        )

    for entity_zip_code in entity_zip_codes:
        db_engine.execute(
            "insert into entity_zip_codes values (%s, %s)", entity_zip_code
        )

    db_engine.execute(
        """create table zip_code_events (
        zip_code text,
        as_of_date date,
        num_events int
    )"""
    )
    for zip_code_event in zip_code_events:
        db_engine.execute(
            "insert into zip_code_events values (%s, %s, %s)", zip_code_event
        )

    for complaint in complaints:
        db_engine.execute("insert into cat_complaints values (%s, %s, %s)", complaint)

    db_engine.execute(
        """create table events (
        entity_id int,
        outcome int,
        outcome_date date
    )"""
    )

    for event in events:
        db_engine.execute("insert into events values (%s, %s, %s)", event)


def sample_cohort_config(query_source="filepath"):
    return {
        "name": "has_past_events",
        query_source: CONFIG_QUERY_DATA["cohort"][query_source],
    }


def sample_config(query_source="filepath"):
    temporal_config = {
        "feature_start_time": "2010-01-01",
        "feature_end_time": "2014-01-01",
        "label_start_time": "2011-01-01",
        "label_end_time": "2014-01-01",
        "model_update_frequency": "1year",
        "training_label_timespans": ["6months"],
        "test_label_timespans": ["6months"],
        "training_as_of_date_frequencies": "1day",
        "test_as_of_date_frequencies": "3months",
        "max_training_histories": ["6months"],
        "test_durations": ["1months"],
    }

    scoring_config = {
        "testing_metric_groups": [
            {"metrics": ["precision@"], "thresholds": {"top_n": [2]}}
        ],
        "training_metric_groups": [
            {"metrics": ["precision@"], "thresholds": {"top_n": [3]}}
        ],
        "subsets": [
            {
                "name": "evens",
                "query": """\
                    select distinct entity_id
                    from events
                    where entity_id %% 2 = 0
                    and outcome_date < '{as_of_date}'::date
                """,
            },
        ],
    }

    grid_config = {
        "sklearn.tree.DecisionTreeClassifier": {
            "min_samples_split": [10, 100],
            "max_depth": [3, 5],
            "criterion": ["gini"],
        }
    }

    feature_config = [
        {
            "prefix": "entity_features",
            "from_obj": "cat_complaints",
            "knowledge_date_column": "as_of_date",
            "aggregates_imputation": {"all": {"type": "constant", "value": 0}},
            "aggregates": [{"quantity": "cat_sightings", "metrics": ["count", "avg"]}],
            "intervals": ["1year"],
        },
        {
            "prefix": "zip_code_features",
            "from_obj": "entity_zip_codes join zip_code_events using (zip_code)",
            "knowledge_date_column": "as_of_date",
            "aggregates_imputation": {"all": {"type": "constant", "value": 0}},
            "aggregates": [{"quantity": "num_events", "metrics": ["max", "min"]}],
            "intervals": ["1year"],
        },
    ]

    cohort_config = sample_cohort_config(query_source)

    label_config = {
        query_source: CONFIG_QUERY_DATA["label"][query_source],
        "name": "custom_label_name",
        "include_missing_labels_in_train_as": False,
    }

    bias_audit_config = {
        "from_obj_query": "select * from zip_code_demographics join entity_zip_codes using (zip_code)",
        "attribute_columns": ["ethnicity"],
        "knowledge_date_column": "as_of_date",
        "entity_id_column": "entity_id",
        "ref_groups_method": "predefined",
        "ref_groups": {"ethnicity": "white"},
        "thresholds": {"percentiles": [], "top_n": [2]},
    }

    return {
        "config_version": CONFIG_VERSION,
        "random_seed": 1234,
        "label_config": label_config,
        "entity_column_name": "entity_id",
        "model_comment": "test2-final-final",
        "model_group_keys": [
            "label_name",
            "label_type",
            "custom_key",
            "class_path",
            "parameters",
        ],
        "feature_aggregations": feature_config,
        "cohort_config": cohort_config,
        "temporal_config": temporal_config,
        "grid_config": grid_config,
        "bias_audit_config": bias_audit_config,
        "prediction": {"rank_tiebreaker": "random"},
        "scoring": scoring_config,
        "user_metadata": {"custom_key": "custom_value"},
        "individual_importance": {"n_ranks": 2},
    }


@contextmanager
def assert_plot_figures_added():
    num_figures_before = plt.gcf().number
    yield
    num_figures_after = plt.gcf().number
    assert num_figures_before < num_figures_after


class CallSpy:
    """Callable-wrapper and -patcher to record invocations.

    ``CallSpy``, (unlike ``Mock``), makes it easy to wrap callables for
    the express purpose of recording how they're invoked – without
    modifying functionality. And ``CallSpy``, (unlike ``Mock``),
    reproduces the descriptor interface, such that methods can be
    patched and proxied for this purpose, as easily as functions.

    For example, as a context manager::

        with CallSpy('my_module.MyClass.my_method') as spy:
            ...

        assert (('arg0',), {'param0': 0}) in spy.calls

    """

    def __init__(self, signature):
        self.calls = []
        self.signature = signature

    @cachedproperty
    def target_path(self):
        return self.signature.split(".")

    @cachedproperty
    def target_name(self):
        return self.target_path[-1]

    @cachedproperty
    def target_base(self):
        # walk target path until can no longer import it as a module path
        for index in range(len(self.target_path)):
            path_parts = self.target_path[: (index + 1)]
            import_path = ".".join(path_parts)

            try:
                base = importlib.import_module(import_path)
            except ImportError:
                # we've imported all that we can import
                # walk the remainder by attribute access
                remainder = self.target_path[index:-1]
                for part in remainder:
                    base = getattr(base, part)

                return base

        raise ValueError(f"cannot patch signature {self.signature!r}")

    @cachedproperty
    def target_object(self):
        return getattr(self.target_base, self.target_name)

    @cachedproperty
    def patch(self):
        return mock.patch.object(self.target_base, self.target_name, new=self)

    def start(self):
        if not callable(self.target_object):
            # 1. ensure target_object set before patching
            # 2. check that it's sane (needn't be done here but reasonable)
            raise TypeError(f"signature target not callable {self.target_object!r}")

        self.patch.start()

    def stop(self):
        self.patch.stop()

    def __call__(self, *args, **kwargs):
        self.calls.append((args, kwargs))
        return self.target_object(*args, **kwargs)

    def __get__(self, instance, cls=None):
        if instance is None:
            return self

        return functools.partial(self, instance)

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, *exc):
        self.stop()