dmyersturnbull/pocketutils

View on GitHub
src/pocketutils/tools/common_tools.py

Summary

Maintainability
C
7 hrs
Test Coverage
# SPDX-FileCopyrightText: Copyright 2020-2023, Contributors to pocketutils
# SPDX-PackageHomePage: https://github.com/dmyersturnbull/pocketutils
# SPDX-License-Identifier: Apache-2.0
"""

"""

import logging
import operator
import re
import sys
from collections import defaultdict, deque
from collections.abc import Callable, Generator, Hashable, Iterable, Iterator, Mapping, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from types import LambdaType
from typing import Any, Literal, Self, TypeVar

from pocketutils import FrozeDict, FrozeList, FrozeSet
from pocketutils.core.exceptions import MultipleMatchesError, NoMatchesError, ValueIllegalError
from pocketutils.core.input_output import Writeable, return_none_1_param

__all__ = ["CommonUtils", "CommonTools", "Writeable"]

T = TypeVar("T")
V_co = TypeVar("V_co", covariant=True)
K_contra = TypeVar("K_contra", contravariant=True)
logger = logging.getLogger("pocketutils")
lambda_regex = re.compile(r"^<function (?:[A-Za-z_][A-Za-z0-9_.]*)?(?:<locals>\.)?<lambda> at 0x[A-F0-9]+>$")


