collerek/ormar

View on GitHub
ormar/models/mixins/pydantic_mixin.py

Summary

Maintainability
A
0 mins
Test Coverage
import copy
import string
from random import choices
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Set,
    Tuple,
    Type,
    Union,
    cast,
)

import pydantic
from pydantic import BaseModel
from pydantic._internal._decorators import DecoratorInfos
from pydantic.fields import FieldInfo

from ormar.fields import BaseField, ForeignKeyField, ManyToManyField
from ormar.models.mixins.relation_mixin import RelationMixin  # noqa: I100, I202
from ormar.queryset.utils import translate_list_to_dict


class PydanticMixin(RelationMixin):
    __cache__: Dict[str, Type[pydantic.BaseModel]] = {}

    if TYPE_CHECKING:  # pragma: no cover
        __pydantic_decorators__: DecoratorInfos
        model_fields: Dict[str, FieldInfo]
        _skip_ellipsis: Callable
        _get_not_excluded_fields: Callable

    @classmethod
    def get_pydantic(
        cls,
        *,
        include: Union[Set, Dict, None] = None,
        exclude: Union[Set, Dict, None] = None,
    ) -> Type[pydantic.BaseModel]:
        """
        Returns a pydantic model out of ormar model.

        Converts also nested ormar models into pydantic models.

        Can be used to fully exclude certain fields in fastapi response and requests.

        :param include: fields of own and nested models to include
        :type include: Union[Set, Dict, None]
        :param exclude: fields of own and nested models to exclude
        :type exclude: Union[Set, Dict, None]
        """
        relation_map = translate_list_to_dict(cls._iterate_related_models())

        return cls._convert_ormar_to_pydantic(
            include=include, exclude=exclude, relation_map=relation_map
        )

    @classmethod
    def _convert_ormar_to_pydantic(
        cls,
        relation_map: Dict[str, Any],
        include: Union[Set, Dict, None] = None,
        exclude: Union[Set, Dict, None] = None,
    ) -> Type[pydantic.BaseModel]:
        if include and isinstance(include, Set):
            include = translate_list_to_dict(include)
        if exclude and isinstance(exclude, Set):
            exclude = translate_list_to_dict(exclude)
        fields_dict: Dict[str, Any] = dict()
        defaults: Dict[str, Any] = dict()
        fields_to_process = cls._get_not_excluded_fields(
            fields={*cls.ormar_config.model_fields.keys()},
            include=include,
            exclude=exclude,
        )
        fields_to_process.sort(
            key=lambda x: list(cls.ormar_config.model_fields.keys()).index(x)
        )

        cache_key = f"{cls.__name__}_{str(include)}_{str(exclude)}"
        if cache_key in cls.__cache__:
            return cls.__cache__[cache_key]

        for name in fields_to_process:
            field = cls._determine_pydantic_field_type(
                name=name,
                defaults=defaults,
                include=include,
                exclude=exclude,
                relation_map=relation_map,
            )
            if field is not None:
                fields_dict[name] = field
        model = type(
            f"{cls.__name__}_{''.join(choices(string.ascii_uppercase, k=3))}",
            (pydantic.BaseModel,),
            {"__annotations__": fields_dict, **defaults},
        )
        model = cast(Type[pydantic.BaseModel], model)
        cls._copy_field_validators(model=model)
        cls.__cache__[cache_key] = model
        return model

    @classmethod
    def _determine_pydantic_field_type(
        cls,
        name: str,
        defaults: Dict,
        include: Union[Set, Dict, None],
        exclude: Union[Set, Dict, None],
        relation_map: Dict[str, Any],
    ) -> Any:
        field = cls.ormar_config.model_fields[name]
        target: Any = None
        if field.is_relation and name in relation_map:
            target, default = cls._determined_included_relation_field_type(
                name=name,
                field=field,
                include=include,
                exclude=exclude,
                defaults=defaults,
                relation_map=relation_map,
            )
        elif not field.is_relation:
            defaults[name] = cls.model_fields[name].default
            target = field.__type__
        if target is not None and field.nullable:
            target = Optional[target]
        return target

    @classmethod
    def _determined_included_relation_field_type(
        cls,
        name: str,
        field: Union[BaseField, ForeignKeyField, ManyToManyField],
        include: Union[Set, Dict, None],
        exclude: Union[Set, Dict, None],
        defaults: Dict,
        relation_map: Dict[str, Any],
    ) -> Tuple[Type[BaseModel], Dict]:
        target = field.to._convert_ormar_to_pydantic(
            include=cls._skip_ellipsis(include, name),
            exclude=cls._skip_ellipsis(exclude, name),
            relation_map=cls._skip_ellipsis(relation_map, name, default_return=dict()),
        )
        if field.is_multi or field.virtual:
            target = List[target]  # type: ignore
        if field.nullable:
            defaults[name] = None
        return target, defaults

    @classmethod
    def _copy_field_validators(cls, model: Type[pydantic.BaseModel]) -> None:
        """
        Copy field validators from ormar model to generated pydantic model.
        """
        filed_names = list(model.model_fields.keys())
        cls.copy_selected_validators_type(
            model=model, fields=filed_names, validator_type="field_validators"
        )
        cls.copy_selected_validators_type(
            model=model, fields=filed_names, validator_type="validators"
        )

        class_validators = cls.__pydantic_decorators__.root_validators
        model.__pydantic_decorators__.root_validators.update(
            copy.deepcopy(class_validators)
        )
        model_validators = cls.__pydantic_decorators__.model_validators
        model.__pydantic_decorators__.model_validators.update(
            copy.deepcopy(model_validators)
        )
        model.model_rebuild(force=True)

    @classmethod
    def copy_selected_validators_type(
        cls, model: Type[pydantic.BaseModel], fields: List[str], validator_type: str
    ) -> None:
        """
        Copy field validators from ormar model to generated pydantic model.
        """
        validators = getattr(cls.__pydantic_decorators__, validator_type)
        for name, decorator in validators.items():
            if any(field_name in decorator.info.fields for field_name in fields):
                copied_decorator = copy.deepcopy(decorator)
                copied_decorator.info.fields = [
                    field_name
                    for field_name in decorator.info.fields
                    if field_name in fields
                ]
                getattr(model.__pydantic_decorators__, validator_type)[
                    name
                ] = copied_decorator