collerek/ormar

View on GitHub
ormar/queryset/queries/prefetch_query.py

Summary

Maintainability
A
2 hrs
Test Coverage
import abc
import logging
from abc import abstractmethod
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    Sequence,
    Tuple,
    Type,
    Union,
    cast,
)

import ormar  # noqa:  I100, I202
from ormar.queryset.clause import QueryClause
from ormar.queryset.queries.query import Query
from ormar.queryset.utils import translate_list_to_dict

if TYPE_CHECKING:  # pragma: no cover
    from ormar import ForeignKeyField, Model
    from ormar.models.excludable import ExcludableItems
    from ormar.queryset import FilterAction, OrderAction

logger = logging.getLogger(__name__)


class UniqueList(list):
    """
    Simple subclass of list that prevents the duplicates
    Cannot use set as the order is important
    """

    def append(self, item: Any) -> None:
        if item not in self:
            super().append(item)


class Node(abc.ABC):
    """
    Base Node use to build a query tree and divide job into already loaded models
    and the ones that still need to be fetched from database
    """

    def __init__(self, relation_field: "ForeignKeyField", parent: "Node") -> None:
        self.parent = parent
        self.children: List["Node"] = []
        if self.parent:
            self.parent.children.append(self)
        self.relation_field = relation_field
        self.table_prefix = ""
        self.rows: List = []
        self.models: List["Model"] = []
        self.use_alias: bool = False

    @property
    def target_name(self) -> str:
        """
        Return the name of the relation that is used to
        fetch excludes/includes from the excludable mixin
        as well as specifying the target to join in m2m relations

        :return: name of the relation
        :rtype: str
        """
        if (
            self.relation_field.self_reference
            and self.relation_field.self_reference_primary == self.relation_field.name
        ):
            return self.relation_field.default_source_field_name()
        else:
            return self.relation_field.default_target_field_name()

    @abstractmethod
    def extract_related_ids(self, column_name: str) -> List:  # pragma: no cover
        pass

    @abstractmethod
    def reload_tree(self) -> None:  # pragma: no cover
        pass

    @abstractmethod
    async def load_data(self) -> None:  # pragma: no cover
        pass

    def get_filter_for_prefetch(self) -> List["FilterAction"]:
        """
        Populates where clause with condition to return only models within the
        set of extracted ids.
        If there are no ids for relation the empty list is returned.

        :return: list of filter clauses based on original models
        :rtype: List[sqlalchemy.sql.elements.TextClause]
        """
        column_name = self.relation_field.get_model_relation_fields(
            self.parent.use_alias
        )

        ids = self.parent.extract_related_ids(column_name=column_name)

        if ids:
            return self._prepare_filter_clauses(ids=ids)
        return []

    def _prepare_filter_clauses(self, ids: List) -> List["FilterAction"]:
        """
        Gets the list of ids and construct a list of filter queries on
        extracted appropriate column names

        :param ids: list of ids that should be used to fetch data
        :type ids: List
        :return: list of filter actions to use in query
        :rtype: List["FilterAction"]
        """
        clause_target = self.relation_field.get_filter_clause_target()
        filter_column = self.relation_field.get_related_field_alias()
        qryclause = QueryClause(
            model_cls=clause_target,
            select_related=[],
            filter_clauses=[],
        )
        kwargs = {f"{cast(str, filter_column)}__in": ids}
        filter_clauses, _ = qryclause.prepare_filter(_own_only=False, **kwargs)
        return filter_clauses


class AlreadyLoadedNode(Node):
    """
    Node that was already loaded in select statement
    """

    def __init__(self, relation_field: "ForeignKeyField", parent: "Node") -> None:
        super().__init__(relation_field=relation_field, parent=parent)
        self.use_alias = False
        self._extract_own_models()

    def _extract_own_models(self) -> None:
        """
        Extract own models that were already fetched and attached to root node
        """
        for model in self.parent.models:
            child_models = getattr(model, self.relation_field.name)
            if isinstance(child_models, list):
                self.models.extend(child_models)
            elif child_models:
                self.models.append(child_models)

    async def load_data(self) -> None:
        """
        Triggers a data load in the child nodes
        """
        for child in self.children:
            await child.load_data()

    def reload_tree(self) -> None:
        """
        After data was loaded we reload whole tree from the bottom
        to include freshly loaded nodes
        """
        for child in self.children:
            child.reload_tree()

    def extract_related_ids(self, column_name: str) -> List:
        """
        Extracts the selected column(s) values from own models.
        Those values are used to construct filter clauses and populate child models.

        :param column_name: names of the column(s) that holds the relation info
        :type column_name: Union[str, List[str]]
        :return: List of extracted values of relation columns
        :rtype: List
        """
        list_of_ids = UniqueList()
        for model in self.models:
            child = getattr(model, column_name)
            if isinstance(child, ormar.Model):
                list_of_ids.append(child.pk)
            elif child is not None:
                list_of_ids.append(child)
        return list_of_ids


