stphivos/django-mock-queries

View on GitHub
django_mock_queries/utils.py

Summary

Maintainability
C
1 day
Test Coverage
from datetime import datetime, date
from django.core.exceptions import FieldError
from django.db.models import F, Value, Case
from django.db.models.functions import Coalesce
from unittest.mock import Mock

from .comparisons import *
from .constants import *
from .exceptions import *

import django_mock_queries.query


def merge(first, second):
    return first + list(set(second) - set(first))


def intersect(first, second):
    return list(set(first).intersection(second))


def get_field_mapping(field):
    name = field.get_accessor_name()
    model_name = field.related_model._meta.model_name.lower()

    if name[-4:] == '_set':
        return {model_name: name}
    else:
        return {name: name}


def find_field_names_from_meta(meta, annotated=None, **kwargs):
    field_names = {}
    annotated = annotated or []
    concrete_only = kwargs.get('concrete_only', False)

    if concrete_only:
        fields_no_mapping = [f.attname for f in meta.concrete_fields] + annotated
        fields_with_mapping = []
    else:
        fields_no_mapping = [f for f in meta._forward_fields_map.keys()] + annotated
        fields_with_mapping = [f for f in meta.fields_map.values()]

        for parent in meta.parents.keys():
            fields_no_mapping.extend([key for key in find_field_names(parent)[0]])

    for field in fields_no_mapping:
        field_names[field] = field

    for field in fields_with_mapping:
        field_names.update(get_field_mapping(field))

    return list(field_names.keys()), list(field_names.values())


def find_field_names_from_obj(obj, **kwargs):
    lookup_fields, actual_fields = [], []

    if type(obj) is dict:
        lookup_fields = actual_fields = list(obj.keys())
    else:
        # It is possibly a MockSet.
        use_obj = getattr(obj, 'model', None)

        # Make it easier for MockSet, but Django's QuerySet will always have a model.
        if not use_obj and is_list_like_iter(obj) and len(obj) > 0:
            lookup_fields, actual_fields = find_field_names(obj[0], **kwargs)

    return lookup_fields, actual_fields


def find_field_names(obj, **kwargs):
    if hasattr(obj, '_meta'):
        lookup_fields, actual_fields = find_field_names_from_meta(
            obj._meta,
            annotated=getattr(obj, '_annotated_fields', []),
            **kwargs
        )
    else:
        lookup_fields, actual_fields = find_field_names_from_obj(obj, **kwargs)

    return lookup_fields, actual_fields


def validate_field(field_name, model_fields, for_update=False):
    if '__' in field_name and for_update:
        raise FieldError(
            'Cannot update model field %r (only non-relations and foreign keys permitted).' % field_name
        )
    if field_name != 'pk' and field_name not in model_fields:
        message = "Cannot resolve keyword '{}' into field. Choices are {}.".format(
            field_name,
            ', '.join(map(repr, map(str, sorted(model_fields))))
        )
        raise FieldError(message)


def get_field_value(obj, field_name, default=None):
    if type(obj) is dict:
        return obj.get(field_name, default)
    elif is_list_like_iter(obj):
        return [get_attribute(x, field_name, default)[0] for x in obj]
    elif is_like_date_or_datetime(obj):
        return obj
    else:
        return getattr(obj, field_name, default)


def get_attribute(obj, attr, default=None):
    result = obj
    comparison = None
    if isinstance(attr, F):
        attr = attr.deconstruct()[1][0]
    elif isinstance(attr, Value):
        return attr.value, None
    elif isinstance(attr, Case):
        for case in attr.cases:
            if filter_results([obj], case.condition):
                return get_attribute(obj, case.result)
        else:
            return get_attribute(obj, attr.default)
    elif isinstance(attr, Coalesce):
        for expr in attr.source_expressions:
            res, comp = get_attribute(obj, expr)
            if res is not None:
                return res, comp
    parts = attr.split('__')

    for i, attr_part in enumerate(parts):
        if attr_part in COMPARISONS:
            comparison = attr_part
        elif attr_part in DATETIME_COMPARISONS and type(result) in [date, datetime]:
            comparison_type = parts[i + 1] if i + 1 < len(parts) else COMPARISON_EXACT
            comparison = (attr_part, comparison_type)
            break
        elif result is None:
            result = default
            break
        else:
            lookup_fields, actual_fields = find_field_names(result)

            if lookup_fields:
                validate_field(attr_part, lookup_fields)

            field = actual_fields[lookup_fields.index(attr_part)] if attr_part in lookup_fields else attr_part
            result = get_field_value(result, field, default)
    return result, comparison


def is_match(first, second, comparison=None):
    if isinstance(first, django_mock_queries.query.MockSet):
        return is_match_in_children(comparison, first, second)
    if (isinstance(first, (int, str)) and isinstance(second, django_mock_queries.query.MockSet)):
        second = convert_to_pks(second)
    if (isinstance(first, date) or isinstance(first, datetime)) \
            and isinstance(comparison, tuple) and len(comparison) == 2:
        first = extract(first, comparison[0])
        comparison = comparison[1]
    if not comparison:
        return first == second
    return {
        COMPARISON_EXACT: exact_comparison,
        COMPARISON_IEXACT: iexact_comparison,
        COMPARISON_CONTAINS: contains_comparison,
        COMPARISON_ICONTAINS: icontains_comparison,
        COMPARISON_GT: gt_comparison,
        COMPARISON_GTE: gte_comparison,
        COMPARISON_LT: lt_comparison,
        COMPARISON_LTE: lte_comparison,
        COMPARISON_IN: in_comparison,
        COMPARISON_STARTSWITH: startswith_comparison,
        COMPARISON_ISTARTSWITH: istartswith_comparison,
        COMPARISON_ENDSWITH: endswith_comparison,
        COMPARISON_IENDSWITH: iendswith_comparison,
        COMPARISON_ISNULL: isnull_comparison,
        COMPARISON_REGEX: regex_comparison,
        COMPARISON_IREGEX: iregex_comparison,
        COMPARISON_RANGE: range_comparison,
        COMPARISON_OVERLAP: overlap_comparison,
    }[comparison](first, second)