@dataclass(slots=True, frozen=True)
class CommonUtils:
    def freeze(
        self: Self,
        v: Sequence[V_co] | set[V_co] | dict[K_contra, V_co] | frozenset[V_co] | Hashable,
    ) -> FrozeList[V_co] | FrozeSet[V_co] | FrozeDict[K_contra, V_co] | Hashable:
        """
        Returns `v` or a hashable view of it.
        Note that the returned types must be hashable but might not be ordered.
        You can generally add these values as DataFrame elements, but you might not
        be able to sort on those columns.

        Args:
            v: Any value

        Returns:
            Either `v` itself,
            [`FrozeSet`](pocketutils.core.frozen_types.FrozeSet) (subclass of `collections.abc.Set`),
            a [`FrozeList`](pocketutils.core.frozen_types.FrozeList) (subclass of `collections.abc.Sequence`),
            or a [`FrozeDict`](pocketutils.core.frozen_types.FrozeDict) (subclass of `collections.abc.Mapping`).
            int, float, str, np.generic, and tuple are always returned as-is.

        Raises:
            AttributeError: If `v` is not hashable and could not be converted to
                            a FrozeSet, FrozeList, or FrozeDict, *or* if one of the elements for
                            one of the above types is not hashable.
            TypeError: If `v` is an `Iterator` or `deque`
        """
        if isinstance(v, int | float | str | tuple | frozenset):
            return v  # short-circuit
        if str(type(v)).startswith("<class 'numpy."):
            return v
        match v:
            case Iterator():  # let's not ruin their iterator by traversing
                msg = "Type is an iterator"
                raise TypeError(msg)
            case deque():  # the only other major built-in type we won't accept
                msg = "Type is a deque"
                raise TypeError(msg)
            case Sequence():
                return FrozeList(v)
            case set():
                return FrozeSet(v)
            case frozenset():
                return FrozeSet(v)
            case Mapping():
                return FrozeDict(v)
            case _:
                hash(v)  # raise an AttributeError if not hashable
                return v

    def nice_size(self: Self, n_bytes: int, *, space: str = "") -> str:
        """
        Uses IEC 1998 units, such as KiB (1024).
            n_bytes: Number of bytes
            space: Separator between digits and units

            Returns:
                Formatted string
        """
        data = {
            "PiB": 1024**5,
            "TiB": 1024**4,
            "GiB": 1024**3,
            "MiB": 1024**2,
            "KiB": 1024**1,
        }
        for suffix, scale in data.items():
            if n_bytes >= scale:
                break
        else:
            scale, suffix = 1, "B"
        return str(n_bytes // scale) + space + suffix

    def limit(self: Self, items: Iterable[V_co], n: int) -> Iterator[V_co]:
        for _i, x in zip(range(n), items, strict=False):
            yield x

    def is_float(self: Self, s: Any) -> bool:
        """
        Returns whether `float(s)` succeeds.
        """
        try:
            float(s)
            return True
        except ValueError:
            return False

    def try_or_none(
        self: Self,
        function: Callable[[], V_co],
        or_else: V_co | None = None,
        exception: type[Exception] = Exception,
    ) -> V_co | None:
        """
        Returns the value of a function or None if it raised an exception.

        Args:
            function: Try calling this function
            or_else: Return this value
            exception: Restrict caught exceptions to subclasses of this type
        """
        # noinspection PyBroadException
        try:
            return function()
        except exception:
            return or_else

    def succeeds(self: Self, function: Callable[[], Any], exception: type[BaseException] = Exception) -> bool:
        """Returns True iff `function` does not raise an error."""
        try:
            function()
        except exception:
            return False
        return True

    def iterator_has_elements(self: Self, it: Iterator) -> bool:
        """
        Returns False iff `next(x)` raises a `StopIteration`.
        WARNING: Tries to call `next(x)`, progressing iterators. Don't use `x` after calling this.
        Note that calling `iterator_has_elements([5])` will raise a `TypeError`

        Args:
            it: Must be an Iterator
        """
        return self.succeeds(lambda: next(it), StopIteration)

    def is_null(self: Self, v: V_co) -> bool:
        """
        Returns True for None, NaN, and NaT (not a time) values from Numpy, Pandas, and Python.
        Not perfect; may return false positive True for types declared outside Numpy and Pandas.
        """
        if v is None:
            return True
        if isinstance(v, str):
            return False
        return str(v) in [
            "nan",  # float('NaN') and Numpy float NaN
            "NaN",  # Pandas NaN and decimal.Decimal NaN
            "<NA>",  # Pandas pd.NA
            "NaT",  # Numpy datetime and timedelta NaT
        ]

    def is_empty(self: Self, v: V_co) -> bool:
        """
        Returns True if x is None, NaN according to Pandas, or contains 0 items.

        That is, if and only if:
            - `self.is_null(x)`
            - x is something with 0 length
            - x is iterable and has 0 elements (will call `__iter__`)

        Raises:
            RefusingRequestError: If `x` is an Iterator. Calling this would empty the iterator, which is dangerous.
        """
        if isinstance(v, Iterator):
            msg = "Do not call is_empty on an iterator."
            raise TypeError(msg)
        try:
            if self.is_null(v):
                return True
        except (ValueError, TypeError):  # noqa: S110
            pass
        return hasattr(v, "__len__") and len(v) == 0 or hasattr(v, "__iter__") and len(list(iter(v))) == 0

    def is_probable_null(self: Self, v: V_co) -> bool:
        """
        Returns True if `x` is None, NaN according to Pandas, 0 length, or a string representation.

        Specifically, returns True if and only if:
            - `is_null`
            - x is something with 0 length
            - x is iterable and has 0 elements (will call `__iter__`)
            - a str(x) is 'nan', 'na', 'n/a', 'null', or 'none'; case-insensitive

        Things that are **NOT** probable nulls:
            - `0`
            - `[None]`

        Raises:
            TypeError: If `x` is an Iterator. Calling this would empty the iterator, which is dangerous.
        """
        return self.is_empty(v) or str(v).lower() in {"nan", "n/a", "na", "null", "none", "<NA>", "NaT"}

    def unique(self: Self, sequence: Iterable[V_co]) -> Sequence[V_co]:
        """
        Returns the unique items in `sequence`, in the order they appear in the iteration.

        Args:
            sequence: Any once-iterable sequence

        Returns:
            An ordered List of unique elements
        """
        seen = set()
        return [x for x in sequence if not (x in seen or seen.add(x))]

    def first(self: Self, collection: Iterable[V_co]) -> V_co:
        """
        Gets the first element.

        Warning: Tries to call `next(x)`, progressing iterators.

        Args:
            collection: Any iterable

        Returns:
            Either `None` or the value, according to the rules:
                - The attribute of the first element if `attr` is defined on an element
                - None if the sequence is empty
                - None if the sequence has no attribute `attr`
        """
        try:
            # note: calling iter on an iterator creates a view only
            return next(iter(collection))
        except StopIteration:
            return None

    def iter_row_col(self: Self, rows: int, cols: int) -> Iterator[tuple[int, int]]:
        """
        An iterator over (row column) pairs for a row-first grid traversal.

        Example:

            it = CommonTools.iter_rowcol(5, 3)
            [next(it) for _ in range(5)]  # [(0,0),(0,1),(0,2),(1,0),(1,1)]
        """
        for i in range(rows * cols):
            yield i // cols, i % cols

    def multidict(
        self: Self,
        sequence: Iterable[V_co],
        key_attr: str | Iterable[str] | Callable[[K_contra], V_co],
        skip_none: bool = False,
    ) -> Mapping[K_contra, Sequence[V_co]]:
        """
        Builds a mapping from keys to multiple values.
        Builds a mapping of some attribute in `sequence` to
        the containing elements of `sequence`.

        Args:
            sequence: Any iterable
            key_attr: Usually string like 'attr1.attr2'; see `look`
            skip_none: If None, raises a `KeyError` if the key is missing for any item; otherwise, skips it
        """
        dct = defaultdict(list)
        for item in sequence:
            v = self.look(item, key_attr)
            if not skip_none and v is None:
                msg = f"No {key_attr} in {item}"
                raise KeyError(msg)
            if v is not None:
                dct[v].append(item)
        return dct

    def parse_bool(self: Self, s: str) -> bool:
        """
        Parses a 'true'/'false' string to a bool, ignoring case.

        Raises:
            ValueError: If neither true nor false
        """
        if isinstance(s, bool):
            return s
        if s.lower() == "false":
            return False
        if s.lower() == "true":
            return True
        msg = f"{s} is not true/false"
        raise ValueIllegalError(msg, value=s)

    def parse_bool_flex(self: Self, s: str) -> bool:
        """
        Parses a 'true'/'false'/'yes'/'no'/... string to a bool, ignoring case.

        Allowed:
            - "true", "t", "yes", "y", "1"
            - "false", "f", "no", "n", "0"

        Raises:
            XValueError: If neither true nor false
        """
        mp = {
            **{v: True for v in ("true", "t", "yes", "y", "1")},
            **{v: False for v in ("false", "f", "no", "n", "0")},
        }
        v = mp.get(s.lower())
        if v is None:
            msg = f"{s.lower()} is not in {','.join(mp.keys())}"
            raise ValueIllegalError(msg, value=s)
        return v

    def is_lambda(self: Self, function: Any) -> bool:
        """
        Returns whether this is a lambda function. Will return False for non-callables.
        """
        return isinstance(function, LambdaType)

    def only(
        self: Self,
        sequence: Iterable[T],
        condition: str | Callable[[T], bool] | None = None,
        *,
        exception_if_none: Exception = NoMatchesError(),
        exception_if_multiple: Exception = MultipleMatchesError(),
    ) -> T:
        """
        Returns either the SINGLE (ONLY) UNIQUE ITEM in the sequence or raises an exception.
        Each item must have __hash__ defined on it.

        Args:
            sequence: A list of any items (untyped)
            condition: If nonnull, consider only those matching this condition (using `self.look`)
            exception_if_none: Exception to raise if none match
            exception_if_multiple: Exception to raise if multiple match

        Returns:
            The first item the sequence.

        Raises:
            LookupError: If the sequence is empty
            MultipleMatchesError: If there is more than one unique item.
        """

        def _only(sq: Iterable[T]):
            st = set(sq)
            if len(st) > 1:
                raise exception_if_multiple
            if len(st) == 0:
                raise exception_if_none
            return next(iter(st))

        if condition and isinstance(condition, str):
            return _only(
                [
                    s
                    for s in sequence
                    if (not self.look(s, condition[1:]) if condition.startswith("!") else self.look(s, condition))
                ],
            )
        elif condition:
            return _only([s for s in sequence if condition(s)])
        return _only(sequence)

    def forever(self: Self) -> Iterator[int]:
        """
        Yields i for i in range(0, infinity).
        Useful for simplifying a i = 0; while True: i += 1 block.
        """
        i = 0
        while True:
            yield i
            i += 1

    def is_true_iterable(self: Self, v: Any) -> bool:
        """
        Returns whether `s` is an iterable but not `str` and not `bytes`.

        Warning:
            Types that do not define `__iter__` but are iterable
            via `__getitem__` will not be included.
        """
        return (
            v is not None
            and isinstance(v, Iterable)
            and not isinstance(v, str)
            and not isinstance(v, bytes | bytearray | memoryview)
        )

    @contextmanager
    def null_context(self: Self) -> Generator[None, None, None]:
        """
        Returns an empty context (literally just yields).
        Useful to simplify when a generator needs to be used depending on a switch.

        Example:

            if verbose_flag:
                do_something()
            else:
                with Tools.silenced():
                    do_something()

        Can become:

        ```python
        with (Tools.null_context() if verbose else Tools.silenced()):
            do_something()
        ```
        """
        yield

    def look(self: Self, obj: T, attrs: str | Iterable[str] | Callable[[T], V_co]) -> V_co | None:
        """
        Follows a dotted syntax for getting an item nested in class attributes.
        Returns the value of a chain of attributes on object `obj`,
        or None any object in that chain is None or lacks the next attribute.

        Example:

            # Get a kitten's breed
            BaseTools.look(kitten), 'breed.name')  # either None or a string

        Args:
            obj: Any object
            attrs: One of:
                - A string in the form attr1.attr2, translating to `obj.attr1`
                - An iterable of strings of the attributes
                - A function that maps `obj` to its output;
                   equivalent to calling `attrs(obj)` but returning None on `AttributeError`.

        Returns:
            Either None or the type of the attribute
        """
        if attrs is None:
            return obj
        if not isinstance(attrs, str) and hasattr(attrs, "__len__") and len(attrs) == 0:
            return obj
        if isinstance(attrs, str):
            attrs = operator.attrgetter(attrs)
        elif isinstance(attrs, Iterable) and all(isinstance(a, str) for a in attrs):
            attrs = operator.attrgetter(".".join(attrs))
        elif not callable(attrs):
            msg = f"Type {type(attrs)} unrecognized for key/attrib. Must be a function, string, or sequence of strings"
            raise TypeError(msg)
        try:
            return attrs(obj)
        except AttributeError:
            return None

    def make_writer(self: Self, writer: Writeable | Callable[[str], Any]) -> Writeable:
        if Writeable.isinstance(writer):
            return writer
        elif callable(writer):

            class W(Writeable[V_co]):
                def write(self: Self, msg: V_co) -> int:
                    writer(msg)
                    return len(msg)

                def flush(self: Self) -> None:
                    pass

                def close(self: Self) -> None:
                    pass

            return W()
        msg_ = f"{type(writer)} cannot be wrapped into a Writeable"
        raise TypeError(msg_)

    def get_log_fn(
        self: Self,
        log: (Literal["stdout"] | str | int | Writeable | Callable[[str], Any] | None),
    ) -> Callable[[str], Any]:
        """
        Gets a logging function from user input.

        The rules are:
            - If None, uses logger.info
            - If 'print' or 'stdout',  use sys.stdout.write
            - If 'stderr', use sys.stderr.write
            - If another str or int, try using that logger level (raises an error if invalid)
            - If callable, returns it
            - If it has a callable method called 'write', uses that

        Returns:
            A function of the log message that returns None
        """
        if log is None:
            return return_none_1_param
        elif isinstance(log, str) and log.lower() == "stdout":
            # noinspection PyTypeChecker
            return sys.stdout.write
        elif log == "stderr":
            # noinspection PyTypeChecker
            return sys.stderr.write
        elif isinstance(log, int):
            return getattr(logger, logging.getLevelName(log).lower())
        elif isinstance(log, str):
            return getattr(logger, log.lower())
        elif callable(log):
            return log
        elif hasattr(log, "write"):
            return log.write
        msg = f"Log type {type(log)} not known"
        raise ValueIllegalError(msg, value=log)

    def sentinel(self: Self, name: str) -> Any:
        class _Sentinel:
            def __eq__(self: Self, other: Self) -> bool:
                return self is other

            def __reduce__(self: Self) -> str:
                return name  # returning string is for singletons

            def __hash__(self: Self) -> int:
                return hash(name)

            def __str__(self: Self) -> str:
                return name

            def __repr__(self: Self) -> str:
                return name

        return _Sentinel()

    def longest(self: Self, parts: Iterable[V_co]) -> V_co:
        """
        Returns an element with the highest `len`.
        """
        mx = ""
        for _i, x in enumerate(parts):
            if len(x) > len(mx):
                mx = x
        return mx


CommonTools = CommonUtils()