class RootNode(AlreadyLoadedNode):
    """
    Root model Node from which both main and prefetch query originated
    """

    def __init__(self, models: List["Model"]) -> None:
        self.models = models
        self.use_alias = False
        self.children = []

    def reload_tree(self) -> None:
        for child in self.children:
            child.reload_tree()


class LoadNode(Node):
    """
    Nodes that actually need to be fetched from database in the prefetch query
    """

    def __init__(
        self,
        relation_field: "ForeignKeyField",
        excludable: "ExcludableItems",
        orders_by: List["OrderAction"],
        parent: "Node",
        source_model: Type["Model"],
    ) -> None:
        super().__init__(relation_field=relation_field, parent=parent)
        self.excludable = excludable
        self.exclude_prefix: str = ""
        self.orders_by = orders_by
        self.use_alias = True
        self.grouped_models: Dict[Any, List["Model"]] = dict()
        self.source_model = source_model

    async def load_data(self) -> None:
        """
        Ensures that at least primary key columns from current model are included in
        the query.

        Gets the filter values from the parent model and runs the query.

        Triggers a data load in child tasks.
        """
        self._update_excludable_with_related_pks()
        if self.relation_field.is_multi:
            query_target = self.relation_field.through
            select_related = [self.target_name]
        else:
            query_target = self.relation_field.to
            select_related = []

        filter_clauses = self.get_filter_for_prefetch()

        if filter_clauses:
            qry = Query(
                model_cls=query_target,
                select_related=select_related,
                filter_clauses=filter_clauses,
                exclude_clauses=[],
                offset=None,
                limit_count=None,
                excludable=self.excludable,
                order_bys=self._extract_own_order_bys(),
                limit_raw_sql=False,
            )
            expr = qry.build_select_expression()
            logger.debug(
                expr.compile(
                    dialect=self.source_model.ormar_config.database._backend._dialect,
                    compile_kwargs={"literal_binds": True},
                )
            )
            self.rows = await query_target.ormar_config.database.fetch_all(expr)

            for child in self.children:
                await child.load_data()

    def _update_excludable_with_related_pks(self) -> None:
        """
        Makes sure that excludable is populated with own model primary keys values
        if the excludable has the exclude/include clauses
        """
        related_field_names = self.relation_field.get_related_field_name()
        alias_manager = self.relation_field.to.ormar_config.alias_manager
        relation_key = self._build_relation_key()
        self.exclude_prefix = alias_manager.resolve_relation_alias_after_complex(
            source_model=self.source_model,
            relation_str=relation_key,
            relation_field=self.relation_field,
        )
        if self.relation_field.is_multi:
            self.table_prefix = self.exclude_prefix
        target_model = self.relation_field.to
        model_excludable = self.excludable.get(
            model_cls=target_model, alias=self.exclude_prefix
        )
        # includes nested pks if not included already
        for related_name in related_field_names:
            if model_excludable.include and not model_excludable.is_included(
                related_name
            ):
                model_excludable.set_values({related_name}, is_exclude=False)

    def _build_relation_string(self) -> str:
        node: Union[LoadNode, Node] = self
        relation = node.relation_field.name
        while not isinstance(node.parent, RootNode):
            relation = node.parent.relation_field.name + "__" + relation
            node = node.parent
        return relation

    def _build_relation_key(self) -> str:
        relation_key = self._build_relation_string()
        return relation_key

    def _extract_own_order_bys(self) -> List["OrderAction"]:
        """
        Extracts list of order actions related to current model.
        Since same model can happen multiple times in a tree we check not only the
        match on given model but also that path from relation tree matches the
        path in order action.

        :return: list of order actions related to current model
        :rtype: List[OrderAction]
        """
        own_order_bys = []
        own_path = self._get_full_tree_path()
        for order_by in self.orders_by:
            if (
                order_by.target_model == self.relation_field.to
                and order_by.related_str.endswith(f"{own_path}")
            ):
                order_by.is_source_model_order = True
                order_by.table_prefix = self.table_prefix
                own_order_bys.append(order_by)
        return own_order_bys

    def _get_full_tree_path(self) -> str:
        """
        Iterates the nodes to extract path from root node.

        :return: path from root node
        :rtype: str
        """
        node: Node = self
        relation_str = node.relation_field.name
        while not isinstance(node.parent, RootNode):
            node = node.parent
            relation_str = f"{node.relation_field.name}__{relation_str}"
        return relation_str

    def extract_related_ids(self, column_name: str) -> List:
        """
        Extracts the selected column(s) values from own models.
        Those values are used to construct filter clauses and populate child models.

        :param column_names: names of the column(s) that holds the relation info
        :type column_names: Union[str, List[str]]
        :return: List of extracted values of relation columns
        :rtype: List
        """
        column_name = self._prefix_column_names_with_table_prefix(
            column_name=column_name
        )
        return self._extract_simple_relation_keys(column_name=column_name)

    def _prefix_column_names_with_table_prefix(self, column_name: str) -> str:
        return (f"{self.table_prefix}_" if self.table_prefix else "") + column_name

    def _extract_simple_relation_keys(self, column_name: str) -> List:
        """
        Extracts simple relation keys values.

        :param column_name: names of the column(s) that holds the relation info
        :type column_name: str
        :return: List of extracted values of relation columns
        :rtype: List
        """
        list_of_ids = UniqueList()
        for row in self.rows:
            if row[column_name]:
                list_of_ids.append(row[column_name])
        return list_of_ids

    def reload_tree(self) -> None:
        """
        Instantiates models from loaded database rows.
        Groups those instances by relation key for easy extract per parent.
        Triggers same for child nodes and then populates
        the parent node with own related models
        """
        if self.rows:
            self._instantiate_models()
            self._group_models_by_relation_key()
            for child in self.children:
                child.reload_tree()
            self._populate_parent_models()

    def _instantiate_models(self) -> None:
        """
        Iterates the rows and initializes instances of ormar.Models.
        Each model is instantiated only once (they can be duplicates for m2m relation
        when multiple parent models refer to same child model since the query have to
        also include the through model - hence full rows are unique, but related
        models without through models can be not unique).
        """
        fields_to_exclude = self.relation_field.to.get_names_to_exclude(
            excludable=self.excludable, alias=self.exclude_prefix
        )
        parsed_rows: Dict[Tuple, "Model"] = {}
        for row in self.rows:
            item = self.relation_field.to.extract_prefixed_table_columns(
                item={},
                row=row,
                table_prefix=self.table_prefix,
                excludable=self.excludable,
            )
            hashable_item = self._hash_item(item)
            instance = parsed_rows.setdefault(
                hashable_item,
                self.relation_field.to(**item, **{"__excluded__": fields_to_exclude}),
            )
            self.models.append(instance)

    def _hash_item(self, item: Dict) -> Tuple:
        """
        Converts model dictionary into tuple to make it hashable and allow to use it
        as a dictionary key - used to ensure unique instances of related models.

        :param item: instance dictionary
        :type item: Dict
        :return: tuple out of model dictionary
        :rtype: Tuple
        """
        result = []
        for key, value in sorted(item.items()):
            result.append(
                (key, self._hash_item(value) if isinstance(value, dict) else value)
            )
        return tuple(result)

    def _group_models_by_relation_key(self) -> None:
        """
        Groups own models by relation keys so it's easy later to extract those models
        when iterating parent models. Note that order is important as it reflects
        order by issued by the user.
        """
        relation_key = self.relation_field.get_related_field_alias()
        for index, row in enumerate(self.rows):
            key = row[relation_key]
            current_group = self.grouped_models.setdefault(key, [])
            current_group.append(self.models[index])

    def _populate_parent_models(self) -> None:
        """
        Populate parent node models with own child models from grouped dictionary
        """
        relation_key = self._get_relation_key_linking_models()
        for model in self.parent.models:
            children = self._get_own_models_related_to_parent(
                model=model, relation_key=relation_key
            )
            for child in children:
                setattr(model, self.relation_field.name, child)

    def _get_relation_key_linking_models(self) -> Tuple[str, str]:
        """
        Extract names and aliases of relation columns to use
        in linking between own models and parent models

        :return: tuple of name and alias of relation columns
        :rtype: List[Tuple[str, str]]
        """
        column_name = self.relation_field.get_model_relation_fields(False)
        column_alias = self.relation_field.get_model_relation_fields(True)
        return column_name, column_alias

    def _get_own_models_related_to_parent(
        self, model: "Model", relation_key: Tuple[str, str]
    ) -> List["Model"]:
        """
        Extracts related column values from parent and based on this key gets the
        own grouped models.

        :param model: parent model from parent node
        :type model: Model
        :param relation_key: name and aliases linking relations
        :type relation_key: List[Tuple[str, str]]
        :return: list of own models to set on parent
        :rtype: List[Model]
        """
        column_name, column_alias = relation_key
        model_value = getattr(model, column_name)
        if isinstance(model_value, ormar.Model):
            model_value = model_value.pk
        return self.grouped_models.get(model_value, [])


