src/triage/experiments/validate.py

Summary

Maintainability
F
3 days
Test Coverage
import importlib

import verboselogs, logging

logger = verboselogs.VerboseLogger(__name__)

from itertools import permutations
from datetime import datetime
from textwrap import dedent

from sklearn.model_selection import ParameterGrid

from triage.component import architect
from triage.component import catwalk
from triage.component.timechop import Timechop

from triage.util.conf import convert_str_to_relativedelta, load_query_if_needed
from triage.validation_primitives import string_is_tablesafe


class Validator:
    def __init__(self, db_engine=None, strict=True, entity_id_column="entity_id"):
        self.db_engine = db_engine
        self.strict = strict
        self.entity_id_column = entity_id_column

    def run(self, *args, **kwargs):
        try:
            self._run(*args, **kwargs)
        except ValueError as e:
            if self.strict:
                raise ValueError(e)
            else:
                logger.warning(
                    "Validation error hit, not running in strict mode so continuing on: %s",
                    str(e),
                )


class TemporalValidator(Validator):
    def _run(self, temporal_config):
        logger.spam("Validating temporal configuration")

        def dt_from_str(dt_str):
            return datetime.strptime(dt_str, "%Y-%m-%d")

        splits = []
        try:
            chopper = Timechop(
                feature_start_time=dt_from_str(temporal_config["feature_start_time"]),
                feature_end_time=dt_from_str(temporal_config["feature_end_time"]),
                label_start_time=dt_from_str(temporal_config["label_start_time"]),
                label_end_time=dt_from_str(temporal_config["label_end_time"]),
                model_update_frequency=temporal_config["model_update_frequency"],
                training_label_timespans=temporal_config["training_label_timespans"],
                test_label_timespans=temporal_config["test_label_timespans"],
                training_as_of_date_frequencies=temporal_config[
                    "training_as_of_date_frequencies"
                ],
                test_as_of_date_frequencies=temporal_config[
                    "test_as_of_date_frequencies"
                ],
                max_training_histories=temporal_config["max_training_histories"],
                test_durations=temporal_config["test_durations"],
            )
            splits = chopper.chop_time()
        except Exception as e:
            raise ValueError(
                dedent(
                    """
            Section: temporal_config -
            Timechop could not produce temporal splits from config {}.
            Error: {}
            """.format(
                        temporal_config, e
                    )
                )
            )
        for split_num, split in enumerate(splits):
            if len(split["train_matrix"]["as_of_times"]) == 0:
                raise ValueError(
                    dedent(
                        """
                Section: temporal_config -
                Computed split {} has a train matrix with no as_of_times.
                """.format(
                            split
                        )
                    )
                )

            # timechop computes the last time available to train data
            # and stores it in the matrix as 'matrix_info_end_time'
            # but to be more sure, let's double-check by comparing as_of_times
            # in the train and all associated test matrices
            train_max_data_time = max(
                split["train_matrix"]["as_of_times"]
            ) + convert_str_to_relativedelta(
                split["train_matrix"]["training_label_timespan"]
            )

            for test_matrix in split["test_matrices"]:
                if len(test_matrix["as_of_times"]) == 0:
                    raise ValueError(
                        dedent(
                            """
                    Section: temporal_config -
                    Computed split {} has a test matrix with no as_of_times.
                    """.format(
                                split
                            )
                        )
                    )
                overlapping_times = [
                    as_of_time
                    for as_of_time in test_matrix["as_of_times"]
                    if as_of_time < train_max_data_time
                ]
                if overlapping_times:
                    raise ValueError(
                        dedent(
                            """
                    Section: temporal_config -
                    Computed split index {} has a test matrix with as_of_times {}
                    < the maximum train as_of_time + train label timespan.
                    ({}). This is likely an error in timechop. See the
                    experiment's split_definitions[{}] for more information""".format(
                                split_num,
                                overlapping_times,
                                train_max_data_time,
                                split_num,
                            )
                        )
                    )

        logger.debug("Validation of temporal configuration was successful")


