apimodel/apimodel.py
"""APIModel class with all the validation."""
from __future__ import annotations
import typing
from . import errors, fields, tutils, utility, validation
__all__ = ["APIModel", "APIModelMeta", "create_model"]
ValidatorT = typing.TypeVar("ValidatorT", bound=validation.BaseValidator)
APIModelT = typing.TypeVar("APIModelT", bound="APIModel")
def _get_ordered(validators: typing.Sequence[ValidatorT], order: validation.Order) -> typing.Sequence[ValidatorT]:
"""Get validators that fall into the selected order category."""
return [validator for validator in validators if order <= validator.order < (order + 10)]
def _serialize_attr(attr: object, **kwargs: object) -> object:
"""Serialize an attribute."""
if isinstance(attr, APIModel):
return attr.as_dict(**kwargs)
if tutils.generic_isinstance(attr, typing.Mapping[object, object]):
return {_serialize_attr(k, **kwargs): _serialize_attr(v, **kwargs) for k, v in attr.items()}
if tutils.generic_isinstance(attr, typing.Sequence[object], exclude=str):
return [_serialize_attr(x, **kwargs) for x in attr]
return attr
def _to_mapping(obj: object, **kwargs: object) -> typing.Mapping[str, object]:
"""Turn an arbitrary object into a mapping for APIModel."""
if isinstance(obj, APIModel):
obj = obj.as_dict()
if obj is None:
obj = {}
if not isinstance(obj, typing.Mapping):
raise TypeError(f"Unparsable object: {obj}")
return {**obj, **kwargs}
class APIModelMeta(type):
"""API model metaclass.
Stores fields and validators generated for every model.
"""
__slots__: typing.Sequence[str]
__fields__: typing.Mapping[str, fields.ModelFieldInfo]
"""Fields with their validators."""
__extras__: typing.Mapping[str, fields.ExtraInfo]
"""Extra values proxied to all nested models."""
__properties__: typing.Mapping[str, str]
"""Instance attributes to be included in the serialized representation."""
__root_validators__: typing.Sequence[validation.RootValidator]
"""Root validators."""
def __new__(
cls,
name: str,
bases: typing.Tuple[type],
namespace: typing.Dict[str, object],
*,
field_cls: typing.Optional[typing.Type[fields.ModelFieldInfo]] = None,
slots: typing.Optional[bool] = None,
**options: object,
) -> tutils.Self:
"""Create a new model class.
Collects all fields and validators.
"""
self = super().__new__(cls, name, bases, namespace)
self.__fields__ = {}
self.__extras__ = {}
self.__properties__ = {}
self.__root_validators__ = []
if field_cls is None:
possible: typing.Collection[typing.Type[fields.ModelFieldInfo]]
possible = set(type(field) for base in bases for field in getattr(base, "__fields__", {}).values())
field_cls = next(iter(possible)) if len(possible) == 1 else fields.ModelFieldInfo
slots = hasattr(bases[0], "__slots__") if slots is None else slots
for name, annotation in typing.get_type_hints(self).items():
obj = getattr(self, name, ...)
if isinstance(obj, fields.ExtraInfo):
continue # resolved later
self.__fields__[name] = field_cls.from_annotation(name, annotation, obj, model=self)
for name in dir(self):
obj = getattr(self, name, ...)
if isinstance(obj, validation.RootValidator):
self.__root_validators__.append(obj)
elif isinstance(obj, validation.Validator):
for field_name in obj._fields:
self.__fields__[field_name].add_validators(obj)
elif isinstance(obj, fields.ExtraInfo):
obj.alias = obj.alias or name.lstrip("_")
self.__extras__[name] = obj
elif isinstance(obj, property):
if isinstance(obj, fields.NamedProperty):
if obj.exclude:
continue
self.__properties__[name] = obj.alias
elif name[0] != "_":
self.__properties__[name] = name
self.__root_validators__.sort(key=lambda v: v.order)
if slots and "__slots__" not in namespace:
previous_slots = set(slot for base in bases for slot in utility.get_slots(base))
all_slots = set((*self.__fields__.keys(), *self.__extras__.keys()))
self.__slots__ = tuple(all_slots - previous_slots)
return self
def __repr__(self) -> str:
args = ", ".join(f"{k}={v!r}" for k, v in self.__fields__.items())
return f"{self.__class__.__name__}({self.__name__!r}{f', {args}' if args else ''})"
def __devtools_pretty(self, fmt: typing.Callable[[object], str], **kwargs: object) -> typing.Iterator[object]:
"""Devtools pretty formatting."""
yield from utility.devtools_pretty(
fmt,
self.__name__,
self.__root_validators__,
__name__=self.__class__.__name__,
**self.__extras__,
**self.__fields__,
)
if not typing.TYPE_CHECKING:
def __getattribute__(self, name: str) -> object:
if name == "__pretty__":
return self.__devtools_pretty
return super().__getattribute__(name)
@property
def isasync(self) -> bool:
"""Whether the model is async."""
return (
# field validators
any(validator.isasync for field in self.__fields__.values() for validator in field.validators)
# root validators
or any(validator.isasync for validator in self.__root_validators__)
)
@utility.as_universal_method
async def validate( # noqa # C901: too complex
self,
obj: tutils.JSONMapping,
*,
instance: typing.Optional[APIModel] = None,
extras: bool = False,
) -> tutils.JSONMapping:
"""Validate a mapping.
Returns the validated mapping.
If an instance is not passed in, a dummy instance will be created.
"""
if instance is None:
instance = APIModel._empty(freeform=True)
self = typing.cast("typing.Type[APIModel]", self)
# =============================
# EXTRAS
if extras:
with errors.catch_errors(self) as catcher:
for attr_name, extra in self.__extras__.items():
if extra.alias in obj:
setattr(instance, attr_name, obj[extra.alias])
elif extra.default is not ...:
setattr(instance, attr_name, extra.default)
else:
catcher.add_error(TypeError(f"Missing required extra field: {extra.alias!r}"), loc=attr_name)
# =============================
# INITIAL ROOT
with errors.catch_errors(self) as catcher:
for validator in _get_ordered(self.__root_validators__, order=validation.Order.INITIAL_ROOT):
with catcher.catch():
obj = await validator(instance, obj)
obj = dict(obj)
# =============================
# ALIAS
new_obj: tutils.JSONMapping = {}
for attr_name, field in self.__fields__.items():
default = field._get_default()
if field.alias not in obj and default is not ...:
obj[attr_name] = default
if field.alias in obj:
setattr(instance, attr_name, obj[field.alias])
new_obj[attr_name] = obj[field.alias]
obj = new_obj
# =============================
# ROOT
with errors.catch_errors(self) as catcher:
for validator in _get_ordered(self.__root_validators__, order=validation.Order.ROOT):
with catcher.catch():
obj = await validator(instance, obj)
# =============================
# FIELD CHECK
with errors.catch_errors(self) as catcher:
for attr_name, field in self.__fields__.items():
if attr_name not in obj:
catcher.add_error(TypeError(f"Missing required field: {field.alias!r}"), loc=attr_name)
obj = dict(obj)
# =============================
# VALIDATOR
orders = (validation.Order.VALIDATOR, validation.Order.ANNOTATION, validation.Order.POST_VALIDATOR)
for order in orders:
# order is next to arbitrary, only here because of ANNOTATION
with errors.catch_errors(self) as catcher:
for attr_name, field in self.__fields__.items():
for validator in _get_ordered(field.validators, order=order):
with catcher.catch(loc=attr_name):
obj[attr_name] = await validator(instance, obj[attr_name])
setattr(instance, attr_name, obj[attr_name])
# =============================
# FINAL ROOT
with errors.catch_errors(self) as catcher:
for validator in _get_ordered(self.__root_validators__, order=validation.Order.FINAL_ROOT):
with catcher.catch():
obj = await validator(instance, obj)
# =============================
return obj
class APIModel(utility.Representation, metaclass=APIModelMeta):
"""Base APIModel class."""
# populated by metaclass
__slots__ = ()
def __new__(
cls: typing.Type[APIModelT],
obj: typing.Optional[object] = None,
**kwargs: object,
) -> APIModelT:
"""Create a new model instance.
All async models should be created using `APIModel.create` instead.
"""
if cls.isasync:
raise TypeError("Must use the create method with an async APIModel.")
if isinstance(obj, cls):
return obj
self = super().__new__(cls)
self.update_model.synchronous(_to_mapping(obj, **kwargs))
return self
# TODO: universal async for classmethods
@classmethod
async def create(
cls: typing.Type[APIModelT],
obj: typing.Optional[object] = None,
**kwargs: object,
) -> APIModelT:
"""Create a new model instance asynchronously."""
if isinstance(obj, cls):
return obj
self = super().__new__(cls)
await self.update_model(_to_mapping(obj, **kwargs))
return self
@utility.as_universal_method
async def update_model(self, obj: tutils.JSONMapping) -> tutils.JSONMapping:
"""Update a model instance asynchronously."""
return await self.__class__.validate(obj, instance=self, extras=True)
def as_dict(
self,
*,
private: bool = False,
properties: bool = True,
alias: bool = False,
**options: object,
) -> typing.Mapping[str, object]:
"""Create a mapping from the model instance.
Args:
private: Include private attributes (prefixed with an underscore `_`).
properties: Include methods decorated with `@property`.
alias: Rename fields to their declared name.
"""
obj: typing.Mapping[str, object] = {}
for attr_name, field in self.__class__.__fields__.items():
if field.private and not private:
continue
field_name = field.alias if alias else attr_name
attr = getattr(self, attr_name)
obj[field_name] = _serialize_attr(attr, private=private, alias=alias)
if properties:
obj.update({name: getattr(self, name) for name in self.__class__.__properties__})
return obj
def get_extras(self, alias: bool = True) -> typing.Mapping[str, object]:
"""Get extra fields which are normally not part of the model."""
obj: typing.Mapping[str, object] = {}
for attr_name, extra in self.__class__.__extras__.items():
field_name = extra.alias if alias else attr_name
if hasattr(self, attr_name):
obj[field_name] = getattr(self, attr_name)
return obj
def __repr_args__(self) -> typing.Mapping[str, object]:
args: typing.Mapping[str, object] = {}
args.update({attr: getattr(self, attr) for attr in self.__class__.__fields__})
args.update({name: getattr(self, attr_name) for attr_name, name in self.__class__.__properties__.items()})
return args
if not typing.TYPE_CHECKING:
@classmethod
def __get_validators__(cls) -> typing.Iterator[typing.Callable[..., object]]:
"""Get pydantic validators for compatibility."""
yield cls
@classmethod
def __modify_schema__(cls, field_schema: typing.Dict[str, object]) -> None:
"""Create a schema for pydantic."""
field_schema.update(
type="object",
properties={field.alias: dict(type="any") for field in cls.__fields__.values() if not field.private},
)
@classmethod
def _empty(cls, freeform: bool = False) -> APIModel:
"""Return an empty base APIModel."""
if freeform:
cls = type(cls.__name__, (cls,), {})
return super().__new__(cls)
def create_model(
__name__: str,
__bases__: tutils.MaybeSequence[typing.Type[APIModel]] = APIModel,
/,
__fields__: typing.Optional[typing.Mapping[str, fields.ModelFieldInfo]] = None,
__extras__: typing.Optional[typing.Mapping[str, fields.ExtraInfo]] = None,
__root_validators__: typing.Optional[typing.Sequence[validation.RootValidator]] = None,
**attrs: typing.Union[object, fields.FieldInfo, fields.ExtraInfo, typing.Tuple[object, object]],
) -> typing.Type[APIModel]:
"""Dynamically create a model.
Fields must be in the formats `name=type` | `name=FieldInfo()` | `name=(type, default)`
"""
namespace: typing.Dict[str, typing.Any] = {}
annotations = namespace["__annotations__"] = {}
for name, attr in attrs.items():
if isinstance(attr, (fields.FieldInfo, fields.ExtraInfo)):
namespace[name] = attr
annotations[attr] = attr.tp if isinstance(attr, fields.ModelFieldInfo) else object
elif tutils.generic_isinstance(attr, typing.Tuple[object, object]):
if len(attr) != 2:
raise TypeError("Tuple must be (type, default)")
annotations[name] = attr[0]
namespace[name] = attr[1]
else:
annotations[name] = attr
model: typing.Type[APIModel] = type(__name__, tuple(utility.flatten_sequences(__bases__)), namespace) # type: ignore
if __fields__:
model.__fields__ = {**model.__fields__, **__fields__}
if __extras__:
model.__extras__ = {**model.__extras__, **__extras__}
if __root_validators__:
model.__root_validators__ = [*model.__root_validators__, *__root_validators__]
return model