src/triage/component/architect/utils.py

Summary

Maintainability
A
3 hrs
Test Coverage
import verboselogs, logging
logger = verboselogs.VerboseLogger(__name__)

import datetime
import shutil
import sys
import random
from contextlib import contextmanager
import functools
import operator
import tempfile
import subprocess

import sqlalchemy

import pandas as pd
import numpy as np
from sqlalchemy.orm import sessionmaker

from triage.component.results_schema import Model
from triage.util.structs import FeatureNameList




def str_in_sql(values):
    return ",".join(map(lambda x: "'{}'".format(x), values))


def feature_list(feature_dictionary):
    """Convert a feature dictionary to a sorted list

    Args: feature_dictionary (dict)

    Returns: sorted list of feature names
    """
    if not feature_dictionary:
        return FeatureNameList()
    return FeatureNameList(sorted(
        functools.reduce(
            operator.concat,
            (feature_dictionary[key] for key in feature_dictionary.keys()),
        )
    ))


def convert_string_column_to_date(column):
    return [datetime.datetime.strptime(date, "%Y-%m-%d").date() for date in column]


def create_features_table(table_number, table, engine):
    engine.execute(
        """
            create table features.features{} (
                entity_id int, as_of_date date, f{} int, f{} int
            )
        """.format(
            table_number, (table_number * 2) + 1, (table_number * 2) + 2
        )
    )
    for row in table:
        engine.execute(
            """
                insert into features.features{} values (%s, %s, %s, %s)
            """.format(
                table_number
            ),
            row,
        )


def create_entity_date_df(
    labels,
    states,
    as_of_dates,
    state_one,
    state_two,
    label_name,
    label_type,
    label_timespan,
):
    """ This function makes a pandas DataFrame that mimics the entity-date table
    for testing against.
    """
    0, "2016-02-01", "1 month", "booking", "binary", 0
    labels_table = pd.DataFrame(
        labels,
        columns=[
            "entity_id",
            "as_of_date",
            "label_timespan",
            "label_name",
            "label_type",
            "label",
        ],
    )
    states_table = pd.DataFrame(
        states, columns=["entity_id", "as_of_date", "state_one", "state_two"]
    ).set_index(["entity_id", "as_of_date"])
    as_of_dates = [date.date() for date in as_of_dates]
    labels_table = labels_table[labels_table["label_name"] == label_name]
    labels_table = labels_table[labels_table["label_type"] == label_type]
    labels_table = labels_table[labels_table["label_timespan"] == label_timespan]
    labels_table = labels_table.join(other=states_table, on=("entity_id", "as_of_date"))
    labels_table = labels_table[labels_table["state_one"] & labels_table["state_two"]]
    ids_dates = labels_table[["entity_id", "as_of_date"]]
    ids_dates = ids_dates.sort_values(["entity_id", "as_of_date"])
    ids_dates["as_of_date"] = [
        datetime.datetime.strptime(date, "%Y-%m-%d").date()
        for date in ids_dates["as_of_date"]
    ]
    ids_dates = ids_dates[ids_dates["as_of_date"].isin(as_of_dates)]
    logger.spam(ids_dates)

    return ids_dates.reset_index(drop=True)


def change_datetimes_on_metadata(metadata):
    metadata_keys = list(metadata.keys())
    
    for element in metadata_keys: 
        if element.endswith("_time"): 
            metadata[element] = str(metadata[element])

    #variables = ['end_time', 'feature_start_time', 'first_as_of_time', 'last_of_time', 'matrix_info_end_time']
    #for variable in variables:
    #    metadata[variable] = str(metadata[variable])

    return metadata


def NamedTempFile():
    if sys.version_info >= (3, 0, 0):
        return tempfile.NamedTemporaryFile(mode="w+", newline="")
    else:
        return tempfile.NamedTemporaryFile()


@contextmanager
def TemporaryDirectory():
    name = tempfile.mkdtemp()
    try:
        yield name
    finally:
        shutil.rmtree(name)


def fake_labels(length):
    return np.array([random.choice([True, False]) for i in range(0, length)])


class MockTrainedModel:
    def predict_proba(self, dataset):
        return np.random.rand(len(dataset), len(dataset))


def fake_trained_model(project_path, model_storage_engine, db_engine):
    """Creates and stores a trivial trained model

    Args:
        project_path (string) a desired fs/s3 project path
        model_storage_engine (triage.storage.ModelStorageEngine)
        db_engine (sqlalchemy.engine)

    Returns:
        (int) model id for database retrieval
    """
    trained_model = MockTrainedModel()
    model_storage_engine.write(trained_model, "abcd")
    session = sessionmaker(db_engine)()
    db_model = Model(model_hash="abcd")
    session.add(db_model)
    session.commit()
    return trained_model, db_model.model_id