class FeatureAggregationsValidator(Validator):
    def _validate_keys(self, aggregation_config):
        logger.spam("Validating feature aggregation keys")
        for key in [
            "from_obj",
            "intervals",
            "knowledge_date_column",
            "prefix",
        ]:
            if key not in aggregation_config:
                raise ValueError(
                    dedent(
                        """
                Section: feature_aggregations -
                '{} required as key: aggregation config: {}""".format(
                            key, aggregation_config
                        )
                    )
                )
        if not string_is_tablesafe(aggregation_config["prefix"]):
            raise ValueError(
                dedent(
                    f"""Section: feature_aggregations -
                    Feature aggregation prefix should only contain
                    lowercase letters, numbers, and underscores.
                    Aggregation config: {aggregation_config}
                    """
                )
            )
        if "groups" in aggregation_config:
            if aggregation_config["groups"] != [self.entity_id_column]:
                raise ValueError(
                    dedent(
                        """Specifying groupings for feature aggregation is 
                        not supported. Features can only be grouped at the 
                        entity_id level."""
                    )
                )
            else:
                logger.warning(
                    dedent(
                        """Specifying groupings for feature aggregation is 
                        not supported. In the future, please exclude this key 
                        from your feature configuration."""
                    )
                )

        logger.debug("Validation of feature aggregation keys was successful")

    def _validate_aggregates(self, aggregation_config):
        logger.spam("Validating aggregates")
        if (
            "aggregates" not in aggregation_config
            and "categoricals" not in aggregation_config
            and "array_categoricals" not in aggregation_config
        ):
            raise ValueError(
                dedent(
                    """
            Section: feature_aggregations -
            Need either aggregates, categoricals, or array_categoricals
            in {}""".format(
                        aggregation_config
                    )
                )
            )
        logger.debug("Validation of aggregates was successful")

    def _validate_categoricals(self, categoricals):
        logger.spam("Validating categoricals")
        conn = self.db_engine.connect()
        for categorical in categoricals:
            if "choice_query" in categorical and "choices" in categorical:
                raise ValueError(
                    dedent(
                        """
                Section: feature_aggregations -
                Both 'choice_query' and 'choices' specified for {}.
                Please only specify one.""".format(
                            categorical
                        )
                    )
                )
            if not ("choice_query" in categorical or "choices" in categorical):
                raise ValueError(
                    dedent(
                        """
                Section: feature_aggregations -
                Neither 'choice_query' and 'choices' specified for {}.
                Please specify one.""".format(
                            categorical
                        )
                    )
                )
            if "choice_query" in categorical:
                logger.spam("Validating choice query")
                choice_query = categorical["choice_query"]
                try:
                    conn.execute("explain {}".format(choice_query))
                    logger.debug("Validation of choice query was successful")
                except Exception as e:
                    raise ValueError(
                        dedent(
                            """
                    Section: feature_aggregations -
                    choice query does not run.
                    choice query: "{}"
                    Full error: {}""".format(
                                choice_query, e
                            )
                        )
                    )

        logger.debug("Validation of categoricals was successful")

    def _validate_from_obj(self, from_obj):
        conn = self.db_engine.connect()
        logger.spam("Validating from_obj")
        try:
            conn.execute("explain select * from {}".format(from_obj))
            logger.debug("Validation of from_obj was successful")
        except Exception as e:
            raise ValueError(
                dedent(
                    """
                Section: feature_aggregations -
                from_obj query does not run.
                from_obj: "{}"
                Full error: {}""".format(
                        from_obj, e
                    )
                )
            )

    def _validate_time_intervals(self, intervals):
        logger.spam("Validating time intervals")
        for interval in intervals:
            if interval != "all":
                # this function, used elsewhere to break up time intervals,
                # will throw an error if the interval can't be converted to a
                # relativedelta
                try:
                    convert_str_to_relativedelta(interval)
                    logger.debug("Validation of time intervals was successful")
                except Exception as e:
                    raise ValueError(
                        dedent(
                            """
                    Section: feature_aggregations -
                    Time interval is invalid.
                    interval: "{}"
                    Full error: {}""".format(
                                interval, e
                            )
                        )
                    )

    def _validate_imputation_rule(self, aggregate_type, impute_rule):
        """Validate the imputation rule for a given aggregation type."""
        logger.spam("Validating imputation rule")
        # dictionary of imputation type : required parameters
        valid_imputations = {
            "all": {
                "mean": [],
                "constant": ["value"],
                "zero": [],
                "zero_noflag": [],
                "error": [],
            },
            "aggregates": {"binary_mode": []},
            "categoricals": {"null_category": []},
        }
        valid_imputations["array_categoricals"] = valid_imputations["categoricals"]

        # the valid imputation rules for the specific aggregation type being checked
        valid_types = dict(
            valid_imputations["all"], **valid_imputations[aggregate_type]
        )

        # no imputation rule was specified
        if "type" not in impute_rule.keys():
            raise ValueError(
                dedent(
                    """
            Section: feature_aggregations -
            Imputation type must be specified"""
                )
            )

        # a rule was specified, but not valid for this type of aggregate
        if impute_rule["type"] not in valid_types.keys():
            raise ValueError(
                dedent(
                    """
            Section: feature_aggregations -
            Invalid imputation type %s for %s"""
                    % (impute_rule["type"], aggregate_type)
                )
            )

        # check that all required parameters exist in the keys of the imputation rule
        required_params = valid_types[impute_rule["type"]]
        for param in required_params:
            if param not in impute_rule.keys():
                raise ValueError(
                    dedent(
                        """
                Section: feature_aggregations -
                Missing param %s for %s"""
                        % (param, impute_rule["type"])
                    )
                )
        logger.debug("Validation of imputation rule was successful")

    def _validate_imputations(self, aggregation_config):
        """Validate the imputation rules in an aggregation config, looping
        through all three types of aggregates. Most of the work here is
        done by _validate_imputation_rule() to check the requirements of
        each imputation rule found
        """
        logger.spam("Validating imputation definitions")
        agg_types = ["aggregates", "categoricals", "array_categoricals"]

        for agg_type in agg_types:
            logger.spam("Validating imputation rules for aggregation type %s", agg_type)
            # base_imp are the top-level rules, `such as aggregates_imputation`
            base_imp = aggregation_config.get(agg_type + "_imputation", {})

            # loop through the individual aggregates
            for agg in aggregation_config.get(agg_type, []):
                logger.spam("Validating imputation rules for aggregation %s", agg)
                # combine any aggregate-level imputation rules with top-level ones
                imp_dict = dict(base_imp, **agg.get("imputation", {}))

                # imputation rules are metric-specific, so check each metric's rule
                for metric in agg["metrics"]:
                    logger.spam("Validating imputation rules for metric: %s", metric)
                    # metric rules may be defined by the metric name (e.g., 'max')
                    # or with the 'all' catch-all, with named metrics taking
                    # precedence. If we fall back to {}, the rule validator will
                    # error out on no metric found.
                    impute_rule = imp_dict.get(metric, imp_dict.get("all", {}))
                    self._validate_imputation_rule(agg_type, impute_rule)
        logger.debug("Validation of imputation definitions was successful")

    def _validate_aggregation(self, aggregation_config):
        logger.spam("Validating aggregation config %s", aggregation_config)
        self._validate_keys(aggregation_config)
        self._validate_aggregates(aggregation_config)
        self._validate_categoricals(aggregation_config.get("categoricals", []))
        self._validate_from_obj(aggregation_config["from_obj"])
        self._validate_time_intervals(aggregation_config["intervals"])
        self._validate_imputations(aggregation_config)
        logger.debug("Validation of aggregation config was successful")

    def _run(self, feature_aggregation_config):
        """Validate a feature aggregation config applied to this object

        The validations range from basic type checks, key presence checks,
        as well as validating the sql in from objects.

        Args:
            feature_aggregation_config (list) all values, except for feature
                date, necessary to instantiate a collate.SpacetimeAggregation

        Raises: ValueError if any part of the config is found to be invalid
        """
        if not feature_aggregation_config:
            raise ValueError(
                dedent(
                    """
            Section: feature_aggregations -
            Section not found. You must define feature aggregations."""
                )
            )
        for aggregation in feature_aggregation_config:
            self._validate_aggregation(aggregation)


