uktrade/aioftps3

View on GitHub
aioftps3/server.py

Summary

Maintainability
A
2 hrs
Test Coverage
A
93%
import asyncio
from pathlib import (
    PurePosixPath,
)
import random
import stat

from aioftps3.server_logger import (
    get_child_logger,
    logged,
)

from aioftps3.server_s3 import (
    s3_delete,
    s3_get,
    s3_list,
    s3_mkdir,
    s3_put,
    s3_rename,
    s3_rmdir,
)

from aioftps3.server_socket import (
    recv_lines,
    recv_until_close,
    send_all,
    send_line,
    server,
    shutdown_socket,
    ssl_complete_handshake,
    ssl_get_socket,
    ssl_unwrap_socket,
)

from aioftps3.server_utils import (
    timeout,
)


# How long a command has to complete
COMMAND_TIMEOUT_SECONDS = 60

# How long a client has to connect to the PASV server once it's requested
DATA_CONNECT_TIMEOUT_SECONDS = 10

# How long a client has to issue the command to use the PASV server after
# the client has connected to it
DATA_COMMAND_TIMEOUT_SECONDS = 2

COMMAND_CHUNK_BYTES = 1024
DATA_CHUNK_SIZE = 1024 * 64


# We are very specific in terms of tasks created:
#
# - A task for the main server
# - Two tasks for each client connection:
#   - One for receiving and processing incoming commands
#   - One for sending outgoing command responses
# - A task for each client's data server: only running until the data transfer
#   started
# - A task for each data connection from the client: any more than one will
#   be very short-lived, e.g. if multiple clients connect in the brief amount
#   amount of time that the server has started, but not yet started the data
#   transfer, the second will close the connection and end the task
#
# The separation of incoming/outgoing on the client command connection is a
# consequence of wanting outgoing command responses from the _data_ task.
# Sending data is not atomic, e.g. it can decide to send a single byte at a
# time and yield to other tasks, so there could be a risk of sending corrupt
# responses unless we ensure only a response is sent at time. One way of doing
# this is a dedicated task for outgoing command data.
#
# Each task also keeps track of the tasks that it creates: on cancellation of
# the main task, it's the responsibility of each task to cancel its child
# tasks.
#
# Communication to S3 can happen from any task. There is locking to try to
# prevent race conditions


