digitalfabrik/integreat-cms

View on GitHub
integreat_cms/cms/utils/repair_tree.py

Summary

Maintainability
A
0 mins
Test Coverage
A
96%
"""
This module contains utilities to repair or detect inconsistencies in a tree
"""

from __future__ import annotations

import logging
from collections import deque
from typing import TYPE_CHECKING

from django.apps import apps
from django.db import transaction

from .shadow_instance import ShadowInstance

if TYPE_CHECKING:
    from typing import Iterable

    from django.apps.registry import Apps
    from django.db.models import Model

    Page: Model = apps.get_model("cms", "Page")


@transaction.atomic
def repair_tree(
    page_id: Iterable[int] | None = None,
    commit: bool = False,
    logging_name: str = __name__,
    dj_apps: Apps = apps,
) -> None:
    """
    Fix the tree for a given page, or all trees if no id is given.
    Changes are only written to the database if ``commit`` is set to ``True``.
    For details see :class:`MPTTFixer`.
    """
    logger = logging.getLogger(logging_name)

    # Use get_model() instead of importing so this function can be used in migrations
    Page: Model = dj_apps.get_model("cms", "Page")
    mptt_fixer = MPTTFixer(dj_apps=dj_apps)
    root_nodes: Iterable[ShadowInstance[Page]]

    if page_id:
        # Assert that any of the requested pages actually exist
        for single_id in page_id:
            try:
                Page.objects.get(id=single_id)
            except Page.DoesNotExist as e:
                raise ValueError(
                    f'The page with id "{single_id}" does not exist.'
                ) from e
        # All ids in page_id are valid, get them from mptt_fixer
        root_nodes = {
            mptt_fixer.get_fixed_root_node(single_id) for single_id in page_id
        }
    else:
        root_nodes = mptt_fixer.get_fixed_root_nodes()

    for root_node in root_nodes:
        action = "Fixing" if commit else "Detecting problems in"
        logger.info(
            "%s tree with id %i... (%r)",
            action,
            root_node.tree_id__original,
            root_node.instance,
        )
        for tree_node in mptt_fixer.get_fixed_tree_of_page(root_node.pk):
            print_changed_fields(tree_node, logging_name=logging_name)

    if commit:
        for page in mptt_fixer.get_fixed_tree_nodes():
            page.apply_changes()
            page.save()


def print_changed_fields(
    tree_node: ShadowInstance[Page], logging_name: str = __name__
) -> None:
    """
    Utility function to print changed and unchanged attributes using a
    :class:`~integreat_cms.cms.utils.shadow_instance.ShadowInstance` of the :class:`~integreat_cms.cms.models.pages.page.Page`.
    """
    logger = logging.getLogger(logging_name)

    diff = tree_node.changed_attributes

    logger.info("Page %s:", tree_node.id)
    logger.success("\tparent_id: %s", tree_node.parent_id)  # type: ignore[attr-defined]

    for name in ["tree_id", "depth", "lft", "rgt"]:
        if name in diff:
            logger.error("\t%s: %i → %i", name, diff[name]["old"], diff[name]["new"])
        else:
            logger.success("\t%s: %i", name, getattr(tree_node, name))  # type: ignore[attr-defined]


class MPTTFixer:
    """
    Gets ALL nodes and coughs out fixed ``lft``, ``rgt`` and ``depth`` values.
    Uses the parent field to fix hierarchy and sorts siblings by (potentially inconsistent) ``lft``.
    """

    logger = logging.getLogger(__name__)

    def __init__(self, dj_apps: Apps = apps) -> None:
        """
        Create a fixed tree, using a :class:`~integreat_cms.cms.utils.shadow_instance.ShadowInstance` of each page.
        """
        Page: Model = dj_apps.get_model("cms", "Page")
        self.broken_nodes: deque[ShadowInstance[Page]] = deque(
            ShadowInstance(page) for page in Page.objects.all()
        )
        # A list of root nodes, also determining the new tree_id (index + 1)
        self.trees: list[ShadowInstance[Page]] = []
        # A dictionary of fixed nodes, indexable by primary key
        self.fixed_nodes: dict[int, ShadowInstance[Page]] = {}
        self.recreate_structure()
        self.fix_values()

    def recreate_structure(self) -> None:
        """
        Extract nodes and recreate their hierarchy.
        """
        while self.broken_nodes:
            node = self.broken_nodes.popleft()
            if node.parent_id is None:
                # This is a root node
                self.trees.append(node)
                node.tree_id = len(self.trees)
                node.depth = 1
            elif node.parent_id not in self.fixed_nodes:
                # We don't know the parent yet, put it back in the queue
                self.broken_nodes.append(node)
                continue
            else:
                # Adopt the tree id of the parent and register us as found child
                parent = self.fixed_nodes[node.parent_id]
                node.tree_id = parent.tree_id
                self.fixed_nodes[parent.pk].fixed_children.append(node.pk)
            # This node will have a list of which children have been found
            node.fixed_children = []
            self.fixed_nodes[node.pk] = node

    def get_fixed_root_nodes(self) -> Iterable[ShadowInstance[Page]]:
        """
        Yield all fixed root nodes.
        """
        return tuple(self.trees)

    def get_fixed_root_node(self, page_id: int) -> ShadowInstance[Page]:
        """
        Travel up ancestors of a page until we get the root node.
        """
        page = self.fixed_nodes[page_id]
        while page.parent_id is not None:
            page = self.fixed_nodes[page.parent_id]
        return page

    def get_fixed_tree_nodes(self) -> Iterable[ShadowInstance[Page]]:
        """
        Return all nodes of this page tree.
        """
        return self.fixed_nodes.values()

    def get_fixed_tree_of_page(
        self, node_id: int | None = None
    ) -> Iterable[ShadowInstance[Page]]:
        """
        Yield all nodes of the same page tree as the node specified by id in order (fixed).
        If no ``node_id`` is specified, nodes from all trees are considered.
        """
        tree_ids = (
            [self.fixed_nodes[node_id].tree_id]
            if node_id is not None
            else range(1, len(self.trees) + 1)
        )
        for tree_id in tree_ids:
            node = self.trees[tree_id - 1]
            yield from self.yield_subtree(node.pk)

    def yield_subtree(self, node_id: int) -> Iterable[ShadowInstance[Page]]:
        """
        Yield all nodes of the subtree as the node specified by id in order (fixed).
        """
        node = self.fixed_nodes[node_id]
        yield node
        for child_id in node.fixed_children:
            yield from self.yield_subtree(child_id)

    def fix_values(self) -> None:
        """
        Recalculate ``lft``, ``rgt`` and ``depth`` values for the reconstructed hierarchical structure.
        """
        for root_node in self.trees:
            self.fix_values_on_subtree(root_node.pk, 1, 1)

    def fix_values_on_subtree(self, node_id: int, counter: int, depth: int) -> int:
        """
        Recalculate ``lft``, ``rgt`` and ``depth`` values for the subtree of the reconstructed hierarchical structure.
        """
        node = self.fixed_nodes[node_id]
        node.lft = counter
        node.depth = depth
        # Assure that the child nodes are ordered by old lft value
        node.fixed_children.sort(key=lambda node_id: self.fixed_nodes[node_id].lft)
        for child_id in node.fixed_children:
            counter = self.fix_values_on_subtree(child_id, counter + 1, depth + 1)
        node.rgt = counter + 1
        return counter + 1