stphivos/django-mock-queries

View on GitHub
django_mock_queries/query.py

Summary

Maintainability
D
2 days
Test Coverage
import datetime
import random
from collections import OrderedDict, defaultdict, namedtuple
from unittest.mock import Mock, MagicMock, PropertyMock

from .constants import *
from .exceptions import *
from .utils import (
    matches, get_attribute, validate_mock_set, is_list_like_iter, flatten_list, truncate,
    hash_dict, filter_results, get_nested_attr
)


class MockSetMeta(type):
    def __call__(cls, *initial_items, **kwargs):
        obj = super(MockSetMeta, cls).__call__(**kwargs)
        obj.add(*initial_items)
        return obj


class MockSet(MagicMock, metaclass=MockSetMeta):
    EVENT_ADDED = 'added'
    EVENT_UPDATED = 'updated'
    EVENT_SAVED = 'saved'
    EVENT_DELETED = 'deleted'
    SUPPORTED_EVENTS = [EVENT_ADDED, EVENT_UPDATED, EVENT_SAVED, EVENT_DELETED]
    RETURN_SELF_METHODS = [
        'all',
        'only',
        'defer',
        'using',
        'select_related',
        'prefetch_related',
        'select_for_update',
        'iterator'
    ]

    def __init__(self, *initial_items, **kwargs):
        clone = kwargs.pop('clone', None)
        model = kwargs.pop('model', None)

        for x in self.RETURN_SELF_METHODS:
            kwargs.update({x: self._return_self})

        super(MockSet, self).__init__(spec=DjangoQuerySet, **kwargs)

        self.items = list()
        self.clone = clone
        self.model = getattr(clone, 'model', model)
        self.events = {}

        self.add(*initial_items)

        self.__len__ = lambda s: len(s.items)
        self.__iter__ = lambda s: iter(s.items)
        self.__getitem__ = lambda s, k: self.items[k]
        self.__bool__ = self.__nonzero__ = lambda s: len(s.items) > 0

    def _return_self(self, *_, **__):
        return self

    def _mockset_class(self):
        return type(self)

    def count(self):
        return len(self.items)

    def fire(self, obj, *events):
        for name in events:
            for handler in self.events.get(name, []):
                handler(obj)

    def on(self, event, handler):
        assert event in self.SUPPORTED_EVENTS, event
        self.events[event] = self.events.get(event, []) + [handler]

    def _register_fields(self, obj):
        if not (isinstance(obj, MockModel) or isinstance(obj, Mock)):
            return

        for f in self.model._meta.fields:
            if f.name not in obj.keys():
                setattr(obj, f.name, None)

    def add(self, *models):
        if self.model:
            # Initialize MockModel default fields from MockSet model fields if defined
            for obj in models:
                self._register_fields(obj)

        for model in models:
            self.items.append(model)
            self.fire(model, self.EVENT_ADDED, self.EVENT_SAVED)

    def filter(self, *args, **attrs):
        results = list(self.items)
        for x in args:
            if not isinstance(x, DjangoQ):
                raise ArgumentNotSupported()

            if len(x) > 0:
                results = filter_results(results, x)

        return self._mockset_class()(*matches(*results, **attrs), clone=self)

    def exclude(self, *args, **attrs):
        excluded = set(self.filter(*args, **attrs))
        results = [item for item in self.items if item not in excluded]
        return self._mockset_class()(*results, clone=self)

    def exists(self):
        return len(self.items) > 0

    def in_bulk(self, id_list=None, *, field_name='pk'):
        result = {}
        for model in self.items:
            if id_list is None or getattr(model, field_name) in id_list:
                result[getattr(model, field_name)] = model
        return result

    def annotate(self, **kwargs):
        results = list(self.items)
        for key, value in kwargs.items():
            for row in results:
                if not (hasattr(row, '_annotated_fields') and isinstance(row._annotated_fields, list)):
                    row._annotated_fields = []
                row._annotated_fields.append(key)
                setattr(row, key, get_attribute(row, value)[0])

        return self._mockset_class()(*results, clone=self)

    def aggregate(self, *args, **kwargs):
        result = {}

        for expr in set(args):
            kwargs['{0}__{1}'.format(expr.source_expressions[0].name, expr.function).lower()] = expr

        for alias, expr in kwargs.items():
            values = []
            expr_result = None

            for x in self.items:
                val = get_attribute(x, expr.source_expressions[0].name)[0]
                if val is None:
                    continue
                values.extend(val if is_list_like_iter(val) else [val])

            if len(values) > 0:
                expr_result = {
                    AGGREGATES_SUM: lambda: sum(values),
                    AGGREGATES_COUNT: lambda: len(values),
                    AGGREGATES_MAX: lambda: max(values),
                    AGGREGATES_MIN: lambda: min(values),
                    AGGREGATES_AVG: lambda: sum(values) / len(values),
                    AGGREGATES_ARRAY: lambda: values,
                }[expr.function]()

            if len(values) == 0 and expr.function == AGGREGATES_COUNT:
                expr_result = 0

            result[alias] = expr_result

        return result

    def order_by(self, *fields):
        results = self.items
        for field in reversed(fields):
            if field == '?':
                random.shuffle(results)
                break
            is_reversed = field.startswith('-')
            attr = field[1:] if is_reversed else field
            results = sorted(results,
                             key=lambda r: get_attribute(r, attr),
                             reverse=is_reversed)
        return self._mockset_class()(*results, clone=self, ordered=True)

    def distinct(self, *fields):
        results = OrderedDict()
        for item in self.items:
            key = hash_dict(item, *fields)
            if key not in results:
                results[key] = item
        return self._mockset_class()(*results.values(), clone=self)

    def set(self, objs, **attrs):
        self.delete(**attrs)
        self.add(*objs)

    def _raise_does_not_exist(self):
        does_not_exist = getattr(self.model, 'DoesNotExist', ObjectDoesNotExist)
        raise does_not_exist()

    def _get_order_fields(self, fields, field_name):
        if fields and field_name is not None:
            raise ValueError('Cannot use both positional arguments and the field_name keyword argument.')

        if field_name is not None:
            # The field_name keyword argument is deprecated in favor of passing positional arguments.
            order_fields = (field_name,)
        elif fields:
            order_fields = fields
        else:
            order_fields = self.model._meta.get_latest_by
            if order_fields and not isinstance(order_fields, (tuple, list)):
                order_fields = (order_fields,)

        if order_fields is None:
            raise ValueError(
                "earliest() and latest() require either fields as positional "
                "arguments or 'get_latest_by' in the model's Meta."
            )

        return order_fields

    def _earliest_or_latest(self, *fields, **field_kwargs):
        """
        Mimic Django's behavior
        https://github.com/django/django/blob/746caf3ef821dbf7588797cb2600fa81b9df9d1d/django/db/models/query.py#L560
        """
        field_name = field_kwargs.get('field_name', None)
        reverse = field_kwargs.get('reverse', False)
        order_fields = self._get_order_fields(fields, field_name)

        results = sorted(
            self.items,
            key=lambda obj: tuple(get_attribute(obj, key) for key in order_fields),
            reverse=reverse,
        )

        if len(results) == 0:
            self._raise_does_not_exist()

        return results[0]

    def earliest(self, *fields, **field_kwargs):
        return self._earliest_or_latest(*fields, **field_kwargs)

    def latest(self, *fields, **field_kwargs):
        return self._earliest_or_latest(*fields, reverse=True, **field_kwargs)

    def first(self):
        for item in self.items:
            return item

    def last(self):
        return self.items and self.items[-1] or None

    def create(self, **attrs):
        validate_mock_set(self, **attrs)

        obj = self.model(**attrs)
        self.add(obj)

        return obj

    def update(self, **attrs):
        validate_mock_set(self, for_update=True, **attrs)

        count = 0
        for item in self.items:
            count += 1
            for k, v in attrs.items():
                setattr(item, k, v)
                self.fire(item, self.EVENT_UPDATED, self.EVENT_SAVED)

        return count

    def _delete_recursive(self, *items_to_remove, **attrs):
        removed_items = defaultdict(int)

        for item in matches(*items_to_remove, **attrs):
            self.items.remove(item)
            self.fire(item, self.EVENT_DELETED)

            # Support returning detailed information about removed items
            # even if items are not `MockModel` instances
            item_label = get_nested_attr(item, '_meta.label', default=type(item).__name__)
            removed_items[item_label] += 1

        if self.clone is not None:
            self.clone._delete_recursive(*items_to_remove, **attrs)

        return sum(removed_items.values()), removed_items

    def delete(self, **attrs):
        # Delete normally doesn't take **attrs - they're only needed for remove
        return self._delete_recursive(*self.items, **attrs)

    # The following 2 methods were kept for backwards compatibility and
    # should be removed in the future since they are covered by filter & delete
    def clear(self, **attrs):
        return self.delete(**attrs)

    def remove(self, **attrs):
        return self.delete(**attrs)

    def get(self, *args, **attrs):
        results = self.filter(*args, **attrs)
        if not results.exists():
            self._raise_does_not_exist()
        elif results.count() > 1:
            raise MultipleObjectsReturned()
        else:
            return results[0]

    def get_or_create(self, defaults=None, **attrs):
        if defaults is not None:
            validate_mock_set(self)
        defaults = defaults or {}
        lookup = attrs.copy()
        attrs.update(defaults)
        results = self.filter(**lookup)
        if not results.exists():
            return self.create(**attrs), True
        elif results.count() > 1:
            raise MultipleObjectsReturned()
        else:
            return results[0], False

    def update_or_create(self, defaults=None, **attrs):
        if defaults is not None:
            validate_mock_set(self)
        defaults = defaults or {}
        lookup = attrs.copy()
        attrs.update(defaults)
        results = self.filter(**lookup)
        if not results.exists():
            return self.create(**attrs), True
        elif results.count() > 1:
            raise MultipleObjectsReturned()
        else:
            obj = results[0]
            for k, v in attrs.items():
                setattr(obj, k, v)
                self.fire(obj, self.EVENT_UPDATED, self.EVENT_SAVED)
            return obj, False

    def _item_values(self, item, fields):
        field_buckets = {}
        result_count = 1

        if len(fields) == 0:
            field_names = [f.attname for f in item._meta.concrete_fields]
        else:
            field_names = list(fields)

        for field in sorted(field_names, key=lambda k: k.count('__')):
            value = get_attribute(item, field)[0]

            if is_list_like_iter(value):
                value = flatten_list(value)
                result_count = max(result_count, len(value))

                for bucket, data in field_buckets.items():
                    while len(data) < result_count:
                        data.append(data[-1])

                field_buckets[field] = value
            else:
                field_buckets[field] = [value]

        item_values = []
        for i in range(result_count):
            item_values.append({k: v[i] for k, v in field_buckets.items()})

        return item_values

    def values(self, *fields):
        result = []

        for item in self.items:
            item_values = self._item_values(item, fields)
            result.extend(item_values)

        return self._mockset_class()(*result, clone=self)

    def _item_values_list(self, values_dict, fields, flat):
        if flat:
            return values_dict[fields[0]]
        else:
            data = []
            for key in sorted(values_dict.keys(), key=lambda k: fields.index(k)):
                data.append(values_dict[key])
            return tuple(data)

    def _values_row(self, values_dict, fields, **kwargs):
        flat = kwargs.pop('flat', False)
        named = kwargs.pop('named', False)

        if kwargs:
            raise TypeError('Unexpected keyword arguments to values_list: %s' % (list(kwargs),))
        if flat and len(fields) > 1:
            raise TypeError('`flat` is not valid when values_list is called with more than one field.')
        if flat and named:
            raise TypeError('`flat` and `named` can\'t be used together.')

        if named:
            Row = namedtuple('Row', fields)
            row = Row(**values_dict)
        else:
            row = self._item_values_list(values_dict, fields, flat)

        return row

    def values_list(self, *fields, **kwargs):
        # Django doesn't complain about this:
        # https://github.com/django/django/blob/a4e6030904df63b3f10aa0729b86dc6942b0458e/django/db/models/query.py#L845
        # if len(fields) == 0:
        #     raise NotImplementedError('values_list() with no arguments is not implemented')

        result = []
        item_values_dicts = list(self.values(*fields))

        for values_dict in item_values_dicts:
            result.append(self._values_row(values_dict, fields, **kwargs))

        return self._mockset_class()(*result, clone=self)

    def _date_values(self, field, kind, order, key_func):
        initial_values = list(self.values_list(field, flat=True))

        return self._mockset_class()(*sorted(
            {truncate(x, kind) for x in initial_values},
            key=key_func,
            reverse=True if order == 'DESC' else False
        ), clone=self)

    def dates(self, field, kind, order='ASC'):
        assert kind in ("year", "month", "day"), "'kind' must be one of 'year', 'month' or 'day'."
        assert order in ('ASC', 'DESC'), "'order' must be either 'ASC' or 'DESC'."

        return self._date_values(field, kind, order, lambda y: datetime.date.timetuple(y)[:3])

    def datetimes(self, field, kind, order='ASC'):
        # TODO: Handle `tzinfo` parameter
        assert kind in ("year", "month", "day", "hour", "minute", "second"), \
            "'kind' must be one of 'year', 'month', 'day', 'hour', 'minute' or 'second'."
        assert order in ('ASC', 'DESC'), "'order' must be either 'ASC' or 'DESC'."

        return self._date_values(field, kind, order, lambda y: datetime.datetime.timetuple(y)[:6])