class LabelConfigValidator(Validator):
    def _validate_query(self, query):
        if "{as_of_date}" not in query:
            raise ValueError(
                dedent(
                    """
            Section: label_config -
            If 'query' is used as label_config,
            {as_of_date} must be present"""
                )
            )
        if "{label_timespan}" not in query:
            raise ValueError(
                dedent(
                    """
            Section: label_config -
            If 'query' is used as label_config,
            {label_timespan} must be present"""
                )
            )
        bound_query = query.replace("{as_of_date}", "2016-01-01").replace(
            "{label_timespan}", "6month"
        )
        conn = self.db_engine.connect()
        logger.spam("Validating label query via SQL EXPLAIN")
        try:
            conn.execute("explain {}".format(bound_query))
            logger.debug("Validation of label query was successful")
        except Exception as e:
            raise ValueError(
                dedent(
                    """
                Section: label_config -
                given query can not be run with a sample as_of_date and label_timespan.
                query: "{}"
                Full error: {}""".format(
                        query, e
                    )
                )
            )

    @staticmethod
    def _validate_include_missing_labels_in_train_as(missing_label_flag):
        logger.spam("Validating include_missing_labels_in_train")
        if missing_label_flag not in {None, True, False}:
            raise ValueError(
                dedent(
                    """
            Section: label_config -
            The value for 'include_missing_labels_in_train_as', {}, is invalid.
            The key must be either absent, or a boolean value True or False
            Triage only supports binary labels at this time.""".format(
                        missing_label_flag
                    )
                )
            )
        logger.debug("Validation of include_missing_labels_in_train was successful")

    def _run(self, label_config):
        logger.spam("Validating label configuration")
        if not label_config:
            raise ValueError(
                dedent(
                    """
            Section: label_config -
            Section not found. You must define a label config."""
                )
            )

        if len(set(label_config.keys()).intersection({"query", "filepath"})) != 1:
            raise ValueError(
                dedent(
                    """
            Section: label_config -
            keys ({label_config.keys()}) do not contain exactly one of 'filepath'
            or 'query'. You must pass a filepath to a label query or include one
            in the config."""
                )
            )
        label_config = load_query_if_needed(label_config)
        if "name" in label_config and not string_is_tablesafe(label_config["name"]):
            raise ValueError(
                "Section: label_config - "
                "name should only contain lowercase letters, numbers, and underscores"
            )
        self._validate_query(label_config["query"])
        self._validate_include_missing_labels_in_train_as(
            label_config.get("include_missing_labels_in_train_as", None)
        )
        logger.debug("Validation of label configuration was successful")


