src/pook/interceptors/aiohttp.py

Summary

Maintainability
A
0 mins
Test Coverage
from ..request import Request
from .base import BaseInterceptor

from unittest import mock

from urllib.parse import urlunparse, urlencode
from http.client import responses as http_reasons

import asyncio
from aiohttp.helpers import TimerNoop
from aiohttp.streams import EmptyStreamReader

# Try to load yarl URL parser package used by aiohttp
try:
    import yarl
    import multidict
except Exception:
    yarl, multidict = None, None

PATCHES = ("aiohttp.client.ClientSession._request",)

RESPONSE_CLASS = "ClientResponse"
RESPONSE_PATH = "aiohttp.client_reqrep"


class SimpleContent(EmptyStreamReader):
    def __init__(self, content, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.content = content

    async def read(self, n=-1):
        return self.content


def HTTPResponse(*args, **kw):
    # Dynamically load package
    module = __import__(RESPONSE_PATH, fromlist=(RESPONSE_CLASS,))
    ClientResponse = getattr(module, RESPONSE_CLASS)

    # Return response instance
    return ClientResponse(
        *args,
        request_info=mock.Mock(),
        writer=None,
        continue100=None,
        timer=TimerNoop(),
        traces=[],
        loop=mock.Mock(),
        session=mock.Mock(),
        **kw,
    )


class AIOHTTPInterceptor(BaseInterceptor):
    """
    aiohttp HTTP client traffic interceptor.
    """

    def _url(self, url):
        return yarl.URL(url) if yarl else None

    async def _on_request(
        self, _request, session, method, url, data=None, headers=None, **kw
    ):
        # Create request contract based on incoming params
        req = Request(method)
        req.headers = headers or {}
        req.body = data

        # Expose extra variadic arguments
        req.extra = kw

        # Compose URL
        if not kw.get("params"):
            req.url = str(url)
        else:
            req.url = (
                str(url) + "?" + urlencode([(x, y) for x, y in kw["params"].items()])
            )

        # Match the request against the registered mocks in pook
        mock = self.engine.match(req)

        # If cannot match any mock, run real HTTP request if networking
        # or silent model are enabled, otherwise this statement won't
        # be reached (an exception will be raised before).
        if not mock:
            return await _request(
                session, method, url, data=data, headers=headers, **kw
            )

        # Simulate network delay
        if mock._delay:
            await asyncio.sleep(mock._delay / 1000)  # noqa

        # Shortcut to mock response
        res = mock._response

        # Aggregate headers as list of tuples for interface compatibility
        headers = []
        for key in res._headers:
            headers.append((key, res._headers[key]))

        # Create mock equivalent HTTP response
        _res = HTTPResponse(req.method, self._url(urlunparse(req.url)))

        # response status
        _res.version = (1, 1)
        _res.status = res._status
        _res.reason = http_reasons.get(res._status)
        _res._should_close = False

        # Add response headers
        _res._raw_headers = tuple(headers)
        _res._headers = multidict.CIMultiDictProxy(multidict.CIMultiDict(headers))

        if res._body:
            _res.content = SimpleContent(
                res._body
                if res._binary
                else res._body.encode("utf-8", errors="replace"),
            )
        else:
            # Define `_content` attribute with an empty string to
            # force do not read from stream (which won't exists)
            _res.content = EmptyStreamReader()

        # Return response based on mock definition
        return _res

    def _patch(self, path):
        # If not able to import aiohttp dependencies, skip
        if not yarl or not multidict:
            return None

        async def handler(session, method, url, data=None, headers=None, **kw):
            return await self._on_request(
                _request, session, method, url, data=data, headers=headers, **kw
            )

        try:
            # Create a new patcher for Urllib3 urlopen function
            # used as entry point for all the HTTP communications
            patcher = mock.patch(path, handler)
            # Retrieve original patched function that we might need for real
            # networking
            _request = patcher.get_original()[0]
            # Start patching function calls
            patcher.start()
        except Exception:
            # Exceptions may accur due to missing package
            # Ignore all the exceptions for now
            pass
        else:
            self.patchers.append(patcher)

    def activate(self):
        """
        Activates the traffic interceptor.
        This method must be implemented by any interceptor.
        """
        [self._patch(path) for path in PATCHES]

    def disable(self):
        """
        Disables the traffic interceptor.
        This method must be implemented by any interceptor.
        """
        [patch.stop() for patch in self.patchers]