
View on GitHub


1 hr
Test Coverage
import abc
import copy
import pathlib

import jinja2
import yaml
from openapi_spec_validator.exceptions import OpenAPIValidationError
from urllib.parse import urlsplit

from .exceptions import InvalidSpecification
from .json_schema import resolve_refs
from .operations import OpenAPIOperation, Swagger2Operation
from .utils import deep_get

    import as collections_abc  # python 3.3+
except ImportError:
    import collections as collections_abc

NO_SPEC_VERSION_ERR_MSG = """Unable to get the spec version.
You are missing either '"swagger": "2.0"' or '"openapi": "3.0.0"'
from the top level of your spec."""

def canonical_base_path(base_path):
    Make given "basePath" a canonical base URL which can be prepended to paths starting with "/".
    return base_path.rstrip('/')

class Specification(collections_abc.Mapping):

    def __init__(self, raw_spec):
        self._raw_spec = copy.deepcopy(raw_spec)
        self._spec = resolve_refs(raw_spec)

    def _set_defaults(cls, spec):
        """ set some default values in the spec

    def _validate_spec(cls, spec):
        """ validate spec against schema

    def get_path_params(self, path):
        return deep_get(self._spec, ["paths", path]).get("parameters", [])

    def get_operation(self, path, method):
        return deep_get(self._spec, ["paths", path, method])

    def raw(self):
        return self._raw_spec

    def version(self):
        return self._get_spec_version(self._spec)

    def security(self):
        return self._spec.get('security')

    def __getitem__(self, k):
        return self._spec[k]

    def __iter__(self):
        return self._spec.__iter__()

    def __len__(self):
        return self._spec.__len__()

    def _load_spec_from_file(arguments, specification):
        Loads a YAML specification file, optionally rendering it with Jinja2.
          arguments - passed to Jinja2 renderer
          specification - path to specification
        arguments = arguments or {}

        with'rb') as openapi_yaml:
            contents =
                openapi_template = contents.decode()
            except UnicodeDecodeError:
                openapi_template = contents.decode('utf-8', 'replace')

            openapi_string = jinja2.Template(openapi_template).render(**arguments)
            return yaml.safe_load(openapi_string)

    def from_file(cls, spec, arguments=None):
        Takes in a path to a YAML file, and returns a Specification
        specification_path = pathlib.Path(spec)
        spec = cls._load_spec_from_file(arguments, specification_path)
        return cls.from_dict(spec)

    def _get_spec_version(spec):
            version_string = spec.get('openapi') or spec.get('swagger')
        except AttributeError:
            raise InvalidSpecification(NO_SPEC_VERSION_ERR_MSG)
        if version_string is None:
            raise InvalidSpecification(NO_SPEC_VERSION_ERR_MSG)
            version_tuple = tuple(map(int, version_string.split(".")))
        except TypeError:
            err = ('Unable to convert version string to semantic version tuple: '
            err = err.format(version_string=version_string)
            raise InvalidSpecification(err)
        return version_tuple

    def from_dict(cls, spec):
        Takes in a dictionary, and returns a Specification
        def enforce_string_keys(obj):
            # YAML supports integer keys, but JSON does not
            if isinstance(obj, dict):
                return {
                    str(k): enforce_string_keys(v)
                    for k, v
                    in obj.items()
            return obj

        spec = enforce_string_keys(spec)
        version = cls._get_spec_version(spec)
        if version < (3, 0, 0):
            return Swagger2Specification(spec)
        return OpenAPISpecification(spec)

    def clone(self):
        return type(self)(copy.deepcopy(self._raw_spec))

    def load(cls, spec, arguments=None):
        if not isinstance(spec, dict):
            return cls.from_file(spec, arguments=arguments)
        return cls.from_dict(spec)

    def with_base_path(self, base_path):
        new_spec = self.clone()
        new_spec.base_path = base_path
        return new_spec

class Swagger2Specification(Specification):
    yaml_name = 'swagger.yaml'
    operation_cls = Swagger2Operation

    def _set_defaults(cls, spec):
        spec.setdefault('produces', [])
        spec.setdefault('consumes', ['application/json'])  # type: List[str]
        spec.setdefault('definitions', {})
        spec.setdefault('parameters', {})
        spec.setdefault('responses', {})

    def produces(self):
        return self._spec['produces']

    def consumes(self):
        return self._spec['consumes']

    def definitions(self):
        return self._spec['definitions']

    def parameter_definitions(self):
        return self._spec['parameters']

    def response_definitions(self):
        return self._spec['responses']

    def security_definitions(self):
        return self._spec.get('securityDefinitions', {})

    def base_path(self):
        return canonical_base_path(self._spec.get('basePath', ''))

    def base_path(self, base_path):
        base_path = canonical_base_path(base_path)
        self._raw_spec['basePath'] = base_path
        self._spec['basePath'] = base_path

    def _validate_spec(cls, spec):
        from openapi_spec_validator import validate_v2_spec as validate_spec
        except OpenAPIValidationError as e:
            raise InvalidSpecification.create_from(e)

class OpenAPISpecification(Specification):
    yaml_name = 'openapi.yaml'
    operation_cls = OpenAPIOperation

    def _set_defaults(cls, spec):
        spec.setdefault('components', {})

    def security_definitions(self):
        return self._spec['components'].get('securitySchemes', {})

    def components(self):
        return self._spec['components']

    def _validate_spec(cls, spec):
        from openapi_spec_validator import validate_v3_spec as validate_spec
        except OpenAPIValidationError as e:
            raise InvalidSpecification.create_from(e)

    def base_path(self):
        servers = self._spec.get('servers', [])
            # assume we're the first server in list
            server = copy.deepcopy(servers[0])
            server_vars = server.pop("variables", {})
            server['url'] = server['url'].format(
                **{k: v['default'] for k, v
                   in server_vars.items()}
            base_path = urlsplit(server['url']).path
        except IndexError:
            base_path = ''
        return canonical_base_path(base_path)

    def base_path(self, base_path):
        base_path = canonical_base_path(base_path)
        user_servers = [{'url': base_path}]
        self._raw_spec['servers'] = user_servers
        self._spec['servers'] = user_servers