michalc/lowhaio

View on GitHub
lowhaio.py

Summary

Maintainability
A
0 mins
Test Coverage
import asyncio
import contextlib
import ipaddress
import logging
import urllib.parse
import ssl
import socket

from aiodnsresolver import (
    TYPES,
    DnsError,
    Resolver,
    ResolverLoggerAdapter,
)


class HttpError(Exception):
    pass


class HttpConnectionError(HttpError):
    pass


class HttpDnsError(HttpConnectionError):
    pass


class HttpTlsError(HttpConnectionError):
    pass


class HttpDataError(HttpError):
    pass


class HttpConnectionClosedError(HttpDataError):
    pass


class HttpHeaderTooLong(HttpDataError):
    pass


class HttpLoggerAdapter(logging.LoggerAdapter):
    def process(self, msg, kwargs):
        return \
            ('[http] %s' % (msg,), kwargs) if not self.extra else \
            ('[http:%s] %s' % (','.join(str(v) for v in self.extra.values()), msg), kwargs)


def get_logger_adapter_default(extra):
    return HttpLoggerAdapter(logging.getLogger('lowhaio'), extra)


def get_resolver_logger_adapter_default(http_extra):
    def _get_resolver_logger_adapter_default(resolver_extra):
        http_adapter = HttpLoggerAdapter(logging.getLogger('aiodnsresolver'), http_extra)
        return ResolverLoggerAdapter(http_adapter, resolver_extra)
    return _get_resolver_logger_adapter_default


async def empty_async_iterator():
    while False:
        yield


get_current_task = \
    asyncio.current_task if hasattr(asyncio, 'current_task') else \
    asyncio.Task.current_task


def streamed(data):
    async def _streamed():
        yield data
    return _streamed


async def buffered(data):
    return b''.join([chunk async for chunk in data])


def get_nonblocking_sock():
    sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP)
    sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
    sock.setblocking(False)
    return sock


def set_tcp_cork(sock):
    sock.setsockopt(socket.SOL_TCP, socket.TCP_CORK, 1)  # pylint: disable=no-member


def unset_tcp_cork(sock):
    sock.setsockopt(socket.SOL_TCP, socket.TCP_CORK, 0)  # pylint: disable=no-member


async def send_body_async_gen_bytes(logger, loop, sock, socket_timeout,
                                    body, body_args, body_kwargs):
    logger.debug('Sending body')
    num_bytes = 0
    async for chunk in body(*body_args, **dict(body_kwargs)):
        num_bytes += len(chunk)
        await send_all(loop, sock, socket_timeout, chunk)
    logger.debug('Sent body bytes: %s', num_bytes)


async def send_header_tuples_of_bytes(logger, loop, sock, socket_timeout,
                                      http_version, method, parsed_url, params, headers):
    logger.debug('Sending header')
    outgoing_qs = urllib.parse.urlencode(params, doseq=True).encode()
    outgoing_path = urllib.parse.quote(parsed_url.path).encode()
    outgoing_path_qs = outgoing_path + \
        ((b'?' + outgoing_qs) if outgoing_qs != b'' else b'')
    host_specified = any(True for key, value in headers if key == b'host')
    headers_with_host = \
        headers if host_specified else \
        ((b'host', parsed_url.hostname.encode('idna')),) + headers
    await send_all(loop, sock, socket_timeout, b'%s %s %s\r\n%s\r\n' % (
        method, outgoing_path_qs, http_version, b''.join(
            b'%s:%s\r\n' % (key, value)
            for (key, value) in headers_with_host
        )
    ))
    logger.debug('Sent header')


