dmyersturnbull/pocketutils

View on GitHub
src/pocketutils/core/dot_dict.py

Summary

Maintainability
B
6 hrs
Test Coverage
# SPDX-FileCopyrightText: Copyright 2020-2023, Contributors to pocketutils
# SPDX-PackageHomePage: https://github.com/dmyersturnbull/pocketutils
# SPDX-License-Identifier: Apache-2.0
"""

"""

from __future__ import annotations

import abc
import io
import json
import pickle
from collections import defaultdict
from configparser import ConfigParser
from dataclasses import dataclass
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Self, TypeVar, Unpack

if TYPE_CHECKING:
    from collections.abc import Callable, Iterable, Mapping, MutableMapping

_Single = None | str | int | float | date | datetime
TomlLeaf = list[_Single] | _Single
# TomlBranch = dict[str, TomlLeaf | "TomlBranch"]
TomlBranch = dict[str, TomlLeaf | dict]
T = TypeVar("T", bound=TomlLeaf | TomlBranch)


class _Utils:
    @classmethod
    def json_encode_default(cls: type[Self], obj: Any) -> Any:
        if isinstance(obj, NestedDotDict):
            return dict(obj)

    @classmethod
    def dots_to_dict(cls: type[Self], items: Mapping[str, TomlLeaf]) -> dict[str, TomlLeaf | TomlBranch]:
        """
        Make sub-dictionaries from substrings in `items` delimited by `.`.
        Used for TOML.

        Example:

            Utils.dots_to_dict({"genus.species": "fruit bat"}) == {"genus": {"species": "fruit bat"}}

        See Also:
            [`dict_to_dots`](pocketutils.core.dot_dict._Utils.dicts_to_dots)
        """
        dct = {}
        cls._un_leaf(dct, items)
        return dct

    @classmethod
    def dict_to_dots(cls: type[Self], items: Mapping[str, TomlLeaf | TomlBranch]) -> dict[str, TomlLeaf]:
        """
        Performs the inverse of [`dict_to_dots`](pocketutils.core.dot_dict._Utils.dots_to_dicts].

        Example:

            Utils.dict_to_dots({"genus": {"species": "fruit bat"}}) == {"genus.species": "fruit bat"}
        """
        return dict(cls._re_leaf("", items))

    @classmethod
    def _un_leaf(cls: type[Self], to: MutableMapping[str, Any], items: Mapping[str, Any]) -> None:
        for k, v in items.items():
            if "." not in k:
                to[k] = v
            else:
                k0, k1 = k.split(".", 1)
                if k0 not in to:
                    to[k0] = {}
                cls._un_leaf(to[k0], {k1: v})

    @classmethod
    def _re_leaf(cls: type[Self], at: str, items: Mapping[str, Any]) -> Iterable[tuple[str, Any]]:
        for k, v in items.items():
            me = at + "." + k if len(at) > 0 else k
            if hasattr(v, "items") and hasattr(v, "keys") and hasattr(v, "values"):
                yield from cls._re_leaf(me, v)
            else:
                yield me, v

    @classmethod
    def check(cls: type[Self], dct: Mapping | TomlBranch | TomlLeaf) -> TomlLeaf | TomlBranch:
        if hasattr(dct, "items") and hasattr(dct, "keys") and hasattr(dct, "values"):
            bad = [k for k in dct if "." in k]
            if len(bad) > 0:
                msg = f"Key(s) contain '.': {bad}"
                raise ValueError(msg)
            return dict({k: cls.check(v) for k, v in dct.items()})
        return dct


@dataclass(frozen=True, slots=True)
class AbstractYamlMixin(metaclass=abc.ABCMeta):
    @classmethod
    def from_yaml(cls: type[Self], data: str) -> Self:
        raise NotImplementedError()

    def to_yaml(self: Self, **kwargs: Unpack[Mapping[str, Any]]) -> str:
        """
        Returns YAML text.
        """
        raise NotImplementedError()


@dataclass(frozen=True, slots=True)
class RuamelYamlMixin(AbstractYamlMixin):
    @classmethod
    def from_yaml(cls: type[Self], data: str) -> Self:
        from ruamel.yaml import YAML

        yaml = YAML(typ="safe")
        return cls(yaml.load(data))

    def to_yaml(self: Self, **kwargs: Unpack[Mapping[str, Any]]) -> str:
        from ruamel.yaml import YAML

        yaml = YAML(typ="safe")
        return yaml.dump(self, **kwargs)


@dataclass(frozen=True, slots=True)
class AbstractJsonMixin(metaclass=abc.ABCMeta):
    @classmethod
    def from_json(cls: type[Self], data: str) -> Self:
        raise NotImplementedError()

    def to_json(self: Self) -> str:
        """
        Returns JSON text.
        """
        raise NotImplementedError()


@dataclass(frozen=True, slots=True)
class OrjsonJsonMixin(AbstractJsonMixin):
    @classmethod
    def from_json(cls: type[Self], data: str) -> Self:
        import orjson

        return cls(orjson.loads(data))

    def to_json(self: Self) -> str:
        import orjson

        encoded = orjson.dumps(self, default=_Utils.json_encode_default)
        return encoded.decode(encoding="utf-8")