class CohortConfigValidator(Validator):
    def _run(self, cohort_config):
        logger.spam("Validating of cohort configuration")
        if not cohort_config:
            logger.debug("No cohort config specified, label config will be used instead")
            return
        if len(set(cohort_config.keys()).intersection({"query", "filepath"})) != 1:
            raise ValueError(
                dedent(
                    """
            Section: cohort_config -
            keys ({cohort_config.keys()}) do not contain exactly one of 'filepath'
            or 'query'. You must pass a filepath to a cohort query or include one
            in the config."""
                )
            )
        cohort_config = load_query_if_needed(cohort_config)
        query = cohort_config["query"]
        if "{as_of_date}" not in query:
            raise ValueError(
                dedent(
                    """
            Section: cohort_config -
            If 'query' is used as cohort_config,
            {as_of_date} must be present"""
                )
            )
        if "name" in cohort_config and not string_is_tablesafe(cohort_config["name"]):
            raise ValueError(
                "Section: cohort_config - "
                "name should only contain lowercase letters, numbers, and underscores"
            )
        dated_query = query.replace("{as_of_date}", "2016-01-01")
        logger.spam("Validating cohort query via SQL EXPLAIN")
        try:
            self.db_engine.execute(f"explain {dated_query}")
            logger.debug("Validation of cohort query was successful")
        except Exception as e:
            raise ValueError(
                dedent(
                    f"""
                Section: cohort_config -
                given query can not be run with a sample as_of_date .
                query: "{query}"
                Full error: {e}"""
                )
            )
        logger.debug("Validation of cohort configuration was successful")


