src/triage/component/architect/feature_generators.py

Summary

Maintainability
C
1 day
Test Coverage
import verboselogs, logging
logger = verboselogs.VerboseLogger(__name__)

from collections import OrderedDict

import sqlalchemy
import sqlparse

from triage.util.conf import convert_str_to_relativedelta
from triage.database_reflection import table_exists

from triage.component.collate import (
    Aggregate,
    Categorical,
    Compare,
    SpacetimeAggregation,
    FromObj
)

class FeatureGenerator:
    def __init__(
        self,
        db_engine,
        features_schema_name,
        replace=True,
        feature_start_time=None,
        materialize_subquery_fromobjs=True,
        features_ignore_cohort=False,
    ):
        """Generates aggregate features using collate

        Args:
            db_engine (sqlalchemy.db.engine)
            features_schema_name (string) Name of schema where feature
                tables should be written to
            replace (boolean, optional) Whether or not existing features
                should be replaced
            feature_start_time (string/datetime, optional) point in time before which
                should not be included in features
            features_ignore_cohort (boolean, optional) Whether or not features should be built
                independently of the cohort. Takes longer but means that features can be reused
                for different cohorts.
        """
        self.db_engine = db_engine
        self.features_schema_name = features_schema_name
        self.categorical_cache = {}
        self.replace = replace
        self.feature_start_time = feature_start_time
        self.materialize_subquery_fromobjs = materialize_subquery_fromobjs
        self.features_ignore_cohort = features_ignore_cohort
        self.entity_id_column = "entity_id"
        self.from_objs = {}

    def _validate_keys(self, aggregation_config):
        for key in [
            "from_obj",
            "intervals",
            "knowledge_date_column",
            "prefix",
        ]:
            if key not in aggregation_config:
                raise ValueError(
                    "{} required as key: aggregation config: {}".format(
                        key, aggregation_config
                    )
                )
        if "groups" in aggregation_config:
            if aggregation_config["groups"] != [self.entity_id_column]:
                raise ValueError(
                    "Specifying groupings for feature aggregation is not supported. "
                    "Features can only be grouped at the {} level.".format(self.entity_id_column)
                    )
            else:
                logger.warning(
                    "Specifying groupings for feature aggregation is not supported. "
                    "In the future, please exclude this key from your feature configuration."
                    )

    def _validate_aggregates(self, aggregation_config):
        if (
            "aggregates" not in aggregation_config
            and "categoricals" not in aggregation_config
            and "array_categoricals" not in aggregation_config
        ):
            raise ValueError(
                "Need either aggregates, categoricals, or array_categoricals"
                + " in {}".format(aggregation_config)
            )

    def _validate_categoricals(self, categoricals):
        for categorical in categoricals:
            if "choice_query" in categorical:
                logger.spam("Validating choice query")

                try:
                    with self.db_engine.begin() as conn:
                        conn.execute("explain {}".format(categorical["choice_query"]))
                except Exception as exc:
                    raise ValueError(
                        f"choice query does not run. \n"
                        f'choice query: {categorical["choice_query"]}\n'
                    )

    def _validate_from_obj(self, from_obj):
        logger.spam("Validating from_obj")
        try:
            with self.db_engine.begin() as conn:
                conn.execute("explain select * from {}".format(from_obj))
        except Exception as exc:
            raise ValueError(
                "from_obj query does not run. \n"
                'from_obj: "{from_obj}"\n'
            )

    def _validate_time_intervals(self, intervals):
        logger.spam("Validating time intervals")
        for interval in intervals:
            if interval != "all":
                convert_str_to_relativedelta(interval)

    def _validate_imputation_rule(self, aggregate_type, impute_rule):
        """Validate the imputation rule for a given aggregation type."""
        logger.spam("Validating imputation rule")
        # dictionary of imputation type : required parameters
        valid_imputations = {
            "all": {
                "mean": [],
                "constant": ["value"],
                "zero": [],
                "zero_noflag": [],
                "error": [],
            },
            "aggregates": {"binary_mode": []},
            "categoricals": {"null_category": []},
        }
        valid_imputations["array_categoricals"] = valid_imputations["categoricals"]

        # the valid imputation rules for the specific aggregation type being checked
        valid_types = dict(
            valid_imputations["all"], **valid_imputations[aggregate_type]
        )

        # no imputation rule was specified
        if "type" not in impute_rule.keys():
            raise ValueError("Imputation type must be specified")

        # a rule was specified, but not valid for this type of aggregate
        if impute_rule["type"] not in valid_types.keys():
            raise ValueError(
                f"Invalid imputation type {impute_rule['type']} for {aggregate_type}"
            )

        # check that all required parameters exist in the keys of the imputation rule
        required_params = valid_types[impute_rule["type"]]
        for param in required_params:
            if param not in impute_rule.keys():
                raise ValueError(
                    "Missing param {param} for {imput_rule['type']}"
                )

    def _validate_imputations(self, aggregation_config):
        """Validate the imputation rules in an aggregation config, looping
        through all three types of aggregates. Most of the work here is
        done by _validate_imputation_rule() to check the requirements of
        each imputation rule found
        """
        logger.spam("Validating imputations")
        agg_types = ["aggregates", "categoricals", "array_categoricals"]

        for agg_type in agg_types:
            # base_imp are the top-level rules, `such as aggregates_imputation`
            base_imp = aggregation_config.get(agg_type + "_imputation", {})

            # loop through the individual aggregates
            for agg in aggregation_config.get(agg_type, []):
                # combine any aggregate-level imputation rules with top-level ones
                imp_dict = dict(base_imp, **agg.get("imputation", {}))

                # imputation rules are metric-specific, so check each metric's rule
                for metric in agg["metrics"]:
                    # metric rules may be defined by the metric name (e.g., 'max')
                    # or with the 'all' catch-all, with named metrics taking
                    # precedence. If we fall back to {}, the rule validator will
                    # error out on no metric found.
                    impute_rule = imp_dict.get(metric, imp_dict.get("all", {}))
                    self._validate_imputation_rule(agg_type, impute_rule)

    def _validate_aggregation(self, aggregation_config):
        logger.spam(f"Validating aggregation config {aggregation_config}")
        self._validate_keys(aggregation_config)
        self._validate_aggregates(aggregation_config)
        self._validate_categoricals(aggregation_config.get("categoricals", []))
        self._validate_from_obj(aggregation_config["from_obj"])
        self._validate_time_intervals(aggregation_config["intervals"])
        self._validate_imputations(aggregation_config)

    def validate(self, feature_aggregation_config):
        """Validate a feature aggregation config applied to this object

        The validations range from basic type checks, key presence checks,
        as well as validating the sql in from objects.

        Args:
            feature_aggregation_config (list) all values, except for feature
                date, necessary to instantiate a collate.SpacetimeAggregation

        Raises: ValueError if any part of the config is found to be invalid
        """
        logger.spam(f"Validating feature aggregation configurations")
        for aggregation in feature_aggregation_config:
            self._validate_aggregation(aggregation)

    def _compute_choices(self, choice_query):
        if choice_query not in self.categorical_cache:
            with self.db_engine.begin() as conn:
                self.categorical_cache[choice_query] = [
                    row[0] for row in conn.execute(choice_query)
                ]

            logger.debug(
                "Computed list of categoricals: {self.categorical_cache[choice_query]} for choice query: {choice_query}"
            )

        return self.categorical_cache[choice_query]

    def _build_choices(self, categorical):
        logger.debug(
            f'Building categorical choices for column {categorical["column"]}, metrics {categorical["metrics"]}',
        )
        if "choices" in categorical:
            logger.debug(f'Found list of configured choices: {categorical["choices"]}')
            return categorical["choices"]
        else:
            return self._compute_choices(categorical["choice_query"])

    def _build_categoricals(self, categorical_config, impute_rules):
        # TODO: only include null flag where necessary
        return [
            Categorical(
                col=categorical["column"],
                choices=self._build_choices(categorical),
                function=categorical["metrics"],
                impute_rules=dict(
                    impute_rules,
                    coltype="categorical",
                    **categorical.get("imputation", {})
                ),
                include_null=True,
                coltype=categorical.get('coltype', None),
            )
            for categorical in categorical_config
        ]

    def _build_array_categoricals(self, categorical_config, impute_rules):
        # TODO: only include null flag where necessary
        return [
            Compare(
                col=categorical["column"],
                op="@>",
                choices={
                    choice: "array['{}'::varchar]".format(choice)
                    for choice in self._build_choices(categorical)
                },
                function=categorical["metrics"],
                impute_rules=dict(
                    impute_rules,
                    coltype="array_categorical",
                    **categorical.get("imputation", {})
                ),
                op_in_name=False,
                quote_choices=False,
                include_null=True,
                coltype=categorical.get('coltype', None)
            )
            for categorical in categorical_config
        ]

    def _aggregation(self, aggregation_config, feature_dates, state_table):
        logger.debug(
            f"Building collate.SpacetimeAggregation for config {aggregation_config} and {len(feature_dates)} as_of_dates",
        )

        # read top-level imputation rules from the aggregation config; we'll allow
        # these to be overridden by imputation rules at the individual feature
        # level as those get parsed as well
        agimp = aggregation_config.get("aggregates_imputation", {})
        catimp = aggregation_config.get("categoricals_imputation", {})
        arrcatimp = aggregation_config.get("array_categoricals_imputation", {})

        aggregates = [
            Aggregate(
                aggregate["quantity"],
                aggregate["metrics"],
                dict(agimp, coltype="aggregate", **aggregate.get("imputation", {})),
                coltype=aggregate.get('coltype', None)
            )
            for aggregate in aggregation_config.get("aggregates", [])
        ]
        logger.debug(f"Found {len(aggregates)} quantity aggregates")
        categoricals = self._build_categoricals(
            aggregation_config.get("categoricals", []), catimp
        )
        logger.debug(f"Found {len(categoricals)} categorical aggregates")
        array_categoricals = self._build_array_categoricals(
            aggregation_config.get("array_categoricals", []), arrcatimp
        )
        logger.debug(f"Found {len(array_categoricals)} array categorical aggregates")
        return SpacetimeAggregation(
            aggregates + categoricals + array_categoricals,
            from_obj=aggregation_config["from_obj"],
            intervals=aggregation_config["intervals"],
            groups=[self.entity_id_column],
            dates=feature_dates,
            state_table=state_table,
            state_group=self.entity_id_column,
            date_column=aggregation_config["knowledge_date_column"],
            output_date_column="as_of_date",
            input_min_date=self.feature_start_time,
            schema=self.features_schema_name,
            prefix=aggregation_config["prefix"],
            join_with_cohort_table=not self.features_ignore_cohort
        )

    def aggregations(self, feature_aggregation_config, feature_dates, state_table):
        """Creates collate.SpacetimeAggregations from the given arguments

        Args:
            feature_aggregation_config (list) all values, except for feature
                date, necessary to instantiate a collate.SpacetimeAggregation
            feature_dates (list) dates to generate features as of
            state_table (string) schema.table_name for state table with all entity/date pairs

        Returns: (list) collate.SpacetimeAggregations
        """
        return [
            self.preprocess_aggregation(
                self._aggregation(aggregation_config, feature_dates, state_table)
            )
            for aggregation_config in feature_aggregation_config
        ]

    def preprocess_aggregation(self, aggregation):
        create_schema = aggregation.get_create_schema()

        if create_schema is not None:
            with self.db_engine.begin() as conn:
                conn.execute(create_schema)

        if self.materialize_subquery_fromobjs:
            # materialize from obj
            from_obj = FromObj(
                from_obj=aggregation.from_obj.text,
                name=f"{aggregation.schema}.{aggregation.prefix}",
                knowledge_date_column=aggregation.date_column
            )
            from_obj.maybe_materialize(self.db_engine)
            aggregation.from_obj = from_obj.table
        return aggregation

    def generate_all_table_tasks(self, aggregations, task_type):
        """Generates SQL commands for creating, populating, and indexing
        feature group tables

        Args:
            aggregations (list) collate.SpacetimeAggregation objects
            type (str) either 'aggregation' or 'imputation'

        Returns: (dict) keys are group table names, values are themselves dicts,
            each with keys for different stages of table creation (prepare, inserts, finalize)
            and with values being lists of SQL commands
        """

        # pick the method to use for generating tasks depending on whether we're
        # building the aggregations or imputations
        if task_type == "aggregation":
            task_generator = self._generate_agg_table_tasks_for
            logger.verbose("Starting Feature aggregation")
        elif task_type == "imputation":
            task_generator = self._generate_imp_table_tasks_for
            logger.verbose("Starting Feature imputation")
        else:
            raise ValueError(f"Table task type must be aggregation or imputation, was: {table_task}")

        table_tasks = OrderedDict()
        for aggregation in aggregations:
            table_tasks.update(task_generator(aggregation))
        return table_tasks

    def create_features_before_imputation(
        self, feature_aggregation_config, feature_dates, state_table=None
    ):
        """Create features before imputation for a set of dates"""
        all_tasks = self.generate_all_table_tasks(
            self.aggregations(
                feature_aggregation_config, feature_dates, state_table=state_table
            ),
            task_type="aggregation",
        )
        logger.debug(f"Generated a total of {len(all_tasks)} table tasks")
        for task_num, task in enumerate(all_tasks.values(), 1):
            prepares = task.get("prepare", [])
            inserts = task.get("inserts", [])
            finalize = task.get("finalize", [])
            logger.verbose(f"Executing task [{task_num} / {len(all_tasks)}]")
            logger.verbose(
                f"{len(prepares)} prepare queries, {len(inserts)} insert queries, {len(finalize)} finalize queries"
            )
            logger.verbose("Executing {len(prepares)} preparation queries ")
            for query_num, query in enumerate(prepares, 1):
                logger.verbose(f"Prepare query {query_num} / {len(prepares)}")
                logger.debug(
                    f"Prepare query {query_num}: {sqlparse.format(str(query), reindent=True)}",
                )
            logger.verbose("Executing {len(inserts)} insert queries ")
            for query_num, query in enumerate(inserts, 1):
                logger.verbose(f"Insert query {query_num} / {len(inserts)}")
                logger.debug(
                    f"Insert query {query_num}: {sqlparse.format(str(query), reindent=True)}"
                )
            logger.verbose("Executing {len(finalize)} finalize queries ")
            for query_num, query in enumerate(finalize, 1):
                logger.verbose(f"Finalize query {query_num} / {len(finalize)}")
                logger.debug(
                    f"Finalize query {query_num}: {sqlparse.format(str(query), reindent=True)}"
                )
            self.process_table_task(task)

    def create_all_tables(self, feature_aggregation_config, feature_dates, state_table):
        """Create all feature tables.

        First builds the aggregation tables, and then performs
        imputation on any null values, (requiring a two-step process to
        determine which columns contain nulls after the initial
        aggregation tables are built).

        Args:
            feature_aggregation_config (list) all values, except for
                feature date, necessary to instantiate a
                `collate.SpacetimeAggregation`
            feature_dates (list) dates to generate features as of
            state_table (string) schema.table_name for state table with
                all entity/date pairs

        Returns: (list) table names

        """
        aggs = self.aggregations(feature_aggregation_config, feature_dates, state_table)

        # first, generate and run table tasks for aggregations
        table_tasks_aggregate = self.generate_all_table_tasks(aggs, task_type="aggregation")
        self.process_table_tasks(table_tasks_aggregate)

        # second, perform the imputations (this will query the tables
        # constructed above to identify features containing nulls)
        table_tasks_impute = self.generate_all_table_tasks(aggs, task_type="imputation")
        impute_keys = self.process_table_tasks(table_tasks_impute)

        # double-check that the imputation worked and no nulls remain
        # in the data:
        nullcols = []
        with self.db_engine.begin() as conn:
            for agg in aggs:
                results = conn.execute(agg.find_nulls(imputed=True))
                null_counts = results.first().items()
                nullcols += [col for (col, val) in null_counts if val > 0]

        if len(nullcols) > 0:
            raise ValueError(
                f"Imputation failed for {len(nullcols)} columns. Null values remain in: {nullcols}"
            )

        return impute_keys

    def process_table_task(self, task):
        self.run_commands(task.get("prepare", []))
        self.run_commands(task.get("inserts", []))
        self.run_commands(task.get("finalize", []))

    def process_table_tasks(self, table_tasks):
        for table_name, task in table_tasks.items():
            self.process_table_task(task)
        return table_tasks.keys()

    def _explain_selects(self, aggregations):
        with self.db_engine.begin() as conn:
            for aggregation in aggregations:
                for selectlist in aggregation.get_selects().values():
                    for select in selectlist:
                        query = "explain " + str(select)
                        results = list(conn.execute(query))
                        logger.spam(str(select))
                        logger.spam(results)

    def _clean_table_name(self, table_name):
        # remove the schema and quotes from the name
        return table_name.split(".")[1].replace('"', "")

    def _table_exists(self, table_name):
        try:
            with self.db_engine.begin() as conn:
                conn.execute(
                    "select 1 from {}.{} limit 1".format(
                        self.features_schema_name, table_name
                    )
                ).first()
        except sqlalchemy.exc.ProgrammingError:
            return False
        else:
            return True

    def run_commands(self, command_list):
        with self.db_engine.begin() as conn:
            for command in command_list:
                logger.spam(f"Executing feature generation query: {command}")
                conn.execute(command)

    def _aggregation_index_query(self, aggregation, imputed=False):
        return f"CREATE INDEX ON {aggregation.get_table_name(imputed=imputed)} ({self.entity_id_column}, {aggregation.output_date_column})"

    def _aggregation_index_columns(self, aggregation):
        return sorted(
            [group for group in aggregation.groups.keys()]
            + [aggregation.output_date_column]
        )

    def index_column_lookup(self, aggregations, imputed=True):
        return dict(
            (
                self._clean_table_name(aggregation.get_table_name(imputed=imputed)),
                self._aggregation_index_columns(aggregation),
            )
            for aggregation in aggregations
        )

    def _needs_features(self, aggregation):
        imputed_table = self._clean_table_name(
            aggregation.get_table_name(imputed=True)
        )

        if self._table_exists(imputed_table):
            check_query = (
                f"select 1 from {aggregation.state_table} "
                f"left join {self.features_schema_name}.{imputed_table} "
                f"using (entity_id, as_of_date) "
                f"where {self.features_schema_name}.{imputed_table}.entity_id is null limit 1"
            )
            if self.db_engine.execute(check_query).scalar():
                logger.notice(
                    f"Imputed feature table {imputed_table} did not contain rows from the "
                    f"entire cohort, need to rebuild features"
                )
                return True
        else:
            logger.notice(
                f"Imputed feature table {imputed_table} did not exist, "
                f"need to build features"
            )
            return True
        logger.notice(f"Imputed feature table {imputed_table} looks good, "
                      f"skipping feature building!")
        return False

    def _generate_agg_table_tasks_for(self, aggregation):
        """Generates SQL commands for preparing, populating, and finalizing
        each feature group table in the given aggregation

        Args:
            aggregation (collate.SpacetimeAggregation)

        Returns: (dict) of structure {
            'prepare': list of commands to prepare table for population
            'inserts': list of commands to populate table
            'finalize': list of commands to finalize table after population
        }
        """
        creates = aggregation.get_creates()
        drops = aggregation.get_drops()
        indexes = aggregation.get_indexes()
        inserts = aggregation.get_inserts()
        table_tasks = OrderedDict()

        needs_features = self._needs_features(aggregation)

        for group in aggregation.groups:
            group_table = self._clean_table_name(
                aggregation.get_table_name(group=group)
            )
            if self.replace or needs_features:
                table_tasks[group_table] = {
                    "prepare": [drops[group], creates[group]],
                    "inserts": inserts[group],
                    "finalize": [indexes[group]],
                }
                logger.debug(f"Created task for table {group_table}")
            else:
                logger.debug(f"Skipping feature creation for table {group_table}")
                table_tasks[group_table] = {}
        if self.replace or needs_features:
            table_tasks[self._clean_table_name(aggregation.get_table_name())] = {
                "prepare": [aggregation.get_drop(), aggregation.get_create()],
                "inserts": [],
                "finalize": [self._aggregation_index_query(aggregation)],
            }
            logger.debug(f"Created tasks for aggregation {self._clean_table_name(aggregation.get_table_name())}" )
        else:
            logger.debug(f"Skipping feature creation for table {self._clean_table_name(aggregation.get_table_name())}")
            table_tasks[self._clean_table_name(aggregation.get_table_name())] = {}

        return table_tasks

    def _generate_imp_table_tasks_for(self, aggregation, impute_cols=None, nonimpute_cols=None, drop_preagg=True):
        """Generate SQL statements for preparing, populating, and
        finalizing imputations, for each feature group table in the
        given aggregation.

        Requires the existance of the underlying feature and aggregation
        tables defined in `_generate_agg_table_tasks_for()`.

        Args:
            aggregation (collate.SpacetimeAggregation)
            drop_preagg: boolean to specify dropping pre-imputation
                tables

        Returns: (dict) of structure {
                'prepare': list of commands to prepare table for population
                'inserts': list of commands to populate table
                'finalize': list of commands to finalize table after population
            }

        """
        table_tasks = OrderedDict()
        imp_tbl_name = self._clean_table_name(aggregation.get_table_name(imputed=True))

        if not self.replace and not self._needs_features(aggregation):
            logger.debug("Skipping imputation table creation for %s", imp_tbl_name)
            table_tasks[imp_tbl_name] = {}
            return table_tasks

        if not aggregation.state_table:
            logger.debug(
                f"No state table defined in aggregation, cannot create imputation table for {imp_tbl_name}",
            )
            table_tasks[imp_tbl_name] = {}
            return table_tasks

        if not table_exists(aggregation.state_table, self.db_engine):
            logger.debug(
                f"State table {aggregation.state_table} does not exist, cannot create imputation table for {imp_tbl_name}",
            )
            table_tasks[imp_tbl_name] = {}
            return table_tasks

        # excute query to find columns with null values and create lists of columns
        # that do and do not need imputation when creating the imputation table
        with self.db_engine.begin() as conn:
            results = conn.execute(aggregation.find_nulls())
            null_counts = results.first().items()
        if impute_cols is None:
            impute_cols = [col for (col, val) in null_counts if val > 0]
        if nonimpute_cols is None:
            nonimpute_cols = [col for (col, val) in null_counts if val == 0]

        # table tasks for imputed aggregation table, most of the work is done here
        # by collate's get_impute_create()
        table_tasks[imp_tbl_name] = {
            "prepare": [
                aggregation.get_drop(imputed=True),
                aggregation.get_impute_create(
                    impute_cols=impute_cols, nonimpute_cols=nonimpute_cols
                ),
            ],
            "inserts": [],
            "finalize": [self._aggregation_index_query(aggregation, imputed=True)],
        }
        logger.debug("Created table tasks for imputation: %s", imp_tbl_name)

        # do some cleanup:
        # drop the group-level and aggregation tables, just leaving the
        # imputation table if drop_preagg=True
        if drop_preagg:
            drops = aggregation.get_drops()
            table_tasks[imp_tbl_name]["finalize"] += list(drops.values()) + [
                aggregation.get_drop()
            ]
            logger.debug("Added drop table cleanup tasks: %s", imp_tbl_name)

        return table_tasks