uktrade/aioftps3

View on GitHub
aioftps3/server_s3.py

Summary

Maintainability
B
6 hrs
Test Coverage
A
91%
import asyncio
from collections import (
    namedtuple,
)
from contextlib import (
    AsyncExitStack,
    asynccontextmanager,
)
from datetime import datetime
import hashlib
import json
from pathlib import PurePosixPath
import re
import weakref
import xml.etree.ElementTree as ET

from fifolock import FifoLock

from aioftps3.server_aws import (
    aws_request,
)
from aioftps3.server_logger import (
    logged,
)


# This must be between 5 and 2000MB
MULTIPART_UPLOAD_MIN_BYTES = 1024 * 1024 * 25

# How long we sleep per incoming write if uploads haven't kept up
MULTIPART_UPLOAD_IF_SLOW_SLEEP_SECONDS = 1

# How many in-progress uploads per file we allow per uploading file
MULTIPART_UPLOAD_MAX_CONCURRENT_UPLOADS_PER_FILE = 3

REG_MODE = 0o10666  # stat.S_IFREG | 0o666
DIR_MODE = 0o40777  # stat.S_IFDIR | 0o777

S3_DIR_SEPARATOR = '/'


Stat = namedtuple(
    'Stat',
    ['st_size', 'st_mtime', 'st_ctime', 'st_nlink', 'st_mode'],
)

AwsCredentials = namedtuple('AwsCredentials', [
    'access_key_id', 'secret_access_key', 'pre_auth_headers',
])

S3Bucket = namedtuple('AwsS3Bucket', [
    'region', 'host', 'verify_certs', 'name', 'dir_suffix',
])

S3Context = namedtuple('Context', [
    'session', 'lock', 'credentials', 'bucket',
])

ListKey = namedtuple('ListKey', [
    'key', 'size', 'last_modified',
])

ListPrefix = namedtuple('ListPrefix', [
    'prefix',
])


class S3ListPath(PurePosixPath):

    def __new__(cls, *args, stat):
        self = super().__new__(cls, *args)
        self.stat = stat
        return self


def get_secret_access_key_credentials(access_key_id, secret_access_key):

    async def get(_, __):
        return AwsCredentials(
            access_key_id=access_key_id,
            secret_access_key=secret_access_key,
            pre_auth_headers={},
        )

    return get


def get_ecs_role_credentials(url):

    aws_access_key_id = None
    aws_secret_access_key = None
    token = None
    expiration = datetime(1900, 1, 1)

    async def get(logger, session):
        nonlocal aws_access_key_id
        nonlocal aws_secret_access_key
        nonlocal token
        nonlocal expiration

        now = datetime.now()

        if now > expiration:
            method = 'GET'
            with logged(logger, 'Requesting temporary credentials from %s', [url]):
                async with session.request(method, url) as response:
                    response.raise_for_status()
                    creds = json.loads(await response.read())

            aws_access_key_id = creds['AccessKeyId']
            aws_secret_access_key = creds['SecretAccessKey']
            token = creds['Token']
            expiration = datetime.strptime(creds['Expiration'], '%Y-%m-%dT%H:%M:%SZ')

        return AwsCredentials(
            access_key_id=aws_access_key_id,
            secret_access_key=aws_secret_access_key,
            pre_auth_headers={
                'x-amz-security-token': token,
            },
        )

    return get


def get_s3_bucket(region, host, verify_certs, name, dir_suffix):
    return S3Bucket(
        region=region,
        host=host,
        verify_certs=verify_certs,
        name=name,
        dir_suffix=dir_suffix,
    )


def get_s3_context(session, credentials, bucket):
    return S3Context(
        session=session,
        lock=_PathLock(),
        credentials=credentials,
        bucket=bucket,
    )


async def s3_exists(logger, context, path):
    return await _exists(logger, context, path)


async def s3_is_dir(logger, context, path):
    return await _is_dir(logger, context, path)


async def s3_is_file(logger, context, path):
    return await _is_file(logger, context, path)


async def s3_mkdir(logger, context, path):
    async with context.lock(logger, [path]):
        if await _exists(logger, context, path):
            raise Exception('{} already exists'.format(path))

        if not await _is_dir(logger, context, path.parent):
            raise Exception('{} is not a directory'.format(path.parent))

        await _mkdir(logger, context, path)


