alaouimehdi1995/django-rest

View on GitHub
django_rest/decorators/utils.py

Summary

Maintainability
B
4 hrs
Test Coverage
# -*- coding: utf-8 -*-

import json
from functools import wraps

from django.http import JsonResponse
from django.views import View
from django.views.decorators.csrf import csrf_exempt

from django_rest.deserializers import AllPassDeserializer, Deserializer
from django_rest.http.exceptions import (
    BadRequest,
    BaseAPIException,
    InternalServerError,
    MethodNotAllowed,
    PermissionDenied,
    UnsupportedMediaType,
)
from django_rest.http.methods import SUPPORTING_PAYLOAD_METHODS
from django_rest.permissions import BasePermission

FORMS_CONTENT_TYPES = (
    "application/x-www-form-urlencoded",
    "multipart/form-data",
)


def build_deserializer_map(deserializer_class):
    if isinstance(deserializer_class, dict):
        if not all(
            issubclass(elem, Deserializer) for elem in deserializer_class.values()
        ):
            raise TypeError(
                "The values of the `deserializer_class` mapping "
                "should be subclass of `Deserializer`."
            )
        return {
            http_method: deserializer_class.get(http_method, AllPassDeserializer)
            for http_method in SUPPORTING_PAYLOAD_METHODS
        }
    if issubclass(deserializer_class, Deserializer):
        return {
            http_method: deserializer_class
            for http_method in SUPPORTING_PAYLOAD_METHODS
        }

    raise TypeError(
        "`deserializer_class` parameter should be either a subclass of `Deserializer` "
        "or a dict of `Deserializer` classes, not {}".format(
            deserializer_class.__name__
        )
    )


def transform_query_dict_into_regular_dict(query_dict):
    return {
        key: list_val if len(list_val) > 1 else list_val[0]
        for key, list_val in dict(query_dict).items()
    }


def extract_request_payload(request, allow_form_data=False):
    """
    Return the request's content (form or data).
    For `POST`, `PUT` and `PATCH` requests:
    - If the form data is allowed and the request is a form, returns a `dict`
    - If the form data isn't allowed and the request is a form, raises
      UnsupportedMediaType exception.
    - For application/json requests, returns a `dict` containing the request's
      data
    For other HTTP methods:
    - Returns `None`
    """
    method = request.method
    if not allow_form_data and request.content_type in FORMS_CONTENT_TYPES:
        raise UnsupportedMediaType
    elif request.content_type in FORMS_CONTENT_TYPES and request.method == "POST":
        return transform_query_dict_into_regular_dict(request.POST)

    if method not in SUPPORTING_PAYLOAD_METHODS:
        return None

    return json.loads(request.body.decode())


def build_function_wrapper(
    permission_class,  # type: BasePermission
    allowed_methods,  # type: Iterable[str]
    deserializers_http_methods_map,  # type: Dict[str, ClassVar[Deserializer]]
    allow_forms,  # type: bool
    view_function,  # Callable
):  # type:(...) -> Callable
    @wraps(view_function)
    def function_wrapper(request, *args, **kwargs):
        # type:(HttpRequest, List[Any]) -> JsonResponse
        try:
            if request.method not in allowed_methods:
                raise MethodNotAllowed
            if not permission_class().has_permission(request, view_function):
                raise PermissionDenied
            query_params = transform_query_dict_into_regular_dict(request.GET)
            payload = extract_request_payload(request, allow_forms)
            if request.method in SUPPORTING_PAYLOAD_METHODS:
                deserializer = deserializers_http_methods_map[request.method](
                    data=payload
                )
                if not deserializer.is_valid():
                    raise BadRequest
                deserialized_data = deserializer.data
            else:
                deserialized_data = None

            return view_function(
                request,
                url_params=kwargs,
                query_params=query_params,
                deserialized_data=deserialized_data,
            )
        except BaseAPIException as e:
            return JsonResponse({"error_msg": e.RESPONSE_MESSAGE}, status=e.STATUS_CODE)
        except Exception:
            return JsonResponse(
                {"error_msg": InternalServerError.RESPONSE_MESSAGE},
                status=InternalServerError.STATUS_CODE,
            )

    return csrf_exempt(function_wrapper)


def build_class_wrapper(
    permission_class,  # type: BasePermission
    allowed_methods,  # type: Iterable[str]
    deserializers_http_methods_map,  # type: Dict[str, ClassVar[Deserializer]]
    allow_forms,  # type: bool
    view_class,  # type: ClassVar
):  # type:(...) -> ClassVar
    class ViewWrapper(View):
        __doc__ = view_class.__doc__
        __module__ = view_class.__module__
        __slots__ = "_wrapped_view"

        def __init__(self, *args, **kwargs):
            self._wrapped_view = view_class(*args, **kwargs)
            self.compiled_dispatch = build_function_wrapper(
                permission_class,
                allowed_methods,
                deserializers_http_methods_map,
                allow_forms,
                self._handle_request,
            )

        def _handle_request(self, request, *args, **kwargs):
            # type:(HttpRequest, List[Any]) -> JsonResponse
            # No need for additional check on request.method, since it's been
            # already checked
            handler = getattr(self, request.method.lower())
            return handler(request, *args, **kwargs)

        @csrf_exempt
        def dispatch(self, *args, **kwargs):
            # type:(HttpRequest, List[Any]) -> JsonResponse
            return self.compiled_dispatch(*args, **kwargs)

        def __getattribute__(self, name):
            try:
                return super(ViewWrapper, self).__getattribute__(name)
            except AttributeError:
                attribute = self._wrapped_view.__getattribute__(name)
                return attribute

    ViewWrapper.__name__ = view_class.__name__
    return ViewWrapper