fabiommendes/sidekick

View on GitHub
sidekick-types/sidekick/types/named_record.py

Summary

Maintainability
B
5 hrs
Test Coverage
import collections.abc
import keyword
from types import MappingProxyType

from .anonymous_record import MutableMapView, MapView, record, namespace, MetaMixin
from ..typing import Mapping

NOT_GIVEN = object()
Field = collections.namedtuple("Field", ["name", "type", "default"])


class RecordMeta(type):
    """
    Metaclass for Record types.
    """

    _meta: "Meta" = NOT_GIVEN

    def __new__(mcs, name, bases, ns, use_invalid=False, **kwargs):
        if Namespace is NotImplemented or Record is NotImplemented:
            return super().__new__(mcs, name, bases, ns)
        else:
            fields = extract_fields_from_annotations(bases, ns)
            kwargs.setdefault(
                "is_mutable", any(issubclass(cls, Namespace) for cls in bases)
            )
            return new_record_type(name, fields, bases, ns, mcs=mcs, **kwargs)

    def __init__(cls, name, bases, ns, **kwargs):
        super().__init__(name, bases, ns)

    def __prepare__(cls, bases, **kwargs):
        return collections.OrderedDict()

    def define(
        self,
        name: str,
        fields: list,
        bases=(),
        ns: dict = None,
        use_invalid=False,
        is_mutable=None,
    ):
        """
        Declare a new record class.

        Args:
            name:
                The name of the
            fields:
                A list of field names or a tuples with (name, type) or even
                (name, type, default). If fields is a mapping, it is treated
                as sequence of (name, type) pairs if values are types or
                (name, default) pairs if values are instances.
            bases:
                An optional list of base classes for the derived class. By
                default, it has a single base of :class:`sidekick.Record`.
            ns:
                An optional dictionary of additional methods and properties
                the resulting record class declares.
            use_invalid:
                If True, accept invalid Python names as record fields. Those
                fields are still available from the getattr() and setattr()
                interfaces but are very inconvenient to use.
            is_mutable:
                If given, controls the mutabilty of resulting class. This controls
                the implicit base class (Record or Namespace). If not given,
                and no base class derive from Record or Namespace, assumes
                that result is immutable.

        Usage:

            >>> Point = Record.define('Point', ['x', 'y'])
            >>> Point(1, 2)
            Point(1, 2)


        Returns:
            A new Record subclass.
        """
        kwargs = {"use_invalid": use_invalid}

        # Compute mutability of class
        has_namespace = any(issubclass(cls, Namespace) for cls in bases)
        if is_mutable is None and has_namespace:
            kwargs["is_mutable"] = True
        elif is_mutable is None:
            kwargs["is_mutable"] = issubclass(self, Namespace)
        elif not is_mutable and has_namespace:
            raise ValueError("Immutable record cannot have a mutable super class.")
        else:
            kwargs["is_mutable"] = bool(is_mutable)

        # Force either Record or Namespace be in bases
        if kwargs["is_mutable"] and Namespace not in bases:
            bases = (*bases, Namespace)
        elif not kwargs["is_mutable"] and Record not in bases:
            bases = (*bases, Record)

        return new_record_type(name, fields, bases, ns or {}, **kwargs)


def new_record_type(
    name: str,
    fields: list,
    bases: tuple,
    ns: dict,
    use_invalid=False,
    is_mutable=False,
    mcs: type = RecordMeta,
) -> type:
    """
    Create new record type.
    """
    if isinstance(fields, collections.abc.Mapping):
        fields = list(normalize_field_mapping(fields))
    meta_info = Meta([clean_field(f, use_invalid) for f in fields], is_mutable)
    initial_ns = make_record_namespace(bases, meta_info, is_mutable)
    ns = dict(initial_ns, **ns)

    # Create class and update the init method
    cls = type.__new__(mcs, name, bases, ns)
    cls._meta = meta_info
    init = make_init_function(cls)
    if "__init_data__" not in ns:
        cls.__init_data__ = init
    if "__init__" not in ns:
        cls.__init__ = init
    return cls


def clean_field(field, use_invalid=False):
    """
    Coerce argument to a Field instance.
    """
    tt = object
    default = NOT_GIVEN
    if isinstance(field, str):
        name = field
    elif isinstance(field, Field):
        return field
    elif len(field) == 1:
        (name,) = field
    elif len(field) == 2:
        name, tt = field
    else:
        name, tt, default = field
    if not use_invalid and not is_valid_name(name):
        raise ValueError("%s is an invalid field name" % name)
    return Field(name, tt or object, default)


def normalize_field_mapping(fields):
    """
    Normalize each declaration in a field mapping.
    """
    for name, value in fields.items():
        if isinstance(value, type):
            yield Field(name, value, NOT_GIVEN)
        else:
            if value in (None, ..., NotImplemented):
                yield Field(name, object, value)
            else:
                yield Field(name, type(value), value)


def is_valid_name(name: str) -> bool:
    """
    True if name is a valid attribute name.
    """
    return name.isidentifier() and not keyword.iskeyword(name)


def safe_names(names):
    """
    Receive a list of names and return a map from names to the corresponding
    safe name to use as a Python variable.
    """
    safe_names = {}
    for name in names:
        safe_names[name] = name + "_" if keyword.iskeyword(name) else name
    if len(safe_names) != len(set(safe_names.values())):
        msg = "collision between escaped field names and given field names"
        raise ValueError(msg + ": %s" % names)
    return safe_names


def make_record_namespace(bases, meta_info, is_mutable=False):
    fields = meta_info.fields
    ns = {"__slots__": tuple(fields)}

    if not is_mutable:
        ns.setdefault("__hash__", lambda self: hash(tuple(self)))
        ns["__setattr__"] = record.__setattr__
    return ns