async def on_client_connect(logger, loop, get_ssl_context, sock, get_data_ip, data_ports,
                            is_data_sock_ok, is_user_correct, is_password_correct, s3_context):
    user = None
    is_authenticated = False
    ssl_sock = None
    cwd = PurePosixPath('/')

    rename_from = None

    data_server = None
    data_client = None
    data_port = None
    data_funcs = asyncio.Queue(maxsize=1)

    # Multiple concurrent sends on the same socket would be bad, so we only
    # allow them to be sent from a dedicated task, but other tasks can
    # queue them up
    command_responses = asyncio.Queue(maxsize=1)

    # Glue to lower level functions

    def get_sock():
        return ssl_sock if ssl_sock is not None else sock

    async def send_command_responses():
        while True:
            response = await command_responses.get()
            logger.debug('Out: %s', response)
            try:
                await send_line(loop, get_sock, COMMAND_CHUNK_BYTES, response)
            finally:
                command_responses.task_done()

    def command_sock_recv_lines():
        return recv_lines(loop, get_sock, COMMAND_CHUNK_BYTES)

    async def data_sock_send_line(data_sock, line):
        await send_line(loop, lambda: data_sock, DATA_CHUNK_SIZE, line)

    async def data_sock_send_all(data_sock, data):
        await send_all(loop, lambda: data_sock, DATA_CHUNK_SIZE, memoryview(data))

    # Deliberate quitting: to keep the number of code paths as small as
    # possible, uses the same method as if the whole server is shutting down,
    # using cancellation

    async def cancel_current_task():
        asyncio.current_task().cancel()
        # Causes the cancel exception to be raised, right here
        await asyncio.sleep(0)

    # Path manipulation

    def to_absolute_path(arg):
        requested_path = arg.decode('utf-8')

        absolute = \
            PurePosixPath(requested_path) if requested_path[0] == '/' else \
            cwd / PurePosixPath(requested_path)

        return absolute

    # Commands

    async def command_auth(_):
        nonlocal ssl_sock
        await command_responses.put(b'234 TLS negotiation will follow.')
        await command_responses.join()
        with logged(logger, 'Performing TLS handshake', []):
            ssl_sock = ssl_get_socket(logger, get_ssl_context, sock)
            await ssl_complete_handshake(loop, ssl_sock)

    async def command_syst(_):
        await command_responses.put(b'215 UNIX Type: L8')

    async def command_type(_):
        await command_responses.put(b'200 Command okay.')

    async def command_feat(_):
        await command_responses.put(b'211-System status, or system help reply.')
        await command_responses.put(b'UTF8')
        await command_responses.put(b'211 End')

    async def command_opts(_):
        await command_responses.put(b'200 Command okay.')

    async def command_pbsz(_):
        await command_responses.put(b'200 Command okay.')

    async def command_prot(_):
        await command_responses.put(b'200 Command okay.')

    async def command_stat(_):
        await command_responses.put(b'211')

    async def command_user(arg):
        nonlocal user

        attempted_user = arg.decode('utf-8')
        is_ok = await is_user_correct(attempted_user)

        if not is_ok:
            await command_responses.put(b'530 Not logged in.')
            await command_responses.join()
            await cancel_current_task()

        user = attempted_user
        await command_responses.put(b'331 User name okay, need password.')

    async def command_pass(arg):
        nonlocal is_authenticated

        attempted_password = arg.decode('utf-8')
        is_ok = await is_password_correct(logger, user, attempted_password)

        if not is_ok:
            await command_responses.put(b'530 Not logged in.')
            await command_responses.join()
            await cancel_current_task()

        is_authenticated = True
        await command_responses.put(b'230 User logged in, proceed.')

    async def command_pwd(_):
        await command_responses.put(
            b'257 "%s"' % cwd.as_posix().encode('utf-8').replace(b'"', b'""'))

    async def command_mkd(arg):
        s3_path = to_absolute_path(arg)
        await s3_mkdir(logger, s3_context, s3_path)
        await command_responses.put(b'250 Requested file action okay, completed.')

    async def command_rmd(arg):
        s3_path = to_absolute_path(arg)
        await s3_rmdir(logger, s3_context, s3_path)
        await command_responses.put(b'250 Requested file action okay, completed.')

    async def command_cdup(_):
        nonlocal cwd
        cwd = cwd.parent
        await command_responses.put(b'250 Requested file action okay, completed.')

    async def command_cwd(arg):
        nonlocal cwd
        cwd = to_absolute_path(arg)
        await command_responses.put(b'250 Requested file action okay, completed.')

    async def command_dele(arg):
        s3_path = to_absolute_path(arg)
        await s3_delete(logger, s3_context, s3_path)
        await command_responses.put(b'250 Requested file action okay, completed.')

    async def command_rest(arg):
        rest_from = int(arg)
        if rest_from == 0:
            await command_responses.put(b'350 Requested file action pending further information.')
        else:
            await cancel_current_task()

    async def command_list(_):
        s3_path = cwd

        async def data_task_func(ssl_data_sock):
            async for list_path in await s3_list(logger, s3_context, s3_path):
                await data_sock_send_line(ssl_data_sock, (
                    stat.filemode(list_path.stat.st_mode) + ' ' +
                    str(list_path.stat.st_nlink) + ' ' +
                    'none ' +
                    'none ' +
                    str(list_path.stat.st_size) + ' ' +
                    list_path.stat.st_mtime.strftime('%b %e %H:%M') + ' ' +
                    list_path.name
                ).encode('utf-8'))

        await data_funcs.put(data_task_func)
        await command_responses.put(b'150 File status okay; about to open data connection.')

    async def command_retr(arg):
        s3_path = to_absolute_path(arg)

        async def data_task_func(ssl_data_sock):
            async for data in s3_get(logger, s3_context, s3_path, DATA_CHUNK_SIZE):
                await data_sock_send_all(ssl_data_sock, data)

        await data_funcs.put(data_task_func)
        await command_responses.put(b'150 File status okay; about to open data connection.')

    async def command_stor(arg):
        s3_path = to_absolute_path(arg)

        async def data_task_func(ssl_data_sock):
            async with s3_put(logger, s3_context, s3_path) as write:
                async for data in recv_until_close(loop, lambda: ssl_data_sock, DATA_CHUNK_SIZE):
                    await write(data)

        await data_funcs.put(data_task_func)
        await command_responses.put(b'150 File status okay; about to open data connection.')

    async def command_rnfr(arg):
        nonlocal rename_from
        rename_from = to_absolute_path(arg)
        await command_responses.put(b'350 Requested file action pending further information.')

    async def command_rnto(arg):
        nonlocal rename_from

        _rename_to = to_absolute_path(arg)
        _rename_from = rename_from
        rename_from = None

        await s3_rename(logger, s3_context, _rename_from, _rename_to)

        await command_responses.put(b'250 Requested file action okay, completed.')

    async def command_pasv(_):
        nonlocal data_port
        nonlocal data_server

        data_server_listening = asyncio.Future()
        data_client_connected = asyncio.Future()

        def on_data_server_listening(success):
            if data_server_listening.done():
                return
            if success:
                data_server_listening.set_result(None)
            else:
                data_server_listening.set_exception(Exception('Unable to listen'))

        async def on_data_client_connect(data_client_logger, __, ____, data_sock):
            nonlocal data_client

            # Raise if this is the second, and so unexpected, client
            data_client_connected.set_result(None)

            if not await is_data_sock_ok(get_sock(), data_sock):
                raise Exception('Data sock is not ok: could be an attack')

            data_client = asyncio.current_task()

            # If we do have an expected data client, we cancel the server
            # to prevent more connections
            data_server.cancel()
            await asyncio.sleep(0)

            try:
                with logged(data_client_logger, 'Performing TLS handshake', []):
                    ssl_data_sock = ssl_get_socket(logger, get_ssl_context, data_sock)
                    await ssl_complete_handshake(loop, ssl_data_sock)

                async with timeout(loop, DATA_COMMAND_TIMEOUT_SECONDS):
                    func = await data_funcs.get()

                await func(ssl_data_sock)
            except BaseException:
                await command_responses.put(b'426 Connection closed; transfer aborted.')
                raise
            else:
                await command_responses.put(b'226 Closing data connection.')
            finally:
                data_funcs.task_done()
                data_client = None
                data_sock = await ssl_unwrap_socket(loop, ssl_data_sock, data_sock)
                await shutdown_socket(loop, data_sock)

        async def on_data_server_cancel(_):  # Client tasks passed
            # Since cancellation of the data server is triggered at the
            # beginning of the data client task, which would call this
            # function, we don't cancel the child task here, otherwise we'll
            # be cancelling what we want to be running. On main server cancel,
            # the data client task is cancelled, near the bottom of
            # on_client_connect
            pass

        def on_data_server_close(_):
            nonlocal data_port
            nonlocal data_server

            data_ports.add(data_port)
            data_port = None
            data_server = None

        data_port = random.sample(data_ports, 1)[0]
        data_ports.remove(data_port)
        data_logger = get_child_logger(logger, 'data')
        data_server = loop.create_task(server(data_logger, loop, get_ssl_context, data_port,
                                              on_data_server_listening, on_data_client_connect,
                                              on_data_server_cancel))
        data_server.add_done_callback(on_data_server_close)

        data_port_higher = str(data_port >> 8).encode('ascii')
        data_port_lower = str(data_port & 0xff).encode('ascii')
        data_ip = [part.encode('ascii') for part in (await get_data_ip(get_sock())).split('.')]

        response = b'227 Entering Passive Mode (%s,%s,%s,%s,%s,%s)' % (
            data_ip[0], data_ip[1], data_ip[2], data_ip[3],
            data_port_higher, data_port_lower)

        try:
            # No timeout: we're waiting for the task sheduler to get through to the point
            # where we're listening for a connection, and not waiting for a remote party
            await data_server_listening
            await command_responses.put(response)

            async with timeout(loop, DATA_CONNECT_TIMEOUT_SECONDS):
                await data_client_connected
        except BaseException:
            await cancel_data_tasks()
            raise

    async def command_quit(_):
        await command_responses.put(b'221 Service closing control connection.')
        await command_responses.join()
        await cancel_current_task()

    async def cancel_data_tasks():
        for task in [task for task in [data_client, data_server] if task]:
            task.cancel()
            await asyncio.sleep(0)

    def get_command_func(parent_locals, command):
        command_lower = command.lower()
        return parent_locals[f'command_{command_lower}']

    def is_implemented(parent_locals, command):
        command_lower = command.lower()
        return f'command_{command_lower}' in parent_locals

    def is_good_sequence(command):
        is_ssl = ssl_sock is not None

        is_good = \
            (command == 'AUTH' and not is_ssl and not is_authenticated and not user) or \
            (command == 'USER' and is_ssl and not is_authenticated and not user) or \
            (command == 'PASS' and is_ssl and not is_authenticated and user) or \
            (command == 'PROT' and is_ssl and not is_authenticated and not user) or \
            (command == 'PBSZ' and is_ssl and not is_authenticated and not user) or \
            (command == 'PASV' and is_ssl and is_authenticated and not data_port) or \
            (command == 'RNTO' and is_ssl and is_authenticated and rename_from is not None) or \
            (command in {'LIST', 'STOR', 'RETR'} and is_authenticated and data_client) or \
            (command not in {'AUTH', 'USER', 'PASS',
                             'PASV', 'LIST', 'STOR', 'RETR',
                             'RNTO'} and is_ssl and is_authenticated)

        return is_good

    async def main_client_loop(parent_locals):
        async for line in command_sock_recv_lines():
            command_bytes, _, arg = line.partition(b' ')
            command = command_bytes.decode('utf-8')
            arg_to_log = arg if command != 'PASS' else b'********'
            logger.debug('Inc: %s', command_bytes + b' ' + arg_to_log)

            async with timeout(loop, COMMAND_TIMEOUT_SECONDS):
                if not is_good_sequence(command):
                    await command_responses.put(b'503 Bad sequence of commands.')
                    await command_responses.join()
                    await cancel_current_task()
                elif not is_implemented(parent_locals, command):
                    await command_responses.put(b'502 Command not implemented.')
                else:
                    with logged(logger, command, []):
                        await get_command_func(parent_locals, command)(arg)

    send_command_responses_task = asyncio.create_task(send_command_responses())
    await command_responses.put(b'220 Service ready for new user.')

    try:
        await main_client_loop(locals())
    finally:
        await cancel_data_tasks()
        send_command_responses_task.cancel()
        await asyncio.sleep(0)

        if ssl_sock is not None:
            sock = await ssl_unwrap_socket(loop, ssl_sock, sock)
        await shutdown_socket(loop, sock)