whylabs/whylogs-python

View on GitHub
python/whylogs/core/predicate_parser.py

Summary

Maintainability
C
7 hrs
Test Coverage
import re
from typing import List, Optional, Tuple

from whylogs.core.dataset_profile import DatasetProfile
from whylogs.core.metric_getters import MetricGetter, ProfileGetter
from whylogs.core.metrics.metrics import Metric
from whylogs.core.relations import (
    LiteralGetter,
    Predicate,
    Relation,
    ValueGetter,
    unescape_colon,
    unescape_quote,
)


def _unescape_token(input: str) -> str:
    if input[0] == '"':
        return unescape_quote(input)
    if input[0] == ":":
        return unescape_colon(input)
    return input


# We use this with findall() to split an input string into tokens.
# (?::(?:[^:]|\\:)*:[_a-zA-Z0-9]+/[_a-zA-Z0-9:/]+) matches metric references with \: escaped in column names
# "(?:[^\\"]|\\[^"]|\\")*" matches quoted strings with \"  \\ escaped quotes and \
# (?:[^ ":][^ ]*) matches everything else delimited by spaces

_TOKEN_RE = re.compile(r'(?::(?:[^:]|\\:)*:[_a-zA-Z0-9]+/[_a-zA-Z0-9:/]+)|(?:[^ ":][^ ]*)|"(?:[^\\"]|\\[^"]|\\")*"')


def _tokenize(input: str) -> List[str]:
    return [_unescape_token(t) for t in _TOKEN_RE.findall(input)]


def _get_component(token: List[str], i: int) -> Tuple[str, int]:
    return token[i], i + 1  # either dummy variable x or metric component name in metric::to_summary_dict()


# These match metric references with our without column names, with capture groups to
# pull out the column name and path. They have an optional submetric_name:submetric_namespace/
# section to support MultiMetric references.

_METRIC_REF = re.compile(r"::[_a-zA-Z0-9]+/(?:[_a-zA-Z0-9]+:[_a-zA-Z0-9]+/)?[_a-zA-Z0-9]+")
_PROFILE_REF = re.compile(r":(.+?):([_a-zA-Z0-9]+/(?:[_a-zA-Z0-9]+:[_a-zA-Z0-9]+/)?[_a-zA-Z0-9]+)")


def _get_value(
    token: List[str], i: int, metric: Optional[Metric] = None, profile: Optional[DatasetProfile] = None
) -> Tuple[ValueGetter, int]:
    if token[i].startswith('"'):
        return LiteralGetter(token[i][1:-1]), i + 1

    if _METRIC_REF.fullmatch(token[i]):
        if metric is None:
            raise ValueError("Must specify metric to use with MetricGetter")

        namespace, path = token[i][2:].split("/", 1)
        if metric.namespace != namespace:
            raise ValueError(f"Expected {namespace} metric but got {metric.namespace}")

        return MetricGetter(metric, path), i + 1

    match = _PROFILE_REF.fullmatch(token[i])
    if bool(match):
        if profile is None:
            raise ValueError("Must specify profile to use with ProfileGetter")

        column_name, path = match.groups()  # type: ignore
        return ProfileGetter(profile, column_name, path), i + 1

    if bool(re.fullmatch(r"[-+]?\d+", token[i])):
        return LiteralGetter(int(token[i])), i + 1

    return LiteralGetter(float(token[i])), i + 1


def _deserialize(
    token: List[str], i: int, metric: Optional[Metric] = None, profile: Optional[DatasetProfile] = None
) -> Tuple[Predicate, int]:
    if token[i] == "~":
        component, i = _get_component(token, i + 1)
        value, i = _get_value(token, i, metric, profile)
        return Predicate(Relation.match, value(), component=component), i

    if token[i] == "~=":
        component, i = _get_component(token, i + 1)
        value, i = _get_value(token, i, metric, profile)
        return Predicate(Relation.fullmatch, value(), component=component), i

    if token[i] == "==":
        component, i = _get_component(token, i + 1)
        value, i = _get_value(token, i, metric, profile)
        return Predicate(Relation.equal, value, component=component), i

    if token[i] == "<":
        component, i = _get_component(token, i + 1)
        value, i = _get_value(token, i, metric, profile)
        return Predicate(Relation.less, value, component=component), i

    if token[i] == "<=":
        component, i = _get_component(token, i + 1)
        value, i = _get_value(token, i, metric, profile)
        return Predicate(Relation.leq, value, component=component), i

    if token[i] == ">":
        component, i = _get_component(token, i + 1)
        value, i = _get_value(token, i, metric, profile)
        return Predicate(Relation.greater, value, component=component), i

    if token[i] == ">=":
        component, i = _get_component(token, i + 1)
        value, i = _get_value(token, i, metric, profile)
        return Predicate(Relation.geq, value, component=component), i

    if token[i] == "!=":
        component, i = _get_component(token, i + 1)
        value, i = _get_value(token, i, metric, profile)
        return Predicate(Relation.neq, value, component=component), i

    if token[i] == "and":
        left, i = _deserialize(token, i + 1)
        right, i = _deserialize(token, i, metric, profile)
        return Predicate(Relation._and, left=left, right=right, component=left._component), i

    if token[i] == "or":
        left, i = _deserialize(token, i + 1)
        right, i = _deserialize(token, i, metric, profile)
        return Predicate(Relation._or, left=left, right=right, component=left._component), i

    if token[i] == "not":
        right, i = _deserialize(token, i + 1)
        return Predicate(Relation._not, right=right, component=right._component), i

    raise ValueError("Unable to parse predicate expression '{' '.join(token)}' at token {i+1}.")


def parse_predicate(
    expression: str, metric: Optional[Metric] = None, profile: Optional[DatasetProfile] = None
) -> Predicate:
    predicate, _ = _deserialize(_tokenize(expression), 0, metric, profile)
    return predicate