class PrefetchQuery:
    """
    Query used to fetch related models in subsequent queries.
    Each model is fetched only ones by the name of the relation.
    That means that for each prefetch_related entry next query is issued to database.
    """

    def __init__(  # noqa: CFQ002
        self,
        model_cls: Type["Model"],
        excludable: "ExcludableItems",
        prefetch_related: List,
        select_related: List,
        orders_by: List["OrderAction"],
    ) -> None:
        self.model = model_cls
        self.excludable = excludable
        self.select_dict = translate_list_to_dict(select_related, default={})
        self.prefetch_dict = translate_list_to_dict(prefetch_related, default={})
        self.orders_by = orders_by
        self.load_tasks: List[Node] = []

    async def prefetch_related(self, models: Sequence["Model"]) -> Sequence["Model"]:
        """
        Main entry point for prefetch_query.

        Receives list of already initialized parent models with all children from
        select_related already populated. Receives also list of row sql result rows
        as it's quicker to extract ids that way instead of calling each model.

        Returns list with related models already prefetched and set.

        :param models: list of already instantiated models from main query
        :type models: Sequence[Model]
        :param rows: row sql result of the main query before the prefetch
        :type rows: List[sqlalchemy.engine.result.RowProxy]
        :return: list of models with children prefetched
        :rtype: List[Model]
        """
        parent_task = RootNode(models=cast(List["Model"], models))
        self._build_load_tree(
            prefetch_dict=self.prefetch_dict,
            select_dict=self.select_dict,
            parent=parent_task,
            model=self.model,
        )
        await parent_task.load_data()
        parent_task.reload_tree()
        return parent_task.models

    def _build_load_tree(
        self,
        select_dict: Dict,
        prefetch_dict: Dict,
        parent: Node,
        model: Type["Model"],
    ) -> None:
        """
        Build a tree of already loaded nodes and nodes that need
        to be loaded through the prefetch query.

        :param select_dict: dictionary wth select query structure
        :type select_dict: Dict
        :param prefetch_dict: dictionary with prefetch query structure
        :type prefetch_dict: Dict
        :param parent: parent Node
        :type parent: Node
        :param model: currently processed model
        :type model: Model
        """
        for related in prefetch_dict.keys():
            relation_field = cast(
                "ForeignKeyField", model.ormar_config.model_fields[related]
            )
            if related in select_dict:
                task: Node = AlreadyLoadedNode(
                    relation_field=relation_field, parent=parent
                )
            else:
                task = LoadNode(
                    relation_field=relation_field,
                    excludable=self.excludable,
                    orders_by=self.orders_by,
                    parent=parent,
                    source_model=self.model,
                )
            if prefetch_dict:
                self._build_load_tree(
                    select_dict=select_dict.get(related, {}),
                    prefetch_dict=prefetch_dict.get(related, {}),
                    parent=task,
                    model=model.ormar_config.model_fields[related].to,
                )