class MockModel(dict):
    def __init__(self, *args, **kwargs):
        super(MockModel, self).__init__(*args, **kwargs)

        self.save = PropertyMock()
        object_name = self.get('mock_name', type(self).__name__)
        self.__meta = MockOptions(object_name, *self.get_fields())

    def __getattr__(self, item):
        return self.get(item, None)

    def __setattr__(self, key, value):
        self.__setitem__(key, value)

    def __hash__(self):
        return hash_dict(self)

    def __call__(self, *args, **kwargs):
        return MockModel(*args, **kwargs)

    def get_fields(self):
        skip_keys = ['save', '_MockModel__meta']
        return [key for key in self.keys() if key not in skip_keys]

    @property
    def _meta(self):
        self.__meta.load_fields(*self.get_fields())
        return self.__meta

    def __repr__(self):
        return self.get('mock_name', None) or super(MockModel, self).__repr__()


def create_model(*fields):
    if len(fields) == 0:
        raise ValueError('create_model() is called without fields specified')
    return MockModel(**{f: None for f in fields})


class MockOptions:
    def __init__(self, object_name, *field_names):
        self.load_fields(*field_names)
        self.get_latest_by = None
        self.object_name = object_name
        self.model_name = object_name.lower()

    @property
    def label(self):
        return self.object_name

    @property
    def label_lower(self):
        return self.model_name

    def load_fields(self, *field_names):
        fields = {name: MockField(name) for name in field_names}

        for key in ('_forward_fields_map', 'parents', 'fields_map'):
            self.__dict__[key] = {}

            if key == '_forward_fields_map':
                for name, obj in fields.items():
                    self.__dict__[key][name] = obj

        for key in ('local_concrete_fields', 'concrete_fields', 'fields'):
            self.__dict__[key] = []

            for name, obj in fields.items():
                self.__dict__[key].append(obj)


class MockField:
    def __init__(self, field):
        for key in ('name', 'attname'):
            self.__dict__[key] = field