winter_openapi/generator.py

Summary

Maintainability
C
1 day
Test Coverage
import inspect
import warnings
from itertools import groupby
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import Union

from django.http.response import HttpResponseBase
from openapi_pydantic import schema_validate
from openapi_pydantic.v3.v3_0_3 import Components
from openapi_pydantic.v3.v3_0_3 import Info
from openapi_pydantic.v3.v3_0_3 import MediaType as MediaTypeModel
from openapi_pydantic.v3.v3_0_3 import OpenAPI
from openapi_pydantic.v3.v3_0_3 import Operation
from openapi_pydantic.v3.v3_0_3 import Parameter
from openapi_pydantic.v3.v3_0_3 import PathItem
from openapi_pydantic.v3.v3_0_3 import Paths
from openapi_pydantic.v3.v3_0_3 import Reference
from openapi_pydantic.v3.v3_0_3 import RequestBody
from openapi_pydantic.v3.v3_0_3 import Response
from openapi_pydantic.v3.v3_0_3 import Responses
from openapi_pydantic.v3.v3_0_3 import Schema
from openapi_pydantic.v3.v3_0_3 import Server
from openapi_pydantic.v3.v3_0_3 import Tag

from winter.core import ComponentMethod
from winter.web import MediaType
from winter.web.default_response_status import get_response_status
from winter.web.exceptions import MethodExceptionsManager
from winter.web.request_body_annotation import RequestBodyAnnotation
from winter.web.routing import Route
from winter.web.routing import RouteAnnotation
from winter_openapi.inspection.inspection import inspect_type
from .inspection import DataTypes
from .inspection import TypeInfo
from .inspection.inspection import InspectorNotFound
from .inspectors import get_route_parameters_inspectors


class SchemaRegistry:
    def __init__(self):
        self._schemas: Dict[Tuple[Type, bool], Schema] = {}
        self._types_by_titles: Dict[str, Type] = {}

    def get_schema_or_reference(self, type_info: Union[Type, TypeInfo], output: bool) -> Union[Schema, Reference]:
        if not isinstance(type_info, TypeInfo):
            type_info = inspect_type(type_info)
        schema = self._schemas.get((type_info.hint_class, output))
        if not schema:
            schema = self.build_schema(type_info, output=output)
            if not schema.title:
                return schema
            if schema.nullable:
                schema.nullable = None
            self._schemas[type_info.hint_class, output] = schema

        if schema.title not in self._types_by_titles:
            self._types_by_titles[schema.title] = type_info.hint_class
        elif self._types_by_titles[schema.title] != type_info.hint_class:
            raise ValueError(f'Title {schema.title} for type {type_info.hint_class} is already used for another type: {self._types_by_titles[schema.title]}')

        reference = Reference(ref='#/components/schemas/' + schema.title)

        if type_info.nullable and schema.type == 'object':
            # https://stackoverflow.com/questions/40920441/how-to-specify-a-property-can-be-null-or-a-reference-with-swagger
            # Better solution, but not implemented yet https://github.com/OpenAPITools/openapi-generator/issues/9083
            return Schema(nullable=True, allOf=[reference])

        return reference

    def get_schemas(self) -> Dict[str, Schema]:
        return {
            schema.title: schema
            for schema in self._schemas.values()
        }

    def build_schema(self, type_info: TypeInfo, *, output: bool) -> Schema:
        if type_info.type_ == DataTypes.ANY:
            return Schema(
                description='Can be any value - string, number, boolean, array or object.',
                nullable=type_info.nullable,
            )

        data = {
            'type': type_info.type_,
        }

        if type_info.nullable:
            data['nullable'] = True

        if type_info.title:
            data['title'] = type_info.title if output else f'{type_info.title}Input'

        if type_info.description:
            data['description'] = type_info.description

        if type_info.format_ is not None:
            data['schema_format'] = type_info.format_

        if type_info.child is not None:
            data['items'] = self.get_schema_or_reference(type_info.child, output=output)

        if type_info.enum is not None:
            data['enum'] = type_info.enum

        if type_info.properties:
            sorted_keys = sorted(type_info.properties.keys())
            data['properties'] = {
                key: self.get_schema_or_reference(type_info.properties[key], output=output)
                for key in sorted_keys
            }

        if output:
            required_properties = list(type_info.properties)
        else:
            required_properties = [
                property_name
                for property_name in type_info.properties
                if property_name not in type_info.properties_defaults
                   and not type_info.properties[property_name].nullable
                   and not type_info.properties[property_name].can_be_undefined
            ]

        if required_properties:
            data['required'] = required_properties

        return Schema(**data)