async def s3_rmdir(logger, context, path):
    async with context.lock(logger, [path]):
        if not await _is_dir(logger, context, path):
            raise Exception('{} does not exist'.format(path))

        return await _rmdir(logger, context, path)


async def s3_delete(logger, context, path):
    async with context.lock(logger, [path]):
        if not await _is_file(logger, context, path):
            raise Exception('{} does not exist'.format(path))

        return await _delete(logger, context, path)


async def s3_list(logger, context, path):
    # We allow lists to see non-consistent changes: renames or deletes
    # of folders with lots of files could be taking a while, and nothing
    # too horrible will happen if there is a list half way through
    if not await _is_dir(logger, context, path):
        raise Exception('{} is not a directory'.format(path))

    return _list(logger, context, path)


def s3_get(logger, context, path, chunk_size):
    return _get(logger, context, path, chunk_size)


@asynccontextmanager
async def s3_put(logger, context, path):
    async with _put(logger, context, path) as write:
        yield write


async def s3_rename(logger, context, rename_from, rename_to):
    async with context.lock(logger, [rename_from, rename_to]):
        return await _rename(logger, context, rename_from, rename_to)


def _key(path):
    key = \
        '' if path == PurePosixPath('/') else \
        path.relative_to(PurePosixPath('/')).as_posix()

    return key


def _dir_key(context, path):
    key = \
        '' if path == PurePosixPath('/') else \
        path.relative_to(PurePosixPath('/')).as_posix() + context.bucket.dir_suffix

    return key


def _dir_prefix(path):
    key = \
        '' if path == PurePosixPath('/') else \
        path.relative_to(PurePosixPath('/')).as_posix() + S3_DIR_SEPARATOR

    return key


async def _exists(logger, context, path):
    return \
        await _is_file(logger, context, path) or \
        await _is_dir(logger, context, path)


async def _is_dir(logger, context, path):
    result = \
        True if path == PurePosixPath('/') else \
        await _dir_exists(logger, context, path)
    return result


async def _is_file(logger, context, path):
    result = \
        False if path == PurePosixPath('/') else \
        await _file_exists(logger, context, path)
    return result


async def _file_exists(logger, context, path):
    response, _ = await s3_request_full(logger, context, 'HEAD', '/' + _key(path), {}, {},
                                        b'', s3_hash(b''))
    return response.status == 200


async def _dir_exists(logger, context, path):
    response, _ = await s3_request_full(logger, context, 'HEAD', '/' + _dir_key(context, path),
                                        {}, {}, b'', s3_hash(b''))
    return response.status == 200


async def _mkdir(logger, context, path):
    response, _ = await s3_request_full(logger, context, 'PUT', '/' + _dir_key(context, path),
                                        {}, {}, b'', s3_hash(b''))
    response.raise_for_status()


async def _rmdir(logger, context, path):
    keys = [key async for key in _list_descendant_keys(logger, context, _dir_prefix(path))]

    def delete_sort_key(key):
        # Delete innermost files and folders first
        dir_suffix = context.bucket.dir_suffix
        is_dir_file = 1 if key.key[-len(dir_suffix):] == dir_suffix else 0
        return (key.key.count('/'), is_dir_file, key.key)

    for key in sorted(keys, key=delete_sort_key, reverse=True):
        response, _ = await s3_request_full(logger, context, 'DELETE', '/' + key.key, {}, {},
                                            b'', s3_hash(b''))
        response.raise_for_status()


