vbudovski/django-tree

View on GitHub
django_tree/models.py

Summary

Maintainability
A
1 hr
Test Coverage
A
100%
# Copyright [2020] [Vitaly Budovski]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from collections import OrderedDict

from django.core.exceptions import ObjectDoesNotExist
from django.db import models
from django.db import transaction


class BaseTreeNodeManager(models.Manager):
    def in_order(self):
        table_name = self.model._meta.db_table
        query = f"""
            SELECT nodes.*, depth, index
            FROM
            {table_name} nodes,
            (
                WITH RECURSIVE depth_cte AS (
                    SELECT id, parent_id, 0 as depth
                    FROM {table_name}
                    WHERE parent_id IS NULL
                    UNION ALL
                    SELECT
                    t.id,
                    t.parent_id,
                    CASE
                        WHEN t.parent_id = depth_cte.id THEN depth_cte.depth + 1
                        ELSE 0
                    END AS depth
                    FROM {table_name} t
                    INNER JOIN depth_cte ON depth_cte.id = t.parent_id
                )
                SELECT id, depth
                FROM depth_cte
            ) depth_cte,
            (
                WITH RECURSIVE index_cte AS (
                    SELECT id, previous_id, 0 as index
                    FROM {table_name}
                    WHERE previous_id IS NULL
                    UNION ALL
                    SELECT
                    t.id,
                    t.previous_id,
                    CASE
                        WHEN t.previous_id = index_cte.id THEN index_cte.index + 1
                        ELSE 0
                    END AS index
                    FROM {table_name} t
                    INNER JOIN index_cte ON index_cte.id = t.previous_id
                )
                SELECT id, index
                FROM index_cte
            ) index_cte
            WHERE nodes.id = depth_cte.id
            AND depth_cte.id = index_cte.id
            ORDER BY depth, index, id
        """

        return self.get_queryset().raw(query)

    def build_tree(self) -> OrderedDict:
        ordered_nodes = self.in_order()

        node_tree = OrderedDict()
        paths = OrderedDict()
        for node in ordered_nodes:
            if node.parent_id is None:
                paths[node.pk] = [node.pk]
            else:
                paths[node.pk] = paths[node.parent_id] + [node.pk]

            insert_into = node_tree
            for node_id in paths[node.pk]:
                if node_id in insert_into:
                    insert_into = insert_into[node_id]["children"]
                else:
                    insert_into[node_id] = {
                        "node": node,
                        "children": OrderedDict(),
                    }

        return node_tree

    @transaction.atomic
    def _remove(self, this: "BaseTreeNode"):
        # Save the nodes on either side of 'this' node.
        this_previous = this.previous
        try:
            this_next = this.next
        except ObjectDoesNotExist:
            this_next = None

        # Unlink 'this' node from the tree.
        this.previous = None
        this.parent = None
        this.save(update_fields=["previous", "parent"])

        # Join the nodes on either side of 'this' to bridge the gap.
        if this_next:
            this_next.previous = this_previous
            this_next.save(update_fields=["previous"])

    @transaction.atomic
    def insert_before(self, before: "BaseTreeNode", this: "BaseTreeNode"):
        self._remove(this)

        # Link 'this' before 'before'.
        before_previous = before.previous
        before.previous = this
        before.save(update_fields=["previous"])
        this.previous = before_previous
        this.parent = before.parent
        this.save(update_fields=["previous", "parent"])

    @transaction.atomic
    def insert_after(self, after: "BaseTreeNode", this: "BaseTreeNode"):
        self._remove(this)

        # Link 'this' after 'after.
        try:
            after_next = after.next
        except ObjectDoesNotExist:
            pass
        else:
            after_next.previous = this
            after_next.save(update_fields=["previous"])

        this.previous = after
        this.parent = after.parent
        this.save(update_fields=["previous", "parent"])


class BaseTreeNode(models.Model):
    objects = BaseTreeNodeManager()

    parent = models.ForeignKey(
        "self", related_name="children", null=True, on_delete=models.PROTECT
    )
    previous = models.OneToOneField(
        "self", related_name="next", null=True, on_delete=models.PROTECT
    )

    class Meta:
        abstract = True