rrebase/knboard

View on GitHub
backend/boards/api.py

Summary

Maintainability
A
0 mins
Test Coverage
from itertools import chain

from django.contrib.auth import get_user_model
from django.db import transaction
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework import mixins, viewsets
from rest_framework import filters
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.status import HTTP_400_BAD_REQUEST, HTTP_200_OK
from rest_framework.views import APIView
from rest_framework.viewsets import GenericViewSet
from django.db.models import Q
from django.db.models import Prefetch

from .models import Board, Task, Column, Label, Comment
from .permissions import IsOwner, IsOwnerForDangerousMethods
from .serializers import (
    BoardSerializer,
    TaskSerializer,
    ColumnSerializer,
    BoardDetailSerializer,
    MemberSerializer,
    BoardMemberSerializer,
    LabelSerializer,
    CommentSerializer,
)
from .viewsets import ModelDetailViewSet

User = get_user_model()


class BoardViewSet(
    mixins.CreateModelMixin,
    mixins.RetrieveModelMixin,
    mixins.UpdateModelMixin,
    mixins.DestroyModelMixin,
    mixins.ListModelMixin,
    GenericViewSet,
):
    queryset = Board.objects.all()
    serializer_class = BoardSerializer
    permission_classes = [IsAuthenticated, IsOwnerForDangerousMethods]

    def perform_create(self, serializer):
        serializer.save(owner=self.request.user)

    def get_serializer_class(self):
        if self.action == "retrieve":
            return BoardDetailSerializer
        return super().get_serializer_class()

    def get_queryset(self):
        user = self.request.user
        qs = super().get_queryset().filter(members=user)
        assignees = self.request.query_params.get("assignees", None)
        if self.action == "retrieve":
            queryset = None
            if assignees:
                queryset = (
                    Task.objects.filter(
                        Q(assignees__in=[int(x) for x in assignees.split(",")])
                    )
                    .order_by("id")
                    .distinct("id")
                )
            return qs.prefetch_related(Prefetch("columns__tasks", queryset=queryset))
        return qs

    def get_member(self):
        try:
            member = User.objects.get(username=self.request.data.get("username"))
        except User.DoesNotExist:
            return None

        return member

    @action(
        detail=True,
        methods=["post"],
        serializer_class=MemberSerializer,
        permission_classes=[IsAuthenticated, IsOwner],
    )
    def invite_member(self, request, pk):
        users_ids = self.request.data.get("users")
        if not users_ids:
            return Response(status=HTTP_400_BAD_REQUEST)

        new_members = User.objects.filter(id__in=users_ids)
        if len(new_members) != len(users_ids):
            return Response(status=HTTP_400_BAD_REQUEST)

        self.get_object().members.add(*new_members)
        return Response(
            data=BoardMemberSerializer(instance=new_members, many=True).data
        )

    @action(detail=True, methods=["post"], serializer_class=MemberSerializer)
    def remove_member(self, request, pk):
        member = self.get_member()
        board = self.get_object()

        if not member or member == board.owner or board not in member.boards.all():
            return Response(status=HTTP_400_BAD_REQUEST)

        board.members.remove(member)
        for task in Task.objects.filter(column__board=board):
            task.assignees.remove(member)
        return Response(data=BoardMemberSerializer(instance=member).data)


class TaskViewSet(ModelDetailViewSet):
    queryset = Task.objects.all()
    serializer_class = TaskSerializer
    permission_classes = [IsAuthenticated]

    def get_queryset(self):
        user = self.request.user
        return super().get_queryset().filter(column__board__members=user)


class CommentViewSet(
    mixins.CreateModelMixin,
    mixins.DestroyModelMixin,
    mixins.RetrieveModelMixin,
    mixins.ListModelMixin,
    viewsets.GenericViewSet,
):
    queryset = Comment.objects.all()
    serializer_class = CommentSerializer
    permission_classes = [IsAuthenticated]
    filter_backends = [DjangoFilterBackend]
    filterset_fields = ["task"]

    def get_queryset(self):
        return (
            super()
            .get_queryset()
            .filter(task__column__board__members=self.request.user)
        )

    def create(self, request, *args, **kwargs):
        request.data.update(dict(author=request.user.id))

        if (
            self.request.user
            not in Task.objects.get(
                id=request.data.get("task")
            ).column.board.members.all()
        ):
            return Response(status=HTTP_400_BAD_REQUEST)

        return super().create(request, *args, **kwargs)

    def destroy(self, request, *args, **kwargs):
        request.data.update(dict(author=request.user.id))

        if self.request.user != self.get_object().author:
            return Response(status=HTTP_400_BAD_REQUEST)

        return super().destroy(request, *args, **kwargs)


class ColumnViewSet(ModelDetailViewSet):
    queryset = Column.objects.all()
    serializer_class = ColumnSerializer
    permission_classes = [IsAuthenticated]

    def get_queryset(self):
        user = self.request.user
        return super().get_queryset().filter(board__members=user)


class LabelViewSet(ModelDetailViewSet):
    queryset = Label.objects.all()
    serializer_class = LabelSerializer
    permission_classes = [IsAuthenticated]

    def get_queryset(self):
        user = self.request.user
        return super().get_queryset().filter(board__members=user)


class SortColumn(APIView):
    permission_classes = [IsAuthenticated]

    @transaction.atomic
    def post(self, request, **kwargs):
        try:
            return sort_model(request, Column)
        except (
            KeyError,
            IndexError,
            AttributeError,
            ValueError,
            Column.DoesNotExist,
        ):
            return Response(status=HTTP_400_BAD_REQUEST)


class SortTask(APIView):
    permission_classes = [IsAuthenticated]

    def move_tasks(self, request):
        tasks_by_column = request.data.get("tasks")
        board_id = request.data.get("board")
        board = Board.objects.get(id=board_id)
        pre_columns = Column.objects.filter(board=board)
        pre_tasks = Task.objects.filter(column__in=pre_columns).prefetch_related(
            "columns"
        )

        # Check for duplicate tasks
        flat_tasks = list(chain.from_iterable(tasks_by_column.values()))
        if len(flat_tasks) != len(set(flat_tasks)):
            raise ValueError

        for column_name, task_ids in tasks_by_column.items():
            column = pre_columns.get(id=column_name)
            tasks = pre_tasks.filter(pk__in=task_ids)
            tasks.update(column=column)

    def post(self, request, **kwargs):
        try:
            with transaction.atomic():
                self.move_tasks(request)
                return sort_model(request, Task)
        except (
            KeyError,
            IndexError,
            AttributeError,
            ValueError,
            Column.DoesNotExist,
            Task.DoesNotExist,
        ):
            return Response(status=HTTP_400_BAD_REQUEST)


def sort_model(request, Model):
    ordered_pks = request.data.get("order", [])

    # Check for duplicates
    if len(ordered_pks) != len(set(ordered_pks)):
        return Response(status=HTTP_400_BAD_REQUEST)

    objects_dict = dict(
        [(str(obj.pk), obj) for obj in Model.objects.filter(pk__in=ordered_pks)]
    )
    order_field_name = Model._meta.ordering[0]

    step = 1
    start_object = min(
        objects_dict.values(), key=lambda x: getattr(x, order_field_name)
    )
    start_index = getattr(start_object, order_field_name, len(ordered_pks))

    for pk in ordered_pks:
        obj = objects_dict.get(str(pk))

        # perform the update only if the order field has changed
        if getattr(obj, order_field_name) != start_index:
            setattr(obj, order_field_name, start_index)
            # only update the object's order field
            obj.save(update_fields=[order_field_name])

        start_index += step

    return Response(status=HTTP_200_OK)