async def _rename(logger, context, rename_from, rename_to):
    # The source must exist...
    from_exists = await _exists(logger, context, rename_from)
    if not from_exists:
        raise Exception('The source file does not exist')

    # ... but the target may or may not exist
    to_exists = await _exists(logger, context, rename_to)
    logger.debug('File %s exists: %s', to_exists)

    # ... we find the list of keys to rename from/to

    source_is_dir = await _is_dir(logger, context, rename_from)
    rename_from_key = _key(rename_from)
    rename_to_key = _key(rename_to)

    def replace_key_prefix(old_key):
        return rename_to_key + old_key[len(rename_from_key):]

    from_keys = \
        [
            key.key
            async for key in _list_descendant_keys(logger, context, _dir_prefix(rename_from))
        ] if source_is_dir else \
        [rename_from_key]

    to_keys = [replace_key_prefix(key) for key in from_keys]

    renames = list(zip(from_keys, to_keys))

    # ... we copy everything first...

    def sort_key(keys):
        dir_suffix = context.bucket.dir_suffix
        is_dir_file = 1 if keys[0][-len(dir_suffix):] == dir_suffix else 0
        return (keys[0].count('/'), is_dir_file, keys[0])

    for from_key, to_key in sorted(renames, key=sort_key, reverse=True):
        headers = {
            'x-amz-copy-source': f'/{context.bucket.name}/{from_key}',
        }
        response, _ = await s3_request_full(logger, context, 'PUT', '/' + to_key, {}, headers,
                                            b'', s3_hash(b''))
        response.raise_for_status()

    # ... and then delete the originals

    for from_key, _ in sorted(renames, key=sort_key):
        response, _ = await s3_request_full(logger, context, 'DELETE', '/' + from_key, {}, {},
                                            b'', s3_hash(b''))
        response.raise_for_status()


async def _list(logger, context, path):
    async for child_path in _list_immediate_child_paths(logger, context, _dir_prefix(path)):
        yield child_path


async def _delete(logger, context, path):
    response, _ = await s3_request_full(logger, context, 'DELETE', '/' + _key(path), {}, {},
                                        b'', s3_hash(b''))
    response.raise_for_status()


@asynccontextmanager
async def _put(logger, context, path):
    upload_id = None
    part_uploads = []

    part_length = None
    part_chunks = None
    part_payload_hash = None

    async def start():
        nonlocal upload_id
        new_part_init()
        upload_id = await _multipart_upload_start(logger, context, path)

    def new_part_init():
        nonlocal part_length
        nonlocal part_chunks
        nonlocal part_payload_hash
        part_length = 0
        part_chunks = []
        part_payload_hash = hashlib.sha256()

    async def write(chunk):
        # If ingress is faster than egress, need to do something to avoid storing
        # the entire file in memory. Could have something better/fancier but in
        # our case, suspect egress to S3 will always be faster than ingress, so
        # we keep this simple and just prevent from running out of memory in case
        # it happens, but also prevent the connection from dying
        if len(part_uploads) > 2 and not part_uploads[-2].done():
            await asyncio.sleep(MULTIPART_UPLOAD_IF_SLOW_SLEEP_SECONDS)

            in_progress = [upload for upload in part_uploads if not upload.done()]
            if len(in_progress) > MULTIPART_UPLOAD_MAX_CONCURRENT_UPLOADS_PER_FILE:
                raise Exception('Too many incomplete uploads to S3')

        nonlocal part_length
        part_length += len(chunk)
        part_chunks.append(chunk)
        part_payload_hash.update(chunk)

        if part_length >= MULTIPART_UPLOAD_MIN_BYTES:
            upload_part()

    def upload_part():
        part_number = str(len(part_uploads) + 1)
        upload_coro = _multipart_upload_part(
            logger, context, path, upload_id,
            part_number, part_length, part_chunks, part_payload_hash.hexdigest())
        part_uploads.append(asyncio.create_task(upload_coro))
        new_part_init()

    async def end():
        # AWS throws an error if the multipart upload doesn't have at least one part
        # which would happen if write is never called, i.e. for an empty file
        if not part_uploads:
            await write(b'')

        if part_chunks:
            upload_part()

        indexes_and_etags = await asyncio.gather(*part_uploads)

        # Completing the upload is the only action that changes anything (apart from
        # APIs that query multipart uploads), so this all we need to wrap in a lock.
        # A bucket lifecycle policy can cleanup unfinished multipart uploads, so we
        # don't need to do it here
        async with context.lock(logger, [path]):
            if await _is_file(logger, context, path.parent):
                raise Exception('{} is not a directory'.format(path.parent))

            if not await _is_dir(logger, context, path.parent):
                raise Exception('{} does not exist'.format(path.parent))

            if await _is_dir(logger, context, path):
                raise Exception('{} is a directory'.format(path))

            # Overwrites of files are allowed, so there is no need to check if
            # the file aleady exists

            await _multipart_upload_complete(logger, context, path, upload_id, indexes_and_etags)

        new_part_init()

    await start()
    yield write
    await end()