def Pool(
        get_dns_resolver=Resolver,
        get_sock=get_nonblocking_sock,
        get_ssl_context=ssl.create_default_context,
        sock_pre_message=set_tcp_cork if hasattr(socket, 'TCP_CORK') else lambda _: None,
        sock_post_message=unset_tcp_cork if hasattr(socket, 'TCP_CORK') else lambda _: None,
        send_header=send_header_tuples_of_bytes,
        send_body=send_body_async_gen_bytes,
        http_version=b'HTTP/1.1',
        keep_alive_timeout=15,
        recv_bufsize=16384,
        max_header_length=16384,
        socket_timeout=10,
        get_logger_adapter=get_logger_adapter_default,
        get_resolver_logger_adapter=get_resolver_logger_adapter_default,
):

    loop = \
        asyncio.get_running_loop() if hasattr(asyncio, 'get_running_loop') else \
        asyncio.get_event_loop()
    ssl_context = get_ssl_context()

    logger_extra = {}
    logger = get_logger_adapter({})

    dns_resolve, dns_resolver_clear_cache = get_dns_resolver(
        get_logger_adapter=get_resolver_logger_adapter_default(logger_extra),
    )

    pool = {}

    async def request(method, url, params=(), headers=(),
                      body=empty_async_iterator, body_args=(), body_kwargs=(),
                      get_logger_adapter=get_logger_adapter,
                      get_resolver_logger_adapter=get_resolver_logger_adapter,
                      ):
        parsed_url = urllib.parse.urlsplit(url)

        logger_extra = {'lowhaio_method': method.decode(), 'lowhaio_url': url}
        logger = get_logger_adapter(logger_extra)

        try:
            ip_addresses = (ipaddress.ip_address(parsed_url.hostname),)
        except ValueError:
            try:
                ip_addresses = await dns_resolve(
                    parsed_url.hostname, TYPES.A,
                    get_logger_adapter=get_resolver_logger_adapter(logger_extra),
                )
            except DnsError as exception:
                raise HttpDnsError() from exception

        key = (parsed_url.scheme, parsed_url.netloc)

        sock = get_from_pool(logger, key, ip_addresses)

        if sock is None:
            sock = get_sock()
            try:
                logger.debug('Connecting: %s', sock)
                await connect(sock, parsed_url, str(ip_addresses[0]))
                logger.debug('Connected: %s', sock)
            except asyncio.CancelledError:
                sock.close()
                raise
            except Exception as exception:
                sock.close()
                raise HttpConnectionError() from exception
            except BaseException:
                sock.close()
                raise

            try:
                if parsed_url.scheme == 'https':
                    logger.debug('TLS handshake started')
                    sock = tls_wrapped(sock, parsed_url.hostname)
                    await tls_complete_handshake(loop, sock, socket_timeout)
                    logger.debug('TLS handshake completed')
            except asyncio.CancelledError:
                sock.close()
                raise
            except Exception as exception:
                sock.close()
                raise HttpTlsError() from exception
            except BaseException:
                sock.close()
                raise

        try:
            sock_pre_message(sock)
            await send_header(logger, loop, sock, socket_timeout, http_version,
                              method, parsed_url, params, headers)
            await send_body(logger, loop, sock, socket_timeout, body, body_args, body_kwargs)
            sock_post_message(sock)

            code, version, response_headers, unprocessed = await recv_header(sock)
            logger.debug('Received header with code: %s', code)
            connection, body_length, body_handler = connection_length_body_handler(
                logger, method, version, response_headers)
            response_body = response_body_generator(
                logger, sock, unprocessed, key, connection, body_length, body_handler)
        except asyncio.CancelledError:
            sock.close()
            raise
        except Exception as exception:
            sock.close()
            if isinstance(exception, HttpDataError):
                raise
            raise HttpDataError() from exception
        except BaseException:
            sock.close()
            raise

        return code, response_headers, response_body

    def get_from_pool(logger, key, ip_addresses):
        try:
            socks = pool[key]
        except KeyError:
            logger.debug('Connection not in pool: %s', key)
            return None

        while socks:
            _sock, close_callback = next(iter(socks.items()))
            close_callback.cancel()
            del socks[_sock]

            try:
                connected_ip = ipaddress.ip_address(_sock.getpeername()[0])
            except OSError:
                logger.debug('Unable to get peer name: %s', _sock)
                _sock.close()
                continue

            if connected_ip not in ip_addresses:
                logger.debug('Not current for domain, closing: %s', _sock)
                _sock.close()
                continue

            logger.debug('Reusing connection %s', _sock)
            if _sock.fileno() != -1:
                return _sock

        del pool[key]

    def add_to_pool(key, sock):
        try:
            key_pool = pool[key]
        except KeyError:
            key_pool = {}
            pool[key] = key_pool

        key_pool[sock] = loop.call_later(keep_alive_timeout, close_by_keep_alive_timeout,
                                         key, sock)

    def close_by_keep_alive_timeout(key, sock):
        logger.debug('Closing by timeout: %s,%s', key, sock)
        sock.close()
        del pool[key][sock]
        if not pool[key]:
            del pool[key]

    async def connect(sock, parsed_url, ip_address):
        scheme = parsed_url.scheme
        _, _, port_specified = parsed_url.netloc.partition(':')
        port = \
            port_specified if port_specified != '' else \
            443 if scheme == 'https' else \
            80
        address = (ip_address, port)
        await loop.sock_connect(sock, address)

    def tls_wrapped(sock, host):
        return ssl_context.wrap_socket(sock, server_hostname=host, do_handshake_on_connect=False)

    async def recv_header(sock):
        unprocessed = b''
        while True:
            unprocessed += await recv(loop, sock, socket_timeout, recv_bufsize)
            try:
                header_end = unprocessed.index(b'\r\n\r\n')
            except ValueError:
                if len(unprocessed) >= max_header_length:
                    raise HttpHeaderTooLong()
                continue
            else:
                break

        header_bytes, unprocessed = unprocessed[:header_end], unprocessed[header_end + 4:]
        lines = header_bytes.split(b'\r\n')
        code = lines[0][9:12]
        version = lines[0][5:8]
        response_headers = tuple(
            (key.strip().lower(), value.strip())
            for line in lines[1:]
            for (key, _, value) in (line.partition(b':'),)
        )
        return code, version, response_headers, unprocessed

    async def response_body_generator(
            logger, sock, unprocessed, key, connection, body_length, body_handler):
        try:
            generator = body_handler(logger, sock, body_length, unprocessed)
            unprocessed = None  # So can be garbage collected

            logger.debug('Receiving body')
            num_bytes = 0
            async for chunk in generator:
                yield chunk
                num_bytes += len(chunk)
            logger.debug('Received transfer-decoded body bytes: %s', num_bytes)
        except BaseException:
            sock.close()
            raise
        else:
            if connection == b'keep-alive':
                logger.debug('Keeping connection alive: %s', sock)
                add_to_pool(key, sock)
            else:
                logger.debug('Closing connection: %s', sock)
                sock.close()

    def connection_length_body_handler(logger, method, version, response_headers):
        headers_dict = dict(response_headers)
        transfer_encoding = headers_dict.get(b'transfer-encoding', b'identity')
        logger.debug('Effective transfer-encoding: %s', transfer_encoding)
        connection = \
            b'close' if keep_alive_timeout == 0 else \
            headers_dict.get(b'connection', b'keep-alive').lower() if version == b'1.1' else \
            headers_dict.get(b'connection', b'close').lower()
        logger.debug('Effective connection: %s', connection)

        body_length = \
            0 if method == b'HEAD' else \
            0 if connection == b'keep-alive' and b'content-length' not in headers_dict else \
            None if b'content-length' not in headers_dict else \
            int(headers_dict[b'content-length'])

        uses_identity = (method == b'HEAD' or transfer_encoding == b'identity')
        body_handler = \
            identity_handler_known_body_length if uses_identity and body_length is not None else \
            identity_handler_unknown_body_length if uses_identity else \
            chunked_handler

        return connection, body_length, body_handler

    async def identity_handler_known_body_length(logger, sock, body_length, unprocessed):
        logger.debug('Expected incoming body bytes: %s', body_length)
        total_remaining = body_length

        if unprocessed and total_remaining:
            total_remaining -= len(unprocessed)
            yield unprocessed

        while total_remaining:
            unprocessed = None  # So can be garbage collected
            unprocessed = await recv(loop, sock, socket_timeout,
                                     min(recv_bufsize, total_remaining))
            total_remaining -= len(unprocessed)
            yield unprocessed

    async def identity_handler_unknown_body_length(logger, sock, _, unprocessed):
        logger.debug('Unknown incoming body length')
        if unprocessed:
            yield unprocessed

        unprocessed = None  # So can be garbage collected

        try:
            while True:
                yield await recv(loop, sock, socket_timeout, recv_bufsize)
        except HttpConnectionClosedError:
            pass

    async def chunked_handler(_, sock, __, unprocessed):
        while True:
            # Fetch until have chunk header
            while b'\r\n' not in unprocessed:
                if len(unprocessed) >= max_header_length:
                    raise HttpHeaderTooLong()
                unprocessed += await recv(loop, sock, socket_timeout, recv_bufsize)

            # Find chunk length
            chunk_header_end = unprocessed.index(b'\r\n')
            chunk_header_hex = unprocessed[:chunk_header_end]
            chunk_length = int(chunk_header_hex, 16)

            # End of body signalled by a 0-length chunk
            if chunk_length == 0:
                while b'\r\n\r\n' not in unprocessed:
                    if len(unprocessed) >= max_header_length:
                        raise HttpHeaderTooLong()
                    unprocessed += await recv(loop, sock, socket_timeout, recv_bufsize)

                break

            # Remove chunk header
            unprocessed = unprocessed[chunk_header_end + 2:]

            # Yield whatever amount of chunk we have already, which
            # might be nothing
            chunk_remaining = chunk_length
            in_chunk, unprocessed = \
                unprocessed[:chunk_remaining], unprocessed[chunk_remaining:]
            if in_chunk:
                yield in_chunk
            chunk_remaining -= len(in_chunk)

            # Fetch and yield rest of chunk
            while chunk_remaining:
                unprocessed += await recv(loop, sock, socket_timeout, recv_bufsize)
                in_chunk, unprocessed = \
                    unprocessed[:chunk_remaining], unprocessed[chunk_remaining:]
                chunk_remaining -= len(in_chunk)
                yield in_chunk

            # Fetch until have chunk footer, and remove
            while len(unprocessed) < 2:
                unprocessed += await recv(loop, sock, socket_timeout, recv_bufsize)
            unprocessed = unprocessed[2:]

    async def close(
            get_logger_adapter=get_logger_adapter,
            get_resolver_logger_adapter=get_resolver_logger_adapter,
    ):
        logger_extra = {}
        logger = get_logger_adapter(logger_extra)

        logger.debug('Closing pool')

        await dns_resolver_clear_cache(
            get_logger_adapter=get_resolver_logger_adapter(logger_extra),
        )
        for key, socks in pool.items():
            for sock, close_callback in socks.items():
                logger.debug('Closing: %s,%s', key, sock)
                close_callback.cancel()
                sock.close()
        pool.clear()

    return request, close


