src/triage/component/architect/label_generators.py

Summary

Maintainability
A
1 hr
Test Coverage
import verboselogs, logging
logger = verboselogs.VerboseLogger(__name__)

import textwrap
from triage.database_reflection import table_row_count, table_exists, table_has_duplicates

DEFAULT_LABEL_NAME = "outcome"


class LabelGeneratorNoOp:
    def generate_all_labels(self, labels_table, as_of_dates, label_timespans):
        logger.warning(
            "No label configuration is available, so no labels will be created"
        )

    def generate(self, start_date, label_timespan, labels_table):
        logger.warning(
            "No label configuration is available, so no labels will be created"
        )

    def clean_up(self, labels_table_name):
        pass


class LabelGenerator:
    def __init__(self, db_engine, query, label_name=None, replace=True):
        self.db_engine = db_engine
        self.replace = replace
        # query is expected to select a number of entity ids
        # and an outcome for each given an as-of-date
        self.query = query
        self.label_name = label_name or DEFAULT_LABEL_NAME

    def _create_labels_table(self, labels_table_name):
        if self.replace or not table_exists(labels_table_name, self.db_engine):
            self.db_engine.execute(f"drop table if exists {labels_table_name}")
            self.db_engine.execute(
                f"""
                create table {labels_table_name} (
                entity_id int,
                as_of_date date,
                label_timespan interval,
                label_name varchar,
                label_type varchar,
                label smallint
                )"""
            )
        else:
            logger.notice(f"Not dropping and recreating {labels_table_name} table because "
                          f"replace flag was set to False and table was found to exist")

    def generate_all_labels(self, labels_table, as_of_dates, label_timespans):
        self._create_labels_table(labels_table)
        logger.spam(f"Creating labels for {len(as_of_dates)} as of dates and {len(label_timespans)} label timespans")
        for as_of_date in as_of_dates:
            for label_timespan in label_timespans:
                if not self.replace:
                    logger.spam(f"Looking for existing labels for as of date {as_of_date} and label timespan {label_timespan}")
                    any_existing_labels = list(self.db_engine.execute(
                        f"""select 1 from {labels_table}
                        where as_of_date = '{as_of_date}'
                        and label_timespan = '{label_timespan}'::interval
                        and label_name = '{self.label_name}'
                        limit 1
                        """)
                    )
                    if len(any_existing_labels) == 1:
                        logger.spam("Since nonzero existing labels found, skipping")
                        continue

                logger.debug(
                    f"Generating labels for as of date {as_of_date} and label timespan {label_timespan}",
                )
                self.generate(
                    start_date=as_of_date,
                    label_timespan=label_timespan,
                    labels_table=labels_table,
                )

        self.db_engine.execute(
            f"create index on {labels_table} (entity_id, as_of_date)"
        )
        logger.spam("Added index to labels table")

        nrows = table_row_count(labels_table, self.db_engine)

        if nrows == 0:
            logger.warning(f"Done creating labels, but no rows in {labels_table} table!")
            raise ValueError(f"{labels_table} is empty!")

        if table_has_duplicates(
            labels_table,
            ['entity_id', 'as_of_date', 'label_timespan', 'label_name', 'label_type'],
            self.db_engine
            ):
            raise ValueError(f"Duplicates found in {labels_table}!")

        logger.debug(f"Labels table generated at {labels_table}")
        logger.spam(f"Row count of {labels_table}: {nrows}")

    def generate(self, start_date, label_timespan, labels_table):
        """Generate labels table using a query

        Parameters
        ----------
        start_date: str
            as of date
        label_timespan: str
            postgresql readable time interval
        labels_table: str
            name of labels table
        """
        # we want to apply the as-of-date and label in the database driver,
        # so replace the user {as_of_date} with the SQL %(as_of_date)

        query_with_db_variables = self.query.format(
            as_of_date=start_date, label_timespan=label_timespan
        )

        full_insert_query = textwrap.dedent(
            f"""
            insert into {labels_table}
            select
                entities_and_outcomes.entity_id,
                '{start_date}' as as_of_date,
                '{label_timespan}'::interval as label_timespan,
                '{self.label_name}' as label_name,
                'binary' as label_type,
                entities_and_outcomes.outcome as label
            from ({query_with_db_variables}) entities_and_outcomes
            """
        )

        logger.spam("Running label insertion query")
        logger.spam(full_insert_query)
        self.db_engine.execute(full_insert_query)

    def clean_up(self, labels_table_name):
        self.db_engine.execute(f"drop table if exists {labels_table_name}")