def generate_openapi(
    title: str,
    version: str,
    routes: Sequence[Route],
    description: Optional[str] = None,
    tags: Optional[List[Dict[str, Any]]] = None,
    validate: bool = True,
    add_url_segment_as_tag: bool = True,
) -> Dict[str, Any]:
    routes = list(routes)
    routes.sort(key=lambda r: r.url_path)
    schema_registry = SchemaRegistry()
    tags_ = [Tag(**tag) for tag in tags or []]
    tag_names = [tag_.name for tag_ in tags_]
    paths: Paths = {}
    operation_ids: Set[str] = set()
    path_prefix = determine_path_prefix([route.url_path for route in routes])

    for url_path, group_routes in groupby(routes, key=lambda r: r.url_path):
        url_path_without_prefix = get_url_path_without_prefix(url_path, path_prefix)

        if not url_path_without_prefix.startswith('/'):
            url_path_without_prefix = '/' + url_path_without_prefix

        path_tag_names = list(tag_names)
        if add_url_segment_as_tag:
            url_path_tag = get_url_path_tag(url_path, path_prefix)
            if url_path_tag:
                path_tag_names.append(url_path_tag)

        path_item = _get_openapi_path(
            routes=group_routes,
            operation_ids=operation_ids,
            tag_names=path_tag_names,
            schema_registry=schema_registry,
        )
        paths[url_path_without_prefix] = path_item

    info = Info(title=title, version=version, description=description)
    servers_ = [Server(url=path_prefix)]
    components = Components(
        schemas=schema_registry.get_schemas(),
        responses={},
        parameters={},
    )
    openapi = OpenAPI(info=info, servers=servers_, paths=paths, components=components, tags=tags_)
    openapi_dict = openapi.dict(by_alias=True, exclude_none=True)
    if validate:
        schema_validate(openapi_dict)
    return openapi_dict


def _get_openapi_path(
    *,
    routes: Iterable[Route],
    operation_ids: Set[str],
    tag_names: Iterable[str],
    schema_registry: SchemaRegistry,
) -> PathItem:
    path = {}
    for route in routes:
        operation_id = route.method.full_name
        if operation_id in operation_ids:
            warnings.warn(f"Duplicate Operation ID {operation_id}")
        operation_ids.add(operation_id)

        try:
            operation = _get_openapi_operation(
                route=route,
                operation_id=operation_id,
                tag_names=tag_names,
                schema_registry=schema_registry,
            )
        except InspectorNotFound as e:
            raise CanNotInspectType(route.method, str(e))
        path[route.http_method.lower()] = operation

    return PathItem.parse_obj(path)


def _get_openapi_operation(
    *,
    route: Route,
    operation_id: str,
    tag_names: Iterable[str],
    schema_registry: SchemaRegistry,
) -> Operation:
    summary = route.method.docstring.short_description
    description = route.method.docstring.long_description
    operation_parameters = get_route_parameters(route, schema_registry)
    operation_request_body = get_request_body_parameters(route, schema_registry)
    operation_responses = get_responses_schemas(route, schema_registry)
    return Operation(
        tags=tag_names,
        summary=summary,
        description=description,
        operationId=operation_id,
        requestBody=operation_request_body,
        responses=operation_responses,
        parameters=operation_parameters,
    )


class CanNotInspectType(Exception):

    def __init__(
        self,
        method: ComponentMethod,
        message: str,
    ):
        self._message = message
        self._method = method

    def __repr__(self):
        return f'{self.__class__.__name__}({self})'

    def __str__(self):
        component_cls = self._method.component.component_cls
        method_path = f'{component_cls.__module__}.{self._method.full_name}'
        return f'{method_path}: {self._message}'


def get_url_path_without_prefix(url_path: str, path_prefix: str) -> str:
    # TODO: use removeprefix when python 3.9 will be used
    path_prefix_stripped = path_prefix.lstrip('/')
    url_path_stripped = url_path.lstrip('/')

    if path_prefix_stripped and url_path_stripped.startswith(path_prefix_stripped):
        return url_path_stripped[len(path_prefix_stripped):]
    else:
        return url_path