async def send_all(loop, sock, socket_timeout, data):
    try:
        latest_num_bytes = sock.send(data)
    except (BlockingIOError, ssl.SSLWantWriteError):
        latest_num_bytes = 0
    else:
        if latest_num_bytes == 0:
            raise HttpConnectionClosedError()

    if latest_num_bytes == len(data):
        return

    total_num_bytes = latest_num_bytes

    def writer():
        nonlocal total_num_bytes
        try:
            latest_num_bytes = sock.send(data_memoryview[total_num_bytes:])
        except (BlockingIOError, ssl.SSLWantWriteError):
            pass
        except Exception as exception:
            loop.remove_writer(fileno)
            if not result.done():
                result.set_exception(exception)
        else:
            total_num_bytes += latest_num_bytes
            if latest_num_bytes == 0 and not result.done():
                loop.remove_writer(fileno)
                result.set_exception(HttpConnectionClosedError())
            elif total_num_bytes == len(data) and not result.done():
                loop.remove_writer(fileno)
                result.set_result(None)
            else:
                reset_timeout()

    result = asyncio.Future()
    fileno = sock.fileno()
    loop.add_writer(fileno, writer)
    data_memoryview = memoryview(data)

    try:
        with timeout(loop, socket_timeout) as reset_timeout:
            return await result
    finally:
        loop.remove_writer(fileno)