@dataclass(frozen=True, slots=True)
class AbstractTomlMixin(metaclass=abc.ABCMeta):
    @classmethod
    def from_toml(cls: type[Self], data: str) -> Self:
        raise NotImplementedError()

    def to_toml(self: Self) -> str:
        """
        Returns TOML text.
        """
        raise NotImplementedError()


@dataclass(frozen=True, slots=True)
class TomllibTomlMixin(AbstractTomlMixin):
    @classmethod
    def from_toml(cls: type[Self], data: str) -> Self:
        import tomllib

        return cls(tomllib.loads(data))

    def to_toml(self: Self) -> str:
        raise NotImplementedError()


@dataclass(frozen=True, slots=True)
class TomlkitTomlMixin(AbstractTomlMixin):
    @classmethod
    def from_toml(cls: type[Self], data: str) -> Self:
        import tomlkit

        return cls(tomlkit.loads(data))

    def to_toml(self: Self) -> str:
        import tomlkit

        return tomlkit.dumps(self)


@dataclass(frozen=True, slots=True)
class BuiltinJsonMixin(AbstractJsonMixin):
    @classmethod
    def from_json(cls: type[Self], data: str) -> Self:
        import json

        return cls(json.loads(data))

    def to_json(self: Self) -> str:
        import json

        return json.dumps(self, ensure_ascii=False)


@dataclass(frozen=True, slots=True)
class PickleMixin(metaclass=abc.ABCMeta):
    @classmethod
    def from_pickle(cls: type[Self], data: bytes) -> Self:
        if not isinstance(data, bytes):
            data = bytes(data)
        return cls(pickle.loads(data))  # noqa: S301

    def to_pickle(self: Self) -> bytes:
        """
        Writes to a pickle file.
        """
        return pickle.dumps(self, protocol=self.PICKLE_PROTOCOL)


@dataclass(frozen=True, slots=True)
class AbstractIniMixin(metaclass=abc.ABCMeta):
    @classmethod
    def from_ini(cls: type[Self], data: str) -> Self:
        raise NotImplementedError()

    def to_ini(self: Self) -> str:
        """
        Returns INI text.
        """
        raise NotImplementedError()


@dataclass(frozen=True, slots=True)
class BuiltinIniMixin(AbstractIniMixin):
    @classmethod
    def from_ini(cls: type[Self], data: str) -> Self:
        parser = ConfigParser()
        parser.read_string(data)
        return cls(parser)

    def to_ini(self: Self) -> str:
        config = ConfigParser()
        config.read_dict(self)
        writer = io.StringIO()
        config.write(writer)
        return writer.getvalue()


