cloudsmith-io/cloudsmith-cli

View on GitHub
cloudsmith_cli/core/rest.py

Summary

Maintainability
C
7 hrs
Test Coverage
B
84%
"""A RESTful API client with retry builtin."""

import io
import json
import logging
import re
import time
from urllib.parse import urlencode

import requests
import requests.exceptions
from cloudsmith_api.configuration import Configuration
from cloudsmith_api.rest import ApiException, RESTClientObject
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

logger = logging.getLogger(__name__)


class RetryWithCallback(Retry):
    """A urllib3 Retry with a callback on retries."""

    def __init__(self, *args, **kwargs):
        self.error_retry_cb = kwargs.pop("error_retry_cb", None)
        super().__init__(*args, **kwargs)

    def new(self, **kw):
        kw["error_retry_cb"] = self.error_retry_cb
        return super().new(**kw)

    def sleep_for_retry(self, response=None):
        retry_after = self.get_retry_after(response)
        if retry_after:
            self._sleep_with_callback(retry_after, context="retry-after")
            return True

        return False

    def _sleep_backoff(self):
        backoff = self.get_backoff_time()
        if backoff <= 0:
            return
        self._sleep_with_callback(backoff, context="backoff")

    def _sleep_with_callback(self, seconds, context=None):
        """Sleep, but generate a callback before it."""
        if self.error_retry_cb and callable(self.error_retry_cb):
            self.error_retry_cb(seconds, context=context)
        return time.sleep(seconds)


def create_requests_session(
    retries=None,
    backoff_factor=None,
    status_forcelist=None,
    pools_size=4,
    maxsize=4,
    ssl_verify=None,
    ssl_cert=None,
    proxy=None,
    session=None,
    error_retry_cb=None,
    respect_retry_after_header=True,
):
    """Create a requests session that retries some errors."""
    # pylint: disable=too-many-branches
    config = Configuration()

    if retries is None:
        if config.error_retry_max is None:  # pylint: disable=no-member
            retries = 5
        else:
            retries = config.error_retry_max  # pylint: disable=no-member

    if backoff_factor is None:
        if config.error_retry_backoff is None:  # pylint: disable=no-member
            backoff_factor = 0.23
        else:
            backoff_factor = config.error_retry_backoff  # pylint: disable=no-member

    if status_forcelist is None:
        if config.error_retry_codes is None:  # pylint: disable=no-member
            status_forcelist = [500, 502, 503, 504]
        else:
            status_forcelist = config.error_retry_codes  # pylint: disable=no-member

    if ssl_verify is None:
        ssl_verify = config.verify_ssl

    if ssl_cert is None:
        if config.cert_file and config.key_file:
            ssl_cert = (config.cert_file, config.key_file)
        elif config.cert_file:
            ssl_cert = config.cert_file

    if proxy is None:
        proxy = Configuration().proxy

    session = session or requests.Session()
    session.verify = ssl_verify
    session.cert = ssl_cert

    if proxy:
        session.proxies = {"http": proxy, "https": proxy}

    retry = RetryWithCallback(
        backoff_factor=backoff_factor,
        connect=retries,
        method_whitelist=False,
        read=retries,
        status_forcelist=tuple(status_forcelist),
        status=retries,
        total=retries,
        error_retry_cb=error_retry_cb,
        respect_retry_after_header=respect_retry_after_header,
    )

    adapter = HTTPAdapter(
        max_retries=retry,
        pool_connections=pools_size,
        pool_maxsize=maxsize,
        pool_block=True,
    )

    session.mount("http://", adapter)
    session.mount("https://", adapter)

    return session


class RestResponse(io.IOBase):
    """A urllib3 adapter for a requests response."""

    def __init__(self, response):
        super().__init__()
        self.response = response
        self.status = response.status_code
        self.reason = response.reason
        self._data = None

    @property
    def data(self):
        """
        Get the content for the response (lazily decoded).
        """
        if self._data is None:
            self._data = self.response.content.decode("utf-8")
        return self._data

    def getheaders(self):
        """
        Return a dictionary of the response headers.
        """
        return self.response.headers

    def getheader(self, name, default=None):
        """
        Return a given response header.
        """
        return self.response.headers.get(name, default)


class RestClient(RESTClientObject):
    """A rest client interface based on requests, with retry."""

    def __init__(self, *args, **kwargs):
        # pylint: disable=super-init-not-called
        self.session = create_requests_session(*args, **kwargs)

    def request(
        self,
        method,
        url,
        query_params=None,
        headers=None,
        body=None,
        post_params=None,
        _preload_content=True,
        _request_timeout=None,
    ):
        """
        :param method: http request method
        :param url: http request url
        :param query_params: query parameters in the url
        :param headers: http request headers
        :param body: request json body, for `application/json`
        :param post_params: request post parameters,
                            `application/x-www-form-urlencoded`
                            and `multipart/form-data`
        :param _preload_content: if False, the response object will be returned without
                                 reading/decoding response data. Default is True.
        :param _request_timeout: timeout setting for this request. If one number provided, it will be total request
                                 timeout. It can also be a pair (tuple) of (connection, read) timeouts.
        """
        # Based on the RESTClientObject class generated by Swagger
        method = method.upper()
        assert method in ["GET", "HEAD", "DELETE", "POST", "PUT", "PATCH", "OPTIONS"]

        post_params = post_params or {}
        headers = headers or {}

        if "Content-Type" not in headers:
            headers["Content-Type"] = "application/json"

        request_kwargs = {}

        if query_params:
            url += "?" + urlencode(query_params)

        if method in ["POST", "PUT", "PATCH", "OPTIONS", "DELETE"]:
            if re.search("json", headers["Content-Type"], re.IGNORECASE):
                request_body = None
                if body:
                    request_body = json.dumps(body)
                request_kwargs["data"] = request_body
            elif headers["Content-Type"] == "application/x-www-form-urlencoded":
                request_kwargs["data"] = post_params
            elif headers["Content-Type"] == "multipart/form-data":
                del headers["Content-Type"]
                request_kwargs["data"] = post_params
            elif isinstance(body, str):
                request_kwargs["data"] = body
            else:
                # Cannot generate the request from given parameters
                msg = """Cannot prepare a request message for provided arguments.
                         Please check that your arguments match declared content type."""
                raise ApiException(status=0, reason=msg)

        try:
            resp = self.session.request(
                method,
                url,
                timeout=_request_timeout,
                stream=not _preload_content,
                headers=headers,
                **request_kwargs,
            )
        except requests.exceptions.RequestException as exc:
            msg = f"{type(exc).__name__}\n{str(exc)}"
            raise ApiException(status=0, reason=msg)

        resp.encoding = resp.apparent_encoding or "utf-8"
        rest_resp = RestResponse(resp)

        if _preload_content:
            logger.debug("response body: %s", rest_resp.data)

        if not 200 <= rest_resp.status <= 299:
            raise ApiException(http_resp=rest_resp)

        return rest_resp