snorklerjoe/CubeServer

View on GitHub
src/CubeServer-common/cubeserver_common/models/utils/modelutils.py

Summary

Maintainability
B
4 hrs
Test Coverage
"""Some utility classes to help with object mapping"""

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, List, Mapping, Optional, Tuple, Type, cast
from pydoc import locate
import warnings
from bson import ObjectId
from json import loads
from bson import _BUILT_IN_TYPES as BSON_TYPES
from bson.codec_options import TypeCodec, TypeRegistry
from pymongo import MongoClient, ASCENDING, DESCENDING
from pymongo.collection import Collection

from .dummycodec import DummyCodec
from .enumcodec import EnumCodec

__all__ = ["PyMongoModel", "Encodable", "EncodableCodec", "ASCENDING", "DESCENDING"]


def _locatable_name(type_to_name: type) -> str:
    """Returns a string that can be used in reverse with pydoc.locate"""
    module = type_to_name.__module__
    if "builtin" in module:
        return type_to_name.__name__
    return module + "." + type_to_name.__name__


def value_ref(key):
    """Adds in ._val reference to keys"""
    if key == "_id":
        return key

    split_key = key.split(".", 1)

    if len(split_key) == 1:
        return f"{key}._val"
    else:
        return f"{split_key[0]}._val.{split_key[1]}"


def map_filter(f):
    """update all keys to have a value ref"""
    if f:
        result = {}
        for k, v in f.items():
            result[value_ref(k)] = v
        return result
    return f


def map_sort(s):
    """update all sort keys to have the value ref"""
    return [(value_ref(k), d) for k, d in s]


class _Encoder(ABC):
    @abstractmethod
    def encode(self) -> dict:
        """Encodes an Encodable object into a plain old, bson-able
        dictionary"""

    @classmethod
    @abstractmethod
    def decode(cls, value: dict):
        """Decodes a dictionary into an Encodable object"""


class EncodableCodec(TypeCodec):
    """A TypeCodec for PyMongoModel objects"""

    bson_type = dict

    def __init__(self, encodable: Type[_Encoder]):
        self.encodable = encodable

    @property
    def python_type(self) -> Type[_Encoder]:
        return self.encodable

    def transform_python(self, value: _Encoder) -> dict:
        """Encodes a PyMongoModel object into a plain old, bson-able
        dictionary"""
        return value.encode()

    def transform_bson(self, value: dict) -> _Encoder:
        """Decodes a dictionary into a PyMongoModel object"""
        return self.python_type.decode(value)


class Encodable(_Encoder):
    """An abstract class for classes that contain codec data"""

    @abstractmethod
    def __init__(self) -> None:
        """Encodables must have a no-argument constructor that just
        populates default values for decoding purposes."""
        super().__init__()

    # From _Encoder:
    @abstractmethod
    def encode(self) -> dict:
        """Encodes an Encodable object into a plain old, bson-able
        dictionary"""

    @classmethod
    @abstractmethod
    def decode(cls, value: dict) -> _Encoder:
        """Decodes a dictionary into an Encodable object"""

    def __eq__(self, __value: object) -> bool:
        if not isinstance(__value, Encodable):
            return False
        return self.encode() == __value.encode()