def get_url_path_tag(url_path: str, path_prefix: str) -> Optional[str]:
    path_prefix_segments = path_prefix.lstrip('/').split('/')
    url_path_segments = url_path.lstrip('/').split('/')
    path_prefix_segments = [segment for segment in path_prefix_segments if segment]  # remove empty segments like ['']

    # for the cases with single route
    if len(url_path_segments) <= len(path_prefix_segments):
        if len(url_path_segments) == len(path_prefix_segments):
            return None
        else:
            raise ValueError(f'Invalid path prefix {path_prefix} for url_path {url_path}')

    url_path_tag = url_path_segments[len(path_prefix_segments)]

    if url_path_tag.startswith('{'):
        return None

    return url_path_tag


def get_route_parameters(route: Route, schema_registry: SchemaRegistry) -> List[Parameter]:
    parameters = []
    for inspector in get_route_parameters_inspectors():
        parameters += inspector.inspect_parameters(route, schema_registry)
    return parameters


def get_request_body_parameters(route: Route, schema_registry: SchemaRegistry) -> Optional[RequestBody]:
    method = route.method
    request_body_annotation = method.annotations.get_one_or_none(RequestBodyAnnotation)
    if request_body_annotation is None:
        return None

    description = method.docstring.short_description
    argument = method.get_argument(request_body_annotation.argument_name)
    reference = schema_registry.get_schema_or_reference(argument.type_, output=False)

    route_annotation = method.annotations.get_one_or_none(RouteAnnotation)
    consumes = route_annotation.consumes or [MediaType.APPLICATION_JSON]
    content = {
        str(consume): MediaTypeModel(media_type_schema=reference)
        for consume in consumes
    }
    return RequestBody(description=description, content=content)


def get_responses_schemas(route: Route, schema_registry: SchemaRegistry) -> Responses:
    responses: Responses = {}
    http_method = route.http_method
    response_status = str(get_response_status(http_method, route.method))

    responses[response_status] = _build_response_schema(route.method, schema_registry)
    method_exceptions_manager = MethodExceptionsManager(route.method)

    for exception_cls in method_exceptions_manager.declared_exception_classes:
        handler = method_exceptions_manager.get_handler(exception_cls)
        handle_method = ComponentMethod.get_or_create(handler.__class__.handle)
        response_status = str(get_response_status(http_method, handle_method))
        responses[response_status] = _build_response_exception_handler_schema(handle_method, schema_registry)
    return responses


def determine_path_prefix(url_paths: List[str]) -> str:
    """
    https://github.com/encode/django-rest-framework/blob/master/LICENSE.md
    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
    INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
    DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
    WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
    USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

    Given a list of all paths, return the common prefix which should be
    discounted when generating a schema structure.

    This will be the longest common string that does not include that last
    component of the URL, or the last component before a path parameter.

    For example:

    /api/v1/users/
    /api/v1/users/{pk}/

    The path prefix is '/api/v1'
    """
    prefixes = []
    for path in url_paths:
        components = path.strip('/').split('/')
        initial_components = []
        for component in components:
            if '{' in component:
                break
            initial_components.append(component)
        prefix = '/'.join(initial_components[:-1])
        if not prefix:
            # We can just break early in the case that there's at least
            # one URL that doesn't have a path prefix.
            return '/'
        prefixes.append('/' + prefix + '/')

    split_paths = [path.strip('/').split('/') for path in prefixes]
    s1 = min(split_paths)
    s2 = max(split_paths)
    common = s1

    for i, c in enumerate(s1):
        if c != s2[i]:
            common = s1[:i]
            break

    return '/' + '/'.join(common)


def _build_response_schema(method: ComponentMethod, schema_registry: SchemaRegistry) -> Response:
    return_value_type = method.return_value_type
    if _is_abstract_or_none_return_type(return_value_type):
        return Response(description='')

    reference = schema_registry.get_schema_or_reference(return_value_type, output=True)
    route_annotation = method.annotations.get_one_or_none(RouteAnnotation)
    produces = route_annotation.produces or [MediaType.APPLICATION_JSON]
    content = {
        str(produce): MediaTypeModel(media_type_schema=reference)
        for produce in produces
    }
    return Response(description='', content=content)


def _build_response_exception_handler_schema(method: ComponentMethod, schema_registry: SchemaRegistry) -> Response:
    return_value_type = method.return_value_type
    if _is_abstract_or_none_return_type(return_value_type):
        return Response(description='')

    reference = schema_registry.get_schema_or_reference(return_value_type, output=True)
    content = {str(MediaType.APPLICATION_JSON): MediaTypeModel(media_type_schema=reference)}
    return Response(description='', content=content)


def _is_abstract_or_none_return_type(return_value_type):
    return (
        return_value_type in (None, type(None)) or
        (
            inspect.isclass(return_value_type) and
            issubclass(return_value_type, HttpResponseBase)
        )
    )