class FeatureGroupDefinitionValidator(Validator):
    def _validate_prefixes(self, prefix_list):
        """
        Ensure that no prefix starts with another prefix + _ to avoid
        an error with feature group subsets when the group names overlap
        """
        logger.spam("Validating feature group definitions prefixes")
        for prefix1, prefix2 in permutations(prefix_list, 2):
            if prefix2.startswith(prefix1):
                raise ValueError(
                    dedent(
                        """
                Section: feature_group_definition -
                Feature group prefixes must not overlap when using `prefix`: %s and %s"""
                        % (prefix1, prefix2)
                    )
                )
        logger.debug("Validation of feature group definitions prefixes was successful")

    def _run(self, feature_group_definition, feature_aggregation_config):
        logger.spam("Validating of feature group definitions")
        if not isinstance(feature_group_definition, dict):
            raise ValueError(
                dedent(
                    """
            Section: feature_group_definition -
            feature_group_definition must be a dictionary"""
                )
            )

        available_subsetters = (
            architect.feature_group_creator.FeatureGroupCreator.subsetters
        )
        for subsetter_name, value in feature_group_definition.items():
            if subsetter_name not in available_subsetters:
                raise ValueError(
                    dedent(
                        """
                Section: feature_group_definition -
                Unknown feature_group_definition key {} received.
                Available keys are {}""".format(
                            subsetter_name, available_subsetters
                        )
                    )
                )
            if not hasattr(value, "__iter__") or isinstance(value, (str, bytes)):
                raise ValueError(
                    dedent(
                        """
                Section: feature_group_definition -
                feature_group_definition value for {}, {}
                should be a list""".format(
                            subsetter_name, value
                        )
                    )
                )
        logger.debug("Validation of feature group definition was successful")

        if "prefix" in feature_group_definition:
            available_prefixes = {
                aggregation["prefix"] for aggregation in feature_aggregation_config
            }
            bad_prefixes = set(feature_group_definition["prefix"]) - available_prefixes
            if bad_prefixes:
                raise ValueError(
                    dedent(
                        """
                Section: feature_group_definition -
                The following given feature group prefixes: '{}'
                are invalid. Available prefixes from this experiment's feature
                aggregations are: '{}'
                """.format(
                            bad_prefixes, available_prefixes
                        )
                    )
                )
            self._validate_prefixes(feature_group_definition["prefix"])

        if "tables" in feature_group_definition:
            available_tables = {
                aggregation["prefix"] + "_aggregation_imputed"
                for aggregation in feature_aggregation_config
            }
            bad_tables = set(feature_group_definition["tables"]) - available_tables
            if bad_tables:
                raise ValueError(
                    dedent(
                        """
                Section: feature_group_definition -
                The following given feature group tables: '{}'
                are invalid. Available tables from this experiment's feature
                aggregations are: '{}'
                """.format(
                            bad_tables, available_tables
                        )
                    )
                )


class FeatureGroupStrategyValidator(Validator):
    def _run(self, feature_group_strategies):
        logger.spam("Validating feature group strategies")
        if not isinstance(feature_group_strategies, list):
            raise ValueError(
                dedent(
                    """
            Section: feature_group_strategies -
            feature_group_strategies section must be a list"""
                )
            )
        available_strategies = {
            key
            for key in architect.feature_group_mixer.FeatureGroupMixer.strategy_lookup.keys()
        }
        bad_strategies = set(feature_group_strategies) - available_strategies
        if bad_strategies:
            raise ValueError(
                dedent(
                    """
            Section: feature_group_strategies -
            The following given feature group strategies:
            '{}' are invalid. Available strategies are: '{}'
            """.format(
                        bad_strategies, available_strategies
                    )
                )
            )
        logger.debug("Validation of feature group strategies was successful")


class UserMetadataValidator(Validator):
    def _run(self, user_metadata):
        logger.spam("Validating user metadata")
        if not isinstance(user_metadata, dict):
            raise ValueError(
                dedent(
                    """
            Section: user_metadata -
            user_metadata section must be a dict"""
                )
            )
        logger.debug("Validation of user metadata was successful")


class ModelGroupKeysValidator(Validator):
    def _run(self, model_group_keys, user_metadata):
        logger.spam("Validating model group keys")
        if not isinstance(model_group_keys, list):
            raise ValueError(
                dedent(
                    """
            Section: model_group_keys -
            model_group_keys section must be a list"""
                )
            )
        classifier_keys = ["class_path", "parameters"]
        # planner_keys are defined in architect.Planner.make_metadata
        planner_keys = [
            "feature_start_time",
            "end_time",
            "indices",
            "feature_names",
            "feature_groups",
            "label_name",
            "label_type",
            "label_timespan",
            "state",
            "cohort_name",
            "matrix_id",
            "matrix_type",
        ]
        # temporal_keys are defined in
        # timechop.Timechop.generate_matrix_definition
        temporal_keys = [
            "first_as_of_time",
            "last_as_of_time",
            "matrix_info_end_time",
            "as_of_times",
            "training_label_timespan",
            "training_as_of_date_frequency",
            "max_training_history",
        ]
        available_keys = (
            [key for key in user_metadata.keys()]
            + planner_keys
            + temporal_keys
            + classifier_keys
        )
        for model_group_key in model_group_keys:
            if model_group_key not in available_keys:
                raise ValueError(
                    dedent(
                        """
                Section: model_group_keys -
                unknown entry '{}' received. Available keys are {}
                """.format(
                            model_group_key, available_keys
                        )
                    )
                )
        logger.debug("Validation of model group keys was successful")