class AbstractDotDict(dict[str, TomlLeaf | TomlBranch]):
    """
    A thin wrapper around a nested dict, a wrapper for TOML.

    Keys must be strings that do not contain a dot (.).
    A dot is reserved for splitting values to traverse the tree.
    For example, `wrapped["pet.species.name"]`.
    """

    def __init__(self: Self, x: dict[str, TomlLeaf | TomlBranch] | Self) -> None:
        """
        Constructor.

        Raises:
            ValueError: If a key (in this dict or a sub-dict) is not a str or contains a dot
        """
        if not isinstance(x, AbstractDotDict | dict):
            msg = f"Not a dict; actually {type(x)} (value: '{x}')"
            raise TypeError(msg)
        _Utils.check(x)
        super().__init__(x)

    @classmethod
    def from_leaves(cls: type[Self], x: Mapping[str, TomlLeaf] | Self) -> Self:
        return cls(_Utils.dots_to_dict(x))

    def transform_leaves(self: Self, fn: Callable[[str, TomlLeaf], TomlLeaf]) -> Self:
        x = {k: fn(k, v) for k, v in self.leaves()}
        return self.__class__(x)

    def walk(self: Self) -> Iterable[TomlLeaf | TomlBranch]:
        for value in self.values():
            if isinstance(value, dict):
                yield from self.__class__(value).walk()
            else:
                yield value

    def nodes(self: Self) -> dict[str, TomlBranch | TomlLeaf]:
        return {**self.branches(), **self.leaves()}

    def branches(self: Self) -> dict[str, TomlBranch]:
        """
        Maps each lowest-level branch to a dict of its values.

        Note:
            Leaves directly under the root are assigned to key `''`.

        Returns:
            `dotted-key:str -> (non-dotted-key:str -> value)`
        """
        dicts = defaultdict(dict)
        for k, v in self.leaves().items():
            k0, _, k1 = str(k).rpartition(".")
            dicts[k0][k1] = v
        return dicts

    def leaves(self: Self) -> dict[str, TomlLeaf]:
        """
        Gets the leaves in this tree.

        Returns:
            `dotted-key:str -> value`
        """
        dct = {}
        for key, value in self.items():
            if isinstance(value, dict):
                dct.update({key + "." + k: v for k, v in self.__class__(value).leaves().items()})
            else:
                dct[key] = value
        return dct

    def sub(self: Self, items: str) -> Self:
        """
        Returns the dictionary under (dotted) keys `items`.
        """
        # noinspection PyTypeChecker
        return self.__class__(self[items])

    def get_as(self: Self, items: str, as_type: type[T], default: T | None = None) -> T:
        """
        Gets the key `items` from the dict, or `default` if it does not exist

        Args:
            items: The key hierarchy, with a dot (.) as a separator
            as_type: The type, which will be checked using `isinstance`
            default: Default to return the key is not found

        Returns:
            The value in the required type

        Raises:
            XTypeError: If not `isinstance(value, as_type)`
        """
        z = self.get(items, default)
        if not isinstance(z, as_type):
            msg = f"Value {z} from {items} is a {type(z)}, not {as_type}"
            raise TypeError(msg)
        return z

    def req_as(self: Self, items: str, as_type: type[T]) -> T:
        """
        Gets the key `items` from the dict.

        Args:
            items: The key hierarchy, with a dot (.) as a separator
            as_type: The type, which will be checked using `isinstance`

        Returns:
            The value in the required type

        Raises:
            XTypeError: If not `isinstance(value, as_type)`
        """
        z = self[items]
        if not isinstance(z, as_type):
            msg = f"Value {z} from {items} is a {type(z)}, not {as_type}"
            raise TypeError(msg)
        return z

    def get_list(self: Self, items: str, default: list[T] | None = None) -> list[T]:
        try:
            return self[items]
        except KeyError:
            return [] if default is None else default

    def get_list_as(self: Self, items: str, as_type: type[T], default: list[T] | None = None) -> list[T]:
        """
        Gets list values from an *optional* key.
        """
        try:
            x = self[items]
        except KeyError:
            return [] if default is None else default
        if not isinstance(x, list) or isinstance(x, str):
            msg = f"Value {x} is not a list for lookup {items}"
            raise TypeError(msg)
        bad = [y for y in x if not isinstance(y, as_type)]
        if len(bad) > 0:
            msg = f"Value(s) from {items} are not {as_type}: {bad}"
            raise TypeError(msg)
        return x

    def req_list_as(self: Self, items: str, as_type: type[T]) -> list[T]:
        """
        Gets list values from a *required* key.
        """
        x = self[items]
        if not isinstance(x, list) or isinstance(x, str):
            msg = f"Value {x} is not a list for lookup {items}"
            raise TypeError(msg)
        if not all(isinstance(y, as_type) for y in x):
            msg = f"Value {x} from {items} is a {type(x)}, not {as_type}"
            raise TypeError(msg)
        return x

    def req(self: Self, items: str) -> TomlLeaf | dict:
        return self[items]

    def get(self: Self, items: str, default: TomlLeaf | dict = None) -> TomlLeaf | dict:
        """
        Gets a value from an optional key.
        Also see `__getitem__`.
        """
        try:
            return self[items]
        except KeyError:
            return default

    def __getitem__(self: Self, items: str) -> TomlLeaf | dict:
        """
        Gets a value from a required key, operating on `.`-joined strings.

        Example:

            d = WrappedToml(dict(a=dict(b=1)))
            assert d["a.b"] == 1
        """
        if "." in items:
            i0, _, i_ = items.partition(".")
            z = self[i0]
            if not isinstance(z, dict | AbstractDotDict):
                msg = f"No key {items} (ends at {i0})"
                raise KeyError(msg)
            return self.__class__(z)[i_]
        return super().__getitem__(items)

    def __rich_repr__(self: Self) -> str:
        """
        Pretty-prints the leaves of this dict using `json.dumps`.

        Returns:
            A multi-line string
        """
        return json.dumps(self, ensure_ascii=True, indent=2)

    def _to_date(self: Self, s) -> date:
        if isinstance(s, date):
            return s
        elif isinstance(s, str):
            # This is MUCH faster than tomlkit's
            return date.fromisoformat(s)
        else:
            msg = f"Invalid type {type(s)} for {s}"
            raise TypeError(msg)

    def _to_datetime(self: Self, s: str | datetime) -> datetime:
        if isinstance(s, datetime):
            return s
        elif isinstance(s, str):
            # This is MUCH faster than tomlkit's
            if s.count(":") < 2:
                msg = f"Datetime {s} does not contain hours, minutes, and seconds"
                raise ValueError(msg)
            return datetime.fromisoformat(s.upper().replace("Z", "+00:00"))
        else:
            msg = f"Invalid type {type(s)} for {s}"
            raise TypeError(msg)


try:
    import orjson

    _Json = OrjsonJsonMixin
except ImportError:
    _Json = BuiltinJsonMixin

try:
    import orjson

    _Toml = TomlkitTomlMixin
except ImportError:
    _Toml = TomllibTomlMixin

try:
    import ruamel

    _Yaml = RuamelYamlMixin
except ImportError:
    _Yaml = None

if _Yaml is None:

    class NestedDotDict(AbstractDotDict, _Json, _Toml, BuiltinIniMixin, PickleMixin):
        pass
else:

    class NestedDotDict(AbstractDotDict, _Json, _Toml, _Yaml, BuiltinIniMixin, PickleMixin):
        pass


__all__ = ["NestedDotDict", "TomlLeaf", "TomlBranch"]