async def _multipart_upload_start(logger, context, path):
    query = {
        'uploads': '',
    }
    response, body = await s3_request_full(logger, context, 'POST', '/' + _key(path), query, {},
                                           b'', s3_hash(b''))
    response.raise_for_status()
    return re.search(b'<UploadId>(.*)</UploadId>', body)[1].decode('utf-8')


async def _multipart_upload_part(logger, context, path, upload_id, part_number, part_length,
                                 part_chunks, part_payload_hash):
    async def aiter(iterable):  # pylint: disable=redefined-builtin
        for item in iterable:
            await asyncio.sleep(0)
            yield item

    part_payload = aiter(part_chunks)
    query = {
        'partNumber': part_number,
        'uploadId': upload_id,
    }
    headers = {'Content-Length': str(part_length)}
    response, _ = await s3_request_full(logger, context, 'PUT', '/' + _key(path), query, headers,
                                        part_payload, part_payload_hash)
    response.raise_for_status()
    part_etag = response.headers['ETag']

    return (part_number, part_etag)


async def _multipart_upload_complete(logger, context, path, upload_id, part_numbers_and_etags):

    payload = (
        '<CompleteMultipartUpload>' +
        ''.join(
            f'<Part><PartNumber>{part_number}</PartNumber><ETag>{part_etag}</ETag></Part>'
            for part_number, part_etag in part_numbers_and_etags
        ) +
        '</CompleteMultipartUpload>'
    ).encode('utf-8')
    payload_hash = s3_hash(payload)
    query = {
        'uploadId': upload_id,
    }
    response, _ = await s3_request_full(logger, context, 'POST', '/' + _key(path), query, {},
                                        payload, payload_hash)
    response.raise_for_status()


async def _get(logger, context, path, chunk_size):
    # S3 GETs are atomic, so we don't need any additional locking
    # It will either fail with a 404, or return a consistent object

    method = 'GET'
    s3_path = '/' + _key(path)
    query = {}
    with logged(logger, 'Request: %s %s %s %s', [method, context.bucket.host, s3_path, query]):
        async with await _s3_request(logger, context, method, s3_path, query, {},
                                     b'', s3_hash(b'')) as response:
            response.raise_for_status()
            async for data in response.content.iter_chunked(chunk_size):
                yield data


async def _list_immediate_child_paths(logger, context, key_prefix):
    epoch = datetime.utcfromtimestamp(0)
    dir_suffix = context.bucket.dir_suffix

    async for (prefix_page, key_page) in _list_keys(logger, context, key_prefix, S3_DIR_SEPARATOR):
        for list_prefix in prefix_page:
            yield S3ListPath(list_prefix.prefix, stat=Stat(
                # Not completely sure what size should be for a directory
                st_size=0,
                # Can't quite work out an efficient way of working out
                # any sort of meaningful modification/creation time for a
                # directory
                st_mtime=epoch,
                st_ctime=epoch,
                st_nlink=1,
                st_mode=DIR_MODE,
            ))

        for list_key in key_page:
            if list_key.key[-len(dir_suffix):] == dir_suffix:
                continue
            yield S3ListPath(list_key.key, stat=Stat(
                st_size=list_key.size,
                st_mtime=list_key.last_modified,
                st_ctime=list_key.last_modified,
                st_nlink=1,
                st_mode=REG_MODE,
            ))


async def _list_descendant_keys(logger, context, key_prefix):
    async for (_, key_page) in _list_keys(logger, context, key_prefix, ''):
        for list_key in key_page:
            yield list_key