async def recv(loop, sock, socket_timeout, recv_bufsize):
    incoming = await _recv(loop, sock, socket_timeout, recv_bufsize)
    if not incoming:
        raise HttpConnectionClosedError()
    return incoming


async def _recv(loop, sock, socket_timeout, recv_bufsize):
    try:
        return sock.recv(recv_bufsize)
    except (BlockingIOError, ssl.SSLWantReadError):
        pass

    def reader():
        try:
            chunk = sock.recv(recv_bufsize)
        except (BlockingIOError, ssl.SSLWantReadError):
            pass
        except Exception as exception:
            loop.remove_reader(fileno)
            if not result.done():
                result.set_exception(exception)
        else:
            loop.remove_reader(fileno)
            if not result.done():
                result.set_result(chunk)

    result = asyncio.Future()
    fileno = sock.fileno()
    loop.add_reader(fileno, reader)

    try:
        with timeout(loop, socket_timeout):
            return await result
    finally:
        loop.remove_reader(fileno)


async def tls_complete_handshake(loop, ssl_sock, socket_timeout):
    try:
        return ssl_sock.do_handshake()
    except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
        pass

    def handshake():
        try:
            ssl_sock.do_handshake()
        except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
            reset_timeout()
        except Exception as exception:
            loop.remove_reader(fileno)
            loop.remove_writer(fileno)
            if not done.done():
                done.set_exception(exception)
        else:
            loop.remove_reader(fileno)
            loop.remove_writer(fileno)
            if not done.done():
                done.set_result(None)

    done = asyncio.Future()
    fileno = ssl_sock.fileno()
    loop.add_reader(fileno, handshake)
    loop.add_writer(fileno, handshake)

    try:
        with timeout(loop, socket_timeout) as reset_timeout:
            return await done
    finally:
        loop.remove_reader(fileno)
        loop.remove_writer(fileno)


@contextlib.contextmanager
def timeout(loop, max_time):

    cancelling_due_to_timeout = False
    current_task = get_current_task()

    def cancel():
        nonlocal cancelling_due_to_timeout
        cancelling_due_to_timeout = True
        current_task.cancel()

    def reset():
        nonlocal handle
        handle.cancel()
        handle = loop.call_later(max_time, cancel)

    handle = loop.call_later(max_time, cancel)

    try:
        yield reset
    except asyncio.CancelledError:
        if cancelling_due_to_timeout:
            raise asyncio.TimeoutError()
        raise
    finally:
        handle.cancel()