thesadru/apimodel

View on GitHub
apimodel/errors.py

Summary

Maintainability
A
1 hr
Test Coverage
A
93%
"""Pretty errors."""
from __future__ import annotations

import contextlib
import typing

from . import apimodel, tutils, utility

__all__ = ["LocError", "ValidationError", "catch_errors"]


Loc = typing.Tuple[typing.Union[int, str], ...]
RawLoc = typing.Union[int, str, Loc]
ErrorList = typing.Union[typing.Sequence["ErrorList"], "LocError"]


class LocError(utility.Representation, Exception):
    """Error with a location."""

    error: Exception
    """Original error."""

    loc: Loc
    """Location where the error ocurred."""

    def __init__(self, error: Exception, loc: RawLoc = "__root__") -> None:
        """Initialize a LocError.

        The instance will act as an intersection between LocError and the passed error.
        """
        self.error = error
        self.loc = loc if isinstance(loc, tuple) else (loc,)

        super().__init__(str(error))

    def __instancecheck__(self, instance: typing.Any) -> bool:
        return isinstance(instance, self.error.__class__)


class ValidationError(utility.Representation, ValueError):
    """Pretty validation error inspired by pydantic."""

    errors: typing.Sequence[LocError]
    """Collected errors."""

    model: typing.Type[apimodel.APIModel]
    """The model in which the error ocurred. Used for debugging."""

    def __init__(self, *errors: ErrorList, model: typing.Type[apimodel.APIModel]) -> None:
        """Initialize a ValidationError with a list of LocErrors."""
        self.errors = utility.flatten_sequences(*errors)
        self.model = model

        super().__init__(self.errors)

    def __str__(self) -> str:
        errors = list(flatten_errors(self.errors))
        return (
            f'{len(errors)} validation error{"" if len(errors) == 1 else "s"} for {self.model.__name__}\n'
            + "\n".join(f'{" -> ".join(map(str, loc))}\n  {error.__class__.__name__}: {error}' for loc, error in errors)
        )

    @property
    def locations(self) -> typing.Sequence[typing.Tuple[Loc, Exception]]:
        """Get flattened locations and errors."""
        return list(flatten_errors(self.errors))


def flatten_errors(
    errors: typing.Sequence[ErrorList],
    loc: typing.Optional[Loc] = None,
) -> typing.Iterator[typing.Tuple[Loc, Exception]]:
    """Flatten recursive errors."""
    for error in errors:
        if isinstance(error, LocError):
            error_loc = error.loc
            if loc:
                error_loc = loc + error_loc

            if isinstance(error.error, ValidationError):
                yield from flatten_errors(error.error.errors, loc=error_loc)
            else:
                yield (error_loc, error.error)

        else:
            yield from flatten_errors(error, loc=loc)


class ErrorCatcher:
    """Catch errors and append to a list."""

    __slots__ = ("errors", "model")

    errors: typing.MutableSequence[LocError]
    model: typing.Type[apimodel.APIModel]

    def __init__(self, model: tutils.MaybeType[apimodel.APIModel]) -> None:
        """Initialize an ErrorCatcher."""
        if not isinstance(model, type):
            model = type(model)

        self.errors = []
        self.model = model

    def add_error(self, error: Exception, loc: RawLoc) -> None:
        """Add an error to the list."""
        self.errors.append(LocError(error, loc))

    @contextlib.contextmanager
    def catch(self, loc: RawLoc = "__root__") -> typing.Iterator[None]:
        """Catch errors and append to a list."""
        try:
            yield
        except Exception as e:
            self.errors.append(LocError(e, loc))

    def __enter__(self) -> typing.ContextManager[None]:
        return self.catch()

    def raise_errors(self) -> None:
        """Raise errors."""
        if self.errors:
            raise ValidationError(self.errors, model=self.model)


@contextlib.contextmanager
def catch_errors(model: tutils.MaybeType[apimodel.APIModel]) -> typing.Iterator[ErrorCatcher]:
    """Catch errors and raise a ValidationError if at least one is present.

    Examples
    --------
    >>> with apimodel.errors.catch_errors(model) as catcher:
    >>>     for name, function in function:
    >>>         with catcher.catch(loc=function.__name__):
    >>>             function()
    """
    if not isinstance(model, type):
        model = type(model)

    catcher = ErrorCatcher(model)
    try:
        yield catcher
    finally:
        catcher.raise_errors()