class GridConfigValidator(Validator):
    def _run(self, grid_config):
        logger.spam("Validating grid configuration")
        if not grid_config:
            raise ValueError(
                dedent(
                    """
            Section: grid_config -
            Section not found. You must define a grid_config."""
                )
            )
        for classpath, parameter_config in grid_config.items():
            if classpath == "sklearn.linear_model.LogisticRegression":
                logger.warning(
                    "sklearn.linear_model.LogisticRegression found in grid. "
                    "This is unscaled and not well-suited for Triage experiments. "
                    "Use triage.component.catwalk.estimators.classifiers.ScaledLogisticRegression "
                    "instead"
                )
            try:
                module_name, class_name = classpath.rsplit(".", 1)
                module = importlib.import_module(module_name)
                cls = getattr(module, class_name)
                for parameters in ParameterGrid(parameter_config):
                    try:
                        cls(**parameters)
                    except Exception as e:
                        raise ValueError(
                            dedent(
                                """
                        Section: grid_config -
                        Unable to instantiate classifier {} with parameters {}, error thrown: {}
                        """.format(
                                    classpath, parameters, e
                                )
                            )
                        )
            except Exception as e:
                raise ValueError(
                    dedent(
                        """
                Section: grid_config -
                Unable to import classifier {}, error thrown: {}
                """.format(
                            classpath, e
                        )
                    )
                )

        logger.debug("Validation of grid configuration was successful")


class PredictionConfigValidator(Validator):
    def _run(self, prediction_config):
        logger.spam("Validating prediction configuration")
        rank_tiebreaker = prediction_config.get("rank_tiebreaker", None)
        # the tiebreaker is optional, so only try and validate if it's there
        if (
            rank_tiebreaker
            and rank_tiebreaker not in catwalk.utils.AVAILABLE_TIEBREAKERS
        ):
            raise ValueError(
                "Section: prediction - "
                f"given tiebreaker must be in {catwalk.utils.AVAILABLE_TIEBREAKERS}"
            )
        logger.spam("Validation of prediction configuration was successful")


class ScoringConfigValidator(Validator):
    def _run(self, scoring_config):
        logger.spam("Validating scoring configuration")
        if "testing_metric_groups" not in scoring_config:
            logger.warning(
                "Section: scoring - No testing_metric_groups configured. "
                + "Your experiment may run, but you will not have any "
                + "evaluation metrics computed"
            )
        if "training_metric_groups" not in scoring_config:
            logger.warning(
                "Section: scoring - No training_metric_groups configured. "
                + "If training set evaluation metrics are desired, they must be added"
            )
        metric_lookup = catwalk.evaluation.ModelEvaluator.available_metrics
        available_metrics = set(metric_lookup.keys())
        for group in ("testing_metric_groups", "training_metric_groups"):
            for metric_group in scoring_config.get(group, {}):
                given_metrics = set(metric_group["metrics"])
                bad_metrics = given_metrics - available_metrics
                if bad_metrics:
                    raise ValueError(
                        dedent(
                            """Section: scoring -
                        The following given metrics '{}' are unavailable.
                        Available metrics are: '{}'
                        """.format(
                                bad_metrics, available_metrics
                            )
                        )
                    )
                for given_metric in given_metrics:
                    metric_function = metric_lookup[given_metric]
                    if not hasattr(metric_function, "greater_is_better"):
                        raise ValueError(
                            dedent(
                                """Section: scoring -
                        The metric {} does not define the attribute
                        'greater_is_better'. This can only be fixed in the catwalk.metrics
                        module. If you still would like to use this metric, consider
                        submitting a pull request""".format(
                                    given_metric
                                )
                            )
                        )

            if "subsets" in scoring_config:
                for subset in scoring_config["subsets"]:
                    # 1. Validate that all required keys are present
                    if "query" not in subset:
                        raise ValueError(
                            dedent(
                                f"""Section: subsets -
                                The subset {subset} does not have a query key.
                                To run evaluations on a subset, you must
                                include a query that returns a list of distinct
                                entity_ids and has a placeholder for an
                                as_of_date
                                """
                            )
                        )
                    if "name" not in subset:
                        raise ValueError(
                            dedent(
                                f"""Section: subsets -
                                The subset {subset} does not have a name key.
                                Please give a name to your subset. This is used
                                in the namespacing of subset tables created by
                                triage.
                                """
                            )
                        )
                    if not string_is_tablesafe(subset["name"]):
                        raise ValueError(
                            dedent(
                                f"""Section: subsets -
                                The subset {subset} name should only contain
                                lowercase letters, numbers, and underscores
                                """
                            )
                        )

                    # 2. Validate that query conforms to the expectations
                    if "{as_of_date}" not in subset["query"]:
                        raise ValueError(
                            dedent(
                                f"""Section: subsets -
                                The subset query {subset["query"]} must
                                include a placeholder for the as_of_date
                                """
                            )
                        )
                    if "entity_id" not in subset["query"]:
                        raise ValueError(
                            dedent(
                                f"""The subset qeury {subset["query"]} must
                                return a list of distinct entity_ids
                                """
                            )
                        )

        logger.debug("Validation of scoring configuration was successful")