class PyMongoModel(Encodable):  # TODO: Clean up some code by making an
    #  AutoEncodable superclass that implements
    #  encode() and decode() for non-document
    #  objects.
    """A class for easy object-mapping to bson.
    Extend this class for any classes that describe a type of document."""

    mongo: Optional[MongoClient] = None

    @classmethod
    def update_mongo_client(cls, mongo_client: Optional[MongoClient]):
        """Sets the MongoClient reference in PyMongoModel, which is then
        used by any models that extend this class."""
        cls.mongo = mongo_client

    @property
    @classmethod
    def model_type_registry(cls) -> TypeRegistry:
        """A TypeRegistry to be used when getting the collection from
        the database"""
        return TypeRegistry([EncodableCodec(cls)])

    @property
    @classmethod
    def collection(cls) -> Collection:
        """Define the Mongodb collection in your class.
        Use the PyMongoModel.model_type_registry as the type registry."""

    @classmethod
    def set_collection_name(cls, collection_name: str):
        """Define the Mongodb collection in your class.
        Use the PyMongoModel.model_type_registry as the type registry."""
        try:
            cls.collection = PyMongoModel.mongo.db.get_collection(collection_name)
        except AttributeError:
            pass

    def __init_subclass__(cls):
        """Note that subclasses must implement a constructor or a __new__()
        which initializes all attributes with the proper values."""

        if PyMongoModel.mongo is None:
            warnings.warn(
                "Buddy, you forgot to initialize PyMongoModel"
                "with the MongoClient!\n"
                "You can ignore this if it is a part of the"
                "Sphinx-api build process."
            )

        cls._ignored: List[str] = [
            "_id",
            "type_codec",
            "_codecs",
            "_fields",
            "_ignored",
        ]

        # Registered TypeCodecs:
        cls._codecs: Mapping[type, TypeCodec] = {}

        # Registered fields and their corresponding TypeCodecs:
        # (None is an acceptable codec for directly bson-compatible types)
        cls._fields: Mapping[str, Optional[TypeCodec]] = {}

        cls.set_collection_name(cls.__name__.lower())

        super().__init_subclass__()

    @abstractmethod
    def __init__(self):
        """Initializes the PyMongoModel overhead"""
        # Id:
        self._id: Optional[ObjectId] = None
        super().__init__()

    def register_codec(self, type_codec: TypeCodec, replace=False):
        """Register a TypeCodec for use in the PyMongoModelCodec
        Specify whether to replace an existing one if applicable,
        with the default being False."""
        if type_codec.python_type in self._codecs and not replace:
            raise KeyError(
                f"A TypeCodec is already registered"
                f" for type {type_codec.python_type}"
            )
        self._codecs[type_codec.python_type] = type_codec

    def ignore_attribute(self, attr_name: str):
        """Forces an attribute to be ignored as a document field"""
        if attr_name not in self._ignored:
            self._ignored += [attr_name]

    def locate_codec(self, data_type: type) -> Optional[TypeCodec]:
        """Tries to find a TypeCodec for the specified type if possible."""
        codec: Optional[TypeCodec] = None
        if issubclass(data_type, Encodable):  # Use the one specified:
            codec = EncodableCodec(cast(Encodable, data_type()))
        elif issubclass(data_type, TypeCodec):
            codec = data_type()
        elif issubclass(data_type, Enum):  # Use the EnumCodec class:
            codec = EnumCodec(data_type, type(list(cast(Enum, data_type))[0]))
        else:
            raise TypeError(
                f"No TypeCodec is registered for type "
                f"{data_type}. "
                f"Please register one before setting."
            )
        return codec

    def register_field(
        self,
        attr_name: str,
        value: Optional[Any] = None,
        custom_codec: Optional[TypeCodec] = None,
    ):
        """Register each attribute of the model for the database.
        Optionally specify with a custom codec prior to setting the attribute.
        If a codec is specified optionally here, it will not be automatically
        registered for use in encoding/decoding attributes of the same type
        unless register_codec() is used also.
        The value does not need to be specified ONLY IF a custom_codec is
        provided.
        This method returns immediately if a field with the given name is
        already registered.
        The force parameter allows you to force the registration of a field
        despite checks failing or the type appearing to be bson-compatible"""

        if value is not None:  # Also set the value while we're at it:
            self._setattr_shady(attr_name, value)

        if attr_name in self._fields:
            return
        codec = custom_codec
        # TODO: Check recursively for bson compat- there can be dicts of enums for ex
        if (
            codec is None
            and value is not None
            and type(value) not in BSON_TYPES
            and not type(value) in self._codecs
        ):  # if a TypeCodec is required:
            # Find or make a TypeCodec for this field:
            if isinstance(value, Encodable):  # Use the one specified:
                codec = EncodableCodec(type(value))
            elif isinstance(value, TypeCodec):
                codec = value
            elif isinstance(value, Enum):  # Use the EnumCodec class:
                codec = EnumCodec(type(value), type(value.value))
            else:
                raise TypeError(
                    f"No TypeCodec is registered for type "
                    f"{type(value)}, for attribute {attr_name}. "
                    f"Please register one before setting."
                )
            self.register_codec(codec)
        if codec is None and type(value) in self._codecs:
            codec = self._codecs[type(value)]
        self._fields[attr_name] = codec  # Register the field!

    def encode(self) -> dict:
        """Encodes this into a dictionary for BSON to be happy"""
        if "_id" not in vars(self) or self._id is None:
            self._id = ObjectId()
        dictionary: Mapping[str, Tuple[str, Any]] = {}
        for field, codec in zip(self._fields, self._fields.values()):
            if codec:  # Encode each field:
                dictionary[field] = {
                    "_type": _locatable_name(codec.python_type),
                    "_val": codec.transform_python(getattr(self, field)),
                }
            else:  # If no TypeCodec was specified, just leave the value raw:
                value = getattr(self, field)
                # dictionary[field] = (_locatable_name(type(value)), value)
                dictionary[field] = {"_type": "None", "_val": value}
        dictionary["_id"] = self._id
        return dictionary

    @classmethod
    def find_codec(cls, field_name: str, field_type_name: str) -> TypeCodec:
        """Finds a codec for a given field (w/ name and type name specified)"""
        return (
            cls._fields[field_name]
            if field_name in cls._fields
            else (
                cls._codecs[locate(field_type_name)]
                if locate(field_type_name) in cls._codecs
                else DummyCodec()
            )
        )

    @classmethod
    def decode(cls, value: Optional[Mapping[str, Any]]) -> Optional[Encodable]:
        """Populates an object from a dictionary of the document
        This returns None only if the bson value given is None"""
        if value is None:  # Limit a potential failure
            return None
        new_object = cls()
        new_object._id = value.pop("_id")
        for field_name, x in zip(value, value.values()):
            if isinstance(x, dict) and "_type" in x:
                field_type_name = x["_type"]
                val = x["_val"]
            else:
                try:
                    field_type_name, val = x
                    if field_type_name != "None":
                        if (
                            not isinstance(field_type_name, str)
                            or "." not in field_type_name
                        ):
                            field_type_name, val = ("None", x)
                except (ValueError, TypeError):
                    field_type_name, val = ("None", x)

            if field_type_name == "None":
                if isinstance(val, list) and len(val) == 2 and val[0] == "None":
                    val = val[1]
                new_object._setattr_shady(field_name, val)
            else:
                # Try to get the codec from the fields registry
                # or else try to fall back on the codecs registry
                # or, if that fails, try the fallback dummy codec:
                codec: TypeCodec = cls.find_codec(field_name, field_type_name)
                new_object._setattr_shady(field_name, codec.transform_bson(val))
        return new_object

    def __delattr__(self, __name: str):
        if __name in self._fields:
            del self._fields[__name]
        super().__delattr__(__name)

    def __setattr__(self, __name: str, __value: Any):
        if (
            __name not in self._ignored
            and __value is not None
            and __name not in self._fields
        ):
            self.register_field(__name, __value)
        super().__setattr__(__name, __value)

    def __hash__(self) -> int:
        return hash(self._id)

    def __eq__(self, __value: object) -> bool:
        if not isinstance(__value, PyMongoModel):
            return False
        for field in self._fields:
            if field in self._ignored:
                continue
            if getattr(self, field) != getattr(__value, field):
                return False

    def _setattr_shady(self, __name: str, __value: Any):
        """Shadily goes around __setattr__ and straight to the superclass.
        This is a workaround to allow for easier implementation of the
        decode() method."""
        super().__setattr__(__name, __value)

    @property
    def id(self):
        """The internal document identifier"""
        return self._id

    @property
    def id_secondary(self):
        """A dummy property; always equal to id
        Used to have multiple id-driven columns in a Flask-tables table
        """
        return self.id

    def save(self):
        """Saves this document to the collection"""
        if not self._id:
            self.collection.insert_one(self.encode())
        else:
            self.collection.replace_one({"_id": self._id}, self.encode())

    def remove(self):
        """Removes this document from the collection"""
        if self._id:
            self.collection.delete_one({"_id": self._id})

    @classmethod
    def find(cls, *args, **kwargs):
        """Finds documents from the collection
        Arguments are the same as those for PyMongo.collection's find()."""
        args = list(args)
        if args:
            args[0] = map_filter(args[0])

        if "filter" in kwargs:
            kwargs["filter"] = map_filter(kwargs["filter"])

        if "sort" in kwargs:
            kwargs["sort"] = map_sort(kwargs["sort"])

        return [
            cls.decode(document) for document in cls.collection.find(*args, **kwargs)
        ]

    @classmethod
    def find_one(cls, *args, **kwargs):
        """Finds a document from the collection
        Arguments are the same as those for PyMongo's find_one()."""
        args = list(args)
        if args:
            args[0] = map_filter(args[0])

        if "filter" in kwargs:
            kwargs["filter"] = map_filter(kwargs["filter"])

        if "sort" in kwargs:
            kwargs["sort"] = map_sort(kwargs["sort"])

        return cls.decode(cls.collection.find_one(*args, **kwargs))

    def find_self(self):
        """Returns the database's version of self"""
        return self.find_by_id(self.id)

    @classmethod
    def find_by_id(cls, identifier):
        """Finds a document from the collection, given the id"""
        return cls.find_one({"_id": ObjectId(identifier)})

    def set_attr_from_string(self, field_name: str, value: str):
        """Decodes and updates a single string value to the document object"""
        if self._fields[field_name] is not None:
            self._setattr_shady(
                field_name, self._fields[field_name].transform_bson(value)
            )
        else:  # None means the value is already bson-serializable
            self._setattr_shady(
                field_name, type(self.__getattribute__(field_name))(value)
            )

    @classmethod
    def count_documents(cls, *args, **kwargs):
        """Return the number of documents in the filtered collection"""
        args = list(args)
        if args:
            args[0] = map_filter(args[0])
        return cls.collection.count_documents(*args, **kwargs)