async def _list_keys(logger, context, key_prefix, delimeter):
    common_query = {
        'max-keys': '1000',
        'list-type': '2',
        'delimiter': delimeter,
        'prefix': key_prefix,
    }

    async def _list_first_page():
        query = {
            **common_query,
        }
        response, body = await s3_request_full(logger, context, 'GET', '/', query, {},
                                               b'', s3_hash(b''))
        response.raise_for_status()
        return _parse_list_response(body)

    async def _list_later_page(token):
        query = {
            **common_query,
            'continuation-token': token,
        }
        response, body = await s3_request_full(logger, context, 'GET', '/', query, {},
                                               b'', s3_hash(b''))
        response.raise_for_status()
        return _parse_list_response(body)

    def _first_child_text(element, tag):
        for child in element:
            if child.tag == tag:
                return child.text
        return None

    def _parse_list_response(body):
        namespace = '{http://s3.amazonaws.com/doc/2006-03-01/}'
        root = ET.fromstring(body)
        next_token = ''
        keys = []
        prefixes = []
        for element in root:
            if element.tag == f'{namespace}Contents':
                key = _first_child_text(element, f'{namespace}Key')
                last_modified_str = _first_child_text(element, f'{namespace}LastModified')
                last_modified_datetime = datetime.strptime(
                    last_modified_str, '%Y-%m-%dT%H:%M:%S.%fZ')
                size = int(_first_child_text(element, f'{namespace}Size'))
                keys.append(ListKey(key=key, last_modified=last_modified_datetime, size=size))

            if element.tag == f'{namespace}CommonPrefixes':
                # Prefixes end in '/', which we strip off
                prefix = _first_child_text(element, f'{namespace}Prefix')[:-1]
                prefixes.append(ListPrefix(prefix=prefix))

            if element.tag == f'{namespace}NextContinuationToken':
                next_token = element.text

        return (next_token, prefixes, keys)

    token, prefixes_page, keys_page = await _list_first_page()
    yield (prefixes_page, keys_page)

    while token:
        token, prefixes_page, keys_page = await _list_later_page(token)
        yield (prefixes_page, keys_page)


def s3_hash(payload):
    return hashlib.sha256(payload).hexdigest()


async def s3_request_full(logger, context, method, path, query, api_pre_auth_headers,
                          payload, payload_hash):

    with logged(logger, 'Request: %s %s %s %s %s',
                [method, context.bucket.host, path, query, api_pre_auth_headers]):
        async with await _s3_request(logger, context, method, path, query, api_pre_auth_headers,
                                     payload, payload_hash) as result:
            return result, await result.read()


async def _s3_request(logger, context, method, path, query, api_pre_auth_headers,
                      payload, payload_hash):
    bucket = context.bucket
    return await aws_request(
        logger, context.session, 's3', bucket.region, bucket.host, bucket.verify_certs,
        context.credentials, method, f'/{bucket.name}{path}', query, api_pre_auth_headers,
        payload, payload_hash)


class Read(asyncio.Future):
    @staticmethod
    def is_compatible(holds):
        return not holds[Write]


class Write(asyncio.Future):
    @staticmethod
    def is_compatible(holds):
        return not holds[Read] and not holds[Write]


class _PathLock():

    # https://people.eecs.berkeley.edu/~kubitron/courses/cs262a-F14/projects/reports/project6_report.pdf
    #
    # Inspired by ^ but a simpler version without the distinction between
    # "path" and "data" locks: I couldn't see the benefit of that

    def __init__(self):
        self._locks = weakref.WeakValueDictionary()

    @staticmethod
    def _sort_key(path_lock):
        return len(path_lock[0].parents), path_lock[0].as_posix()

    def _with_locks(self, paths, mode):
        return [
            (path, self._locks.setdefault(path, default=FifoLock()), mode)
            for path in paths
        ]

    @asynccontextmanager
    async def __call__(self, logger, paths):
        writable_paths = set(paths)
        writable_locks = self._with_locks(writable_paths, Write)

        ancestor_paths = _flatten(path.parents for path in paths)
        readable_paths = set(ancestor_paths) - writable_paths
        readable_locks = self._with_locks(readable_paths, Read)

        sorted_locks = sorted(readable_locks + writable_locks, key=self._sort_key)
        async with AsyncExitStack() as stack:
            for path, lock, mode in sorted_locks:
                with logged(logger, 'Locking %s on %s', [mode, path]):
                    await stack.enter_async_context(lock(mode))

            yield

            for path, _, _ in reversed(sorted_locks):
                logger.debug('Unlocking %s', path)


def _flatten(to_flatten):
    return [
        item
        for sub_list in to_flatten
        for item in sub_list
    ]