def extract_fields_from_annotations(bases, ns):
    annotations = {}
    annotations.update(ns.get("__annotations__", ()))
    for base in bases:
        if isinstance(base, RecordMeta):
            step = base.__dict__.get("__annotations__", {})
            step.update(annotations)
            annotations = step

    fields = []
    for name, tt in annotations.items():
        try:
            default = ns.pop(name)
        except KeyError:
            default = getattr_from_bases(bases, name, NOT_GIVEN)

        fields.append(Field(name, tt, default))
    return fields


def getattr_from_bases(bases, attr, default):
    for base in bases:
        try:
            return getattr(base, attr)
        except AttributeError:
            pass
    return default


def make_init_function(cls: RecordMeta):
    """
    Create a init function from a list of field names, their respective types
    and a dictionary of defaults.
    """

    # noinspection PyProtectedMember
    meta = cls._meta
    slots = {f: get_slot(cls, f) for f in meta.fields}
    names_map = safe_names(meta.fields)

    # Initialize defaults
    ns = {}
    for name, value in meta.defaults.items():
        safe_name = names_map[name]
        ns["_%s_default" % safe_name] = value
    for name, slot in slots.items():
        safe_name = names_map[name]
        ns["_%s_getter" % safe_name] = slot.__get__
        ns["_%s_setter" % safe_name] = slot.__set__

    code = make_init_function_code(names_map, meta.defaults)
    exec(code, ns, ns)
    return ns["__init__"]


def make_init_function_code(names_map: dict, defaults: Mapping) -> str:
    """
    Return a string with source code for the init function.
    """

    args = []
    for name, safe_name in names_map.items():
        if name in defaults:
            args.append("%s=_%s_default" % (safe_name, safe_name))
        else:
            args.append(safe_name)
    args = ", ".join(args)

    body = []
    for name, safe_name in names_map.items():
        slot_name = "_%s_setter" % safe_name
        body.append("%s(self, %s)" % (slot_name, safe_name))
    body = "\n    ".join(body)

    template = "def __init__(self, {args}):\n    {body}"
    return template.format(args=args, body=body or "pass")


def make_eq_function(fields):
    """
    Create a __eq__ method from a list of (name, field) tuples.
    """

    fields = tuple(fields)

    def __eq__(self, other):  # noqa: N802
        if isinstance(other, self.__class__):
            return all(getattr(self, f) == getattr(other, f) for f in fields)
        return NotImplemented

    return __eq__


def get_slot(cls, name):
    try:
        return getattr(cls, name)
    except AttributeError:
        return property(
            lambda x: x.__dict__[name], lambda x, v: x.__dict__.__setitem__(name, v)
        )


class Meta(MetaMixin):
    __slots__ = ("fields", "types", "defaults", "is_mutable")

    def __init__(self, fields, is_mutable):
        self.fields = tuple(f.name for f in fields)
        self.types = tuple(f.type for f in fields)
        self.defaults = MappingProxyType(
            {f.name: f.default for f in fields if f.default is not NOT_GIVEN}
        )
        self.is_mutable = is_mutable

    def __iter__(self):
        yield from self.fields


# ------------------------------------------------------------------------------
# Record classes
# ------------------------------------------------------------------------------
Record = Namespace = NotImplemented


class RecordMixin:
    __slots__ = ()
    _meta: Meta = Meta
    M: Mapping

    def __init__(*args, **extra):
        self, *args = args
        cls = type(self)
        args = dict(zip(self._meta.fields, args))
        kwargs = dict(self._meta.defaults)
        common = set(args).intersection(extra)
        if common:
            raise TypeError(f"repeated occurrence of arguments: {common}")

        kwargs.update(args)
        kwargs.update(extra)
        missing = set(kwargs) - set(self._meta.fields)
        if missing:
            raise TypeError(f"missing arguments: {missing}")

        types = dict(zip(self._meta.fields, self._meta.types))
        for k, v in kwargs.items():
            tt = types[k]
            if isinstance(tt, type) and not isinstance(v, tt):
                vt = type(v).__name__
                tt = tt.__name__
                raise TypeError(f"invalid type for {k}: got {vt!r}, expected {tt!r}")
            slot = get_slot(cls, k)
            slot.__set__(self, v)

    def __repr__(self):
        return "%s(%s)" % (
            self.__class__.__name__,
            ", ".join(repr(getattr(self, x)) for x in self._meta.fields),
        )

    def __eq__(self, other):
        if isinstance(other, (type(self), record, namespace)):
            return len(self) == len(other) and all(
                getattr(self, k) == getattr(other, k) for k in self._meta.fields
            )
        return NotImplemented

    def __getstate__(self):
        return tuple(self.M.values())

    def __setstate__(self, state):
        # noinspection PyArgumentList
        self.__init__(*state)

    def __json__(self):
        return dict(self.M)

    def __hash__(self):
        return hash(tuple(self))

    def __len__(self):
        return len(self._meta.fields)

    def __iter__(self):
        # Support conversion to dict through iteration in (attr, value) pairs.
        return ((f, getattr(self, f)) for f in self._meta.fields)


# noinspection PyRedeclaration
class Record(RecordMixin, metaclass=RecordMeta):
    """
    Base class for Record types.

    A records is a lightweight class that have only a fixed number of
    attributes. It is analogous to a C struct type.

    Record types can be used to hold data or as a basis for a no-boilerplate
    class.
    """

    __slots__ = ()

    M = property(lambda self: MapView(self))


# noinspection PyRedeclaration
class Namespace(RecordMixin, metaclass=RecordMeta):
    """
    A mutable record-like type.
    """

    __slots__ = ()
    M = property(lambda self: MutableMapView(self))