def extract(obj, comparison):
    result_dict = None
    if isinstance(obj, date):
        result_dict = {
            COMPARISON_DATE: obj,
            COMPARISON_YEAR: obj.year,
            COMPARISON_MONTH: obj.month,
            COMPARISON_DAY: obj.day,
            COMPARISON_WEEK_DAY: (obj.weekday() + 1) % 7 + 1,
        }
    if isinstance(obj, datetime):
        result_dict = {
            COMPARISON_DATE: obj.date(),
            COMPARISON_YEAR: obj.year,
            COMPARISON_MONTH: obj.month,
            COMPARISON_DAY: obj.day,
            COMPARISON_WEEK_DAY: (obj.weekday() + 1) % 7 + 1,
            COMPARISON_HOUR: obj.hour,
            COMPARISON_MINUTE: obj.minute,
            COMPARISON_SECOND: obj.second,
        }
    return result_dict[comparison]


def convert_to_pks(query):
    try:
        return [item.pk for item in query]
    except AttributeError:
        return query  # Didn't have pk's, keep original items


def is_match_in_children(comparison, first, second):
    return any(is_match(item, second, comparison)
               for item in first)


def is_disqualified(obj, attrs, negated):
    for attr_name, filter_value in attrs.items():
        attr_value, comparison = get_attribute(obj, attr_name)
        match = is_match(attr_value, filter_value, comparison)

        if (match and negated) or (not match and not negated):
            return True

    return False


def matches(*source, **attrs):
    negated = attrs.pop('negated', False)
    disqualified = [x for x in source if is_disqualified(x, attrs, negated)]

    return [x for x in source if x not in disqualified]


def validate_mock_set(mock_set, for_update=False, **fields):
    if mock_set.model is None:
        raise ModelNotSpecified()

    _, actual_fields = find_field_names(mock_set.model)

    for k in fields.keys():
        validate_field(k, actual_fields, for_update)


def validate_date_or_datetime(value, comparison):
    mapping = {
        COMPARISON_YEAR: lambda: True,
        COMPARISON_MONTH: lambda: MONTH_BOUNDS[0] <= value <= MONTH_BOUNDS[1],
        COMPARISON_DAY: lambda: DAY_BOUNDS[0] <= value <= DAY_BOUNDS[1],
        COMPARISON_WEEK_DAY: lambda: WEEK_DAY_BOUNDS[0] <= value <= WEEK_DAY_BOUNDS[1],
        COMPARISON_HOUR: lambda: HOUR_BOUNDS[0] <= value <= HOUR_BOUNDS[1],
        COMPARISON_MINUTE: lambda: MINUTE_BOUNDS[0] <= value <= MINUTE_BOUNDS[1],
        COMPARISON_SECOND: lambda: SECOND_BOUNDS[0] <= value <= SECOND_BOUNDS[1],
    }
    if not mapping[comparison]():
        raise ValueError('{} is incorrect value for {}'.format(value, comparison))


def is_list_like_iter(obj):
    if isinstance(obj, django_mock_queries.query.MockModel):
        return False
    elif isinstance(obj, django_mock_queries.query.MockSet):
        return True
    elif isinstance(obj, Mock):
        return False

    return hasattr(obj, '__iter__') and not isinstance(obj, str)


def is_like_date_or_datetime(obj):
    return type(obj) in [date, datetime]


def flatten_list(source):
    target = []
    for x in source:
        if not is_list_like_iter(x):
            target.append(x)
        else:
            target.extend(flatten_list(x))
    return target


def truncate(obj, kind):
    trunc_mapping = None
    if isinstance(obj, date):
        trunc_mapping = {
            'year': obj.replace(month=1, day=1),
            'month': obj.replace(day=1),
            'day': obj
        }
    if isinstance(obj, datetime):
        trunc_mapping = {
            'year': obj.replace(month=1, day=1, hour=0, minute=0, second=0),
            'month': obj.replace(day=1, hour=0, minute=0, second=0),
            'day': obj.replace(hour=0, minute=0, second=0),
            'hour': obj.replace(minute=0, second=0),
            'minute': obj.replace(second=0),
            'second': obj
        }
    return trunc_mapping[kind]


def hash_dict(obj, *fields):
    field_names = fields or find_field_names(obj, concrete_only=True)[1]
    obj_values = {f: get_field_value(obj, f) for f in field_names}

    return hash(tuple(sorted((k, v) for k, v in obj_values.items() if not fields or k in fields)))


def filter_results(source, query):
    results = []

    for child in query.children:
        filtered = _filter_single_q(source, child, query.negated)

        if filtered:
            if not results or query.connector == CONNECTORS_OR:
                results = merge(results, filtered)
            else:
                results = intersect(results, filtered)
        elif query.connector == CONNECTORS_AND:
            return []

    return results


def _filter_single_q(source, q_obj, negated):
    if isinstance(q_obj, DjangoQ):
        return filter_results(source, q_obj)
    else:
        return matches(negated=negated, *source, **{q_obj[0]: q_obj[1]})


def get_nested_attr(obj, attr_path, default=None):
    attrs = attr_path.split('.')
    try:
        for attr in attrs:
            obj = getattr(obj, attr)
        return obj
    except AttributeError:
        return default