uktrade/dns-rewrite-proxy

View on GitHub
dnsrewriteproxy.py

Summary

Maintainability
B
5 hrs
Test Coverage
A
100%
from asyncio import (
    CancelledError,
    Queue,
    create_task,
    get_running_loop,
)
from enum import (
    IntEnum,
)
import logging
import re
from random import (
    choices,
)
import string
import socket

from aiodnsresolver import (
    RESPONSE,
    TYPES,
    DnsRecordDoesNotExist,
    DnsResponseCode,
    Message,
    Resolver,
    ResourceRecord,
    ResolverLoggerAdapter,
    pack,
    parse,
    recvfrom,
)


def get_socket_default():
    sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
    sock.setblocking(False)
    sock.bind(('', 53))
    return sock


def get_resolver_default():
    return Resolver()


class DnsProxyLoggerAdapter(logging.LoggerAdapter):
    def process(self, msg, kwargs):
        return \
            ('[dnsproxy] %s' % (msg,), kwargs) if not self.extra else \
            ('[dnsproxy:%s] %s' % (','.join(str(v) for v in self.extra.values()), msg), kwargs)


def get_logger_adapter_default(extra):
    return DnsProxyLoggerAdapter(logging.getLogger('dnsrewriteproxy'), extra)


def get_resolver_logger_adapter_default(parent_adapter):
    def _get_resolver_logger_adapter_default(dns_extra):
        return ResolverLoggerAdapter(parent_adapter, dns_extra)
    return _get_resolver_logger_adapter_default


def DnsProxy(
        get_resolver=get_resolver_default,
        get_logger_adapter=get_logger_adapter_default,
        get_resolver_logger_adapter=get_resolver_logger_adapter_default,
        get_socket=get_socket_default, num_workers=1000,
        rules=(),
):

    class ERRORS(IntEnum):
        FORMERR = 1
        SERVFAIL = 2
        NXDOMAIN = 3
        REFUSED = 5

    loop = get_running_loop()
    logger = get_logger_adapter({})
    request_id_alphabet = string.ascii_letters + string.digits

    # The "main" task of the server: it receives incoming requests and puts
    # them in a queue that is then fetched from and processed by the proxy
    # workers

    async def server_worker(sock, resolve, stop):
        upstream_queue = Queue(maxsize=num_workers)

        # We have multiple upstream workers to be able to send multiple
        # requests upstream concurrently
        upstream_worker_tasks = [
            create_task(upstream_worker(sock, resolve, upstream_queue))
            for _ in range(0, num_workers)]

        try:
            while True:
                logger.info('Waiting for next request')
                request_data, addr = await recvfrom(loop, [sock], 512)
                request_logger = get_logger_adapter(
                    {'dnsrewriteproxy_requestid': ''.join(choices(request_id_alphabet, k=8))})
                request_logger.info('Received request from %s', addr)
                await upstream_queue.put((request_logger, request_data, addr))
        finally:
            logger.info('Stopping: waiting for requests to finish')
            await upstream_queue.join()

            logger.info('Stopping: cancelling workers...')
            for upstream_task in upstream_worker_tasks:
                upstream_task.cancel()
            for upstream_task in upstream_worker_tasks:
                try:
                    await upstream_task
                except CancelledError:
                    pass
            logger.info('Stopping: cancelling workers... (done)')

            logger.info('Stopping: final cleanup')
            await stop()
            logger.info('Stopping: done')

    async def upstream_worker(sock, resolve, upstream_queue):
        while True:
            request_logger, request_data, addr = await upstream_queue.get()

            try:
                request_logger.info('Processing request')
                response_data = await get_response_data(request_logger, resolve, request_data)
                # Sendto for non-blocking UDP sockets cannot raise a BlockingIOError
                # https://stackoverflow.com/a/59794872/1319998
                sock.sendto(response_data, addr)
            except Exception:
                request_logger.exception('Error processing request')
            finally:
                request_logger.info('Finished processing request')
                upstream_queue.task_done()

    async def get_response_data(request_logger, resolve, request_data):
        # This may raise an exception, which is handled at a higher level.
        # We can't [and I suspect shouldn't try to] return an error to the
        # client, since we're not able to extract the QID, so the client won't
        # be able to match it with an outgoing request
        query = parse(request_data)

        try:
            return pack(await proxy(request_logger, resolve, query))
        except Exception:
            request_logger.exception('Failed to proxy %s', query)
            return pack(error(query, ERRORS.SERVFAIL))

    async def proxy(request_logger, resolve, query):
        name_bytes = query.qd[0].name
        request_logger.info('Name: %s', name_bytes)

        name_str_lower = query.qd[0].name.lower().decode('idna')
        request_logger.info('Decoded: %s', name_str_lower)

        if query.qd[0].qtype != TYPES.A:
            request_logger.info('Unhandled query type: %s', query.qd[0].qtype)
            return error(query, ERRORS.REFUSED)

        for pattern, replace in rules:
            rewritten_name_str, num_matches = re.subn(pattern, replace, name_str_lower)
            if num_matches:
                request_logger.info('Matches rule (%s, %s)', pattern, replace)
                break
        else:
            # No break was triggered, i.e. no match
            request_logger.info('Does not match a rule')
            return error(query, ERRORS.REFUSED)

        try:
            ip_addresses = await resolve(
                rewritten_name_str, TYPES.A,
                get_logger_adapter=get_resolver_logger_adapter(request_logger))
        except DnsRecordDoesNotExist:
            request_logger.info('Does not exist')
            return error(query, ERRORS.NXDOMAIN)
        except DnsResponseCode as dns_response_code_error:
            request_logger.info('Received error from upstream: %s',
                                dns_response_code_error.args[0])
            return error(query, dns_response_code_error.args[0])

        request_logger.info('Resolved to %s', ip_addresses)
        now = loop.time()

        def ttl(ip_address):
            return int(max(0.0, ip_address.expires_at - now))

        reponse_records = tuple(
            ResourceRecord(name=name_bytes, qtype=TYPES.A,
                           qclass=1, ttl=ttl(ip_address), rdata=ip_address.packed)
            for ip_address in ip_addresses
        )
        return Message(
            qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
            qd=query.qd, an=reponse_records, ns=(), ar=(),
        )

    async def start():
        # The socket is created synchronously and passed to the server worker,
        # so if there is an error creating it, this function will raise an
        # exception. If no exeption is raise, we are indeed listening#
        sock = get_socket()

        # The resolver is also created synchronously, since it can parse
        # /etc/hosts or /etc/resolve.conf, and can raise an exception if
        # something goes wrong with that
        resolve, clear_cache = get_resolver()

        async def stop():
            sock.close()
            await clear_cache()

        return create_task(server_worker(sock, resolve, stop))

    return start


def error(query, rcode):
    return Message(
        qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
        qd=query.qd, an=(), ns=(), ar=(),
    )