class BiasAuditConfigValidator(Validator):
    def _run(self, bias_audit_config):
        logger.spam("Validating bias audit configuration")
        if not bias_audit_config:
            # if empty, that's fine, shortcut out
            return
        if (
            "from_obj_query" in bias_audit_config
            and "from_obj_table" in bias_audit_config
        ):
            raise ValueError(
                dedent(
                    """
                    Section: bias_audit_config -
                    Both 'from_obj_query' and 'from_obj_table' specified .
                    Please only specify one."""
                )
            )
        if (
            "from_obj_query" not in bias_audit_config
            and "from_obj_table" not in bias_audit_config
        ):
            raise ValueError(
                dedent(
                    """
                    Section: bias_audit_config -
                    Neither 'from_obj_query' and 'from_obj_table' specified .
                    Please specify one."""
                )
            )
        for key in [
            "attribute_columns",
            "knowledge_date_column",
            "entity_id_column",
            "ref_groups_method",
        ]:
            if key not in bias_audit_config:
                raise ValueError(
                    dedent(
                        f"""Section: bias_audit_config - '{key} required as key: bias_audit_config config: {bias_audit_config}"""
                    )
                )
        percentile_thresholds = bias_audit_config.get("thresholds", {}).get(
            "percentiles", []
        )
        if any(threshold < 0 or threshold > 100 for threshold in percentile_thresholds):
            raise ValueError(
                "Section: bias_audit_config - All percentile thresholds must be between 0 and 100"
            )
        logger.debug("Validation of bias audit configuration was successful")


class ExperimentValidator(Validator):
    def run(self, experiment_config):
        logger.spam("Validating experiment configuration")
        TemporalValidator(strict=self.strict).run(
            experiment_config.get("temporal_config", {})
        )
        FeatureAggregationsValidator(self.db_engine, strict=self.strict).run(
            experiment_config.get("feature_aggregations", {})
        )
        LabelConfigValidator(self.db_engine, strict=self.strict).run(
            experiment_config.get("label_config", None)
        )
        CohortConfigValidator(self.db_engine, strict=self.strict).run(
            experiment_config.get("cohort_config", {})
        )
        FeatureGroupDefinitionValidator(strict=self.strict).run(
            experiment_config.get("feature_group_definition", {}),
            experiment_config.get("feature_aggregations", {}),
        )
        FeatureGroupStrategyValidator(strict=self.strict).run(
            experiment_config.get("feature_group_strategies", [])
        )
        UserMetadataValidator(strict=self.strict).run(
            experiment_config.get("user_metadata", {})
        )
        ModelGroupKeysValidator(strict=self.strict).run(
            experiment_config.get("model_group_keys", []),
            experiment_config.get("user_metadata", {}),
        )
        GridConfigValidator(strict=self.strict).run(
            experiment_config.get("grid_config", {})
        )
        PredictionConfigValidator(self.db_engine, strict=self.strict).run(
            experiment_config.get("prediction", {})
        )
        ScoringConfigValidator(strict=self.strict).run(
            experiment_config.get("scoring", {})
        )
        BiasAuditConfigValidator(strict=self.strict).run(
            experiment_config.get("bias_audit_config", {})
        )

        if self.strict:
            logger.success("Experiment validation ran to completion with no errors")
        else:
            logger.warning(
                "Experiment validation complete. All configuration problems have been displayed as warnings"
            )