def assert_index(engine, table, column):
    """Assert that a table has an index on a given column

    Does not care which position the column is in the index
    Modified from https://www.gab.lc/articles/index_on_id_with_postgresql

    Args:
        engine (sqlalchemy.engine) a database engine
        table (string) the name of a table
        column (string) the name of a column
    """
    query = """
        SELECT 1
        FROM pg_class t
             JOIN pg_index ix ON t.oid = ix.indrelid
             JOIN pg_class i ON i.oid = ix.indexrelid
             JOIN pg_attribute a ON a.attrelid = t.oid
        WHERE
             a.attnum = ANY(ix.indkey) AND
             t.relkind = 'r' AND
             t.relname = '{table_name}' AND
             a.attname = '{column_name}'
    """.format(
        table_name=table, column_name=column
    )
    num_results = len([row for row in engine.execute(query)])
    assert num_results >= 1


def create_dense_state_table(db_engine, table_name, data):
    db_engine.execute(
        """create table {} (
        entity_id int,
        state text,
        start_time timestamp,
        end_time timestamp
    )""".format(
            table_name
        )
    )

    for row in data:
        db_engine.execute(
            "insert into {} values (%s, %s, %s, %s)".format(table_name), row
        )


def create_binary_outcome_events(db_engine, table_name, events_data):
    db_engine.execute(
        "create table events (entity_id int, outcome_date date, outcome bool)"
    )
    for event in events_data:
        db_engine.execute(
            "insert into {} values (%s, %s, %s::bool)".format(table_name), event
        )


def retry_if_db_error(exception):
    return isinstance(exception, sqlalchemy.exc.OperationalError)


def _num_elements(x):
    """Extract the number of rows from the subprocess output"""
    return int(str(x.stdout, encoding="utf-8").split(" ")[0])


def check_rows_in_files(filenames, matrix_uuid):
    """Checks if the number of rows among all the CSV files for features and 
    and label for a matrix uuid are the same. 

    Args:
        filenames (List): List of CSV files to check the number of rows
        path_ (string): Path to get the temporal csv files
    """
    outputs = []
    for element in filenames:
        logging.debug(f"filename: {element}")
        just_filename = element.split("/")[-1]
        if (element.endswith(".csv")) and (just_filename.startswith(matrix_uuid)):
            cmd_line = "wc -l " + element
            outputs.append(subprocess.run(cmd_line, shell=True, capture_output=True))

    # get the number of rows from the subprocess
    rows = [_num_elements(output) for output in outputs]
    rows_set = set(rows)
    logging.debug(f"number of rows in files {rows_set}")

    if len(rows_set) == 1: 
        return True
    else:
        return False

def check_entity_ids_in_files(filenames, matrix_uuid):
    """Verifies if all the files in features and label have the same exact entity ids and knowledge dates"""
    # get first 2 columns on each file (entity_id, knowledge_date)
    for element in filenames: 
        logging.debug(f"getting entity id and knowledge date from features {element}")
        just_filename = element.split("/")[-1]
        prefix = element.split(".")[0]
        if (element.endswith(".csv")) and (just_filename.startswith(matrix_uuid)):
            cmd_line = f"cut -d ',' -f 1,2 {element} | sort -k 1,2 > {prefix}_sorted.csv"
            subprocess.run(cmd_line, shell=True)
    
    base_file = filenames[0]
    comparisons = []
    for i in range(1, len(filenames)):
        if (filenames[i].endswith(".csv")) and (filenames[i].startswith(matrix_uuid)):
            cmd_line = f"diff {base_file} {filenames[i]}"
            comparisons.append(subprocess.run(cmd_line, shell=True, capture_output=True))
    
    if len(comparisons) == 0:
        return True
    else:
        return False


def remove_entity_id_and_knowledge_dates(filenames, matrix_uuid):
    """drop entity id and knowledge date from all features and label files but one""" 
    correct_filenames = []

    for i in range(len(filenames)):
        just_filename = filenames[i].split("/")[-1]
        prefix = filenames[i].split(".")[0]
        if not (just_filename.endswith("_sorted.csv")) and (just_filename.startswith(matrix_uuid)):
            if prefix.endswith("_0"): 
                # only the first file will have entity_id and knowledge data but needs to also be sorted
                cmd_line = f"sort -k 1,2 {filenames[i]} > {prefix}_fixed.csv"
            else:
                cmd_line = f"sort -k 1,2 {filenames[i]} | cut -d ',' -f 3- > {prefix}_fixed.csv"
            subprocess.run(cmd_line, shell=True)
            # all files now the header in the last row (after being sorted)
            # from https://www.unix.com/shell-programming-and-scripting/128416-use-sed-move-last-line-top.html
            # move last line to first line
            cmd_line = f"sed -i '1h;1d;$!H;$!d;G' {prefix}_fixed.csv"
            subprocess.run(cmd_line, shell=True)
            correct_filenames.append(f"{prefix}_fixed.csv")

    return correct_filenames

    
def generate_list_of_files_to_remove(filenames, matrix_uuid):
    """Generate the list of all files that need to be removed"""
    # adding _sorted
    rm_files = []

    for element in filenames:
        rm_files.append(element)
        if (element.split("/")[-1].startswith(matrix_uuid)):
            prefix = element.split(".")[0]
            # adding sorted files 
            rm_files.append(prefix + "_sorted.csv")
            # adding fixed files
            rm_files.append(prefix + "_fixed.csv")

    logging.debug(f"Files to be removed {rm_files}")
    return rm_files