Shoobx/shoobx.mocks3

View on GitHub
src/shoobx/mocks3/models.py

Summary

Maintainability
D
2 days
Test Coverage
###############################################################################
#
# Copyright 2016-2022 by Shoobx, Inc.
#
###############################################################################
"""Shoobx S3 Backend
"""
import base64
import codecs
import collections.abc
import datetime
import hashlib
import json
import os
import shutil

import pytz
import requests.structures
from moto import settings
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
from moto.core import BackendDict
from moto.core.utils import (
    iso_8601_datetime_with_milliseconds,
    iso_8601_datetime_without_milliseconds_s3,
    rfc_1123_datetime,
)
from moto.s3 import models


def _encode_name(name):
    return name.replace("/", "__sl__")


def _decode_name(name):
    return name.replace("__sl__", "/")


# See http://docs.getmoto.org/en/latest/docs/multi_account.html
MOTO_DEFAULT_ACCOUNT_ID = "12345678910"


class _InfoProperty:
    def __init__(self, name):
        self.name = name

    def __get__(self, inst, cls):
        if not os.path.exists(inst._info_path):
            return None
        with open(inst._info_path) as file:
            return json.load(file).get(self.name)

    def __set__(self, inst, value):
        if isinstance(value, bytes):
            value = value.decode("utf-8")
        with open(inst._info_path) as file:
            info = json.load(file)
        info[self.name] = value
        with open(inst._info_path, "w") as file:
            json.dump(info, file)


class _AclProperty(_InfoProperty):
    def __get__(self, inst, cls):
        raw_data = super().__get__(inst, cls)
        if raw_data is None:
            return models.get_canned_acl("private")
        return models.FakeAcl(
            [
                models.FakeGrant(
                    [models.FakeGrantee(**grantee) for grantee in grant["grantees"]],
                    grant["permissions"],
                )
                for grant in raw_data
            ]
        )

    def __set__(self, inst, value):
        with open(inst._info_path) as file:
            info = json.load(file)
        if value is None:
            info[self.name] = None
        else:
            info[self.name] = [
                {
                    "grantees": [
                        {
                            "grantee_id": grantee.id,
                            "uri": grantee.uri,
                            "display_name": grantee.display_name,
                        }
                        for grantee in grant.grantees
                    ],
                    "permissions": grant.permissions,
                }
                for grant in value.grants
            ]
        with open(inst._info_path, "w") as file:
            json.dump(info, file)


class Key(models.FakeKey):
    _last_modified = _InfoProperty("last_modified")
    storage_class = _InfoProperty("storage_class")
    metadata = _InfoProperty("metadata")
    _etag = _InfoProperty("etag")
    expiry_date = _InfoProperty("expiry_date")
    acl = _AclProperty("acl")

    def __init__(
        self,
        bucket,
        name,
        version=0,
        is_versioned=False,
        multipart=None,
        bucket_name=None,
        encryption=None,
        kms_key_id=None,
        bucket_key_enabled=None,
        lock_mode=None,
        lock_legal_status="OFF",
        lock_until=None,
        checksum_value=None,
    ):
        self.bucket = bucket
        self.name = name
        self.version = version
        self._is_versioned = is_versioned
        self.multipart = multipart
        self._path = os.path.join(bucket._path, "keys", _encode_name(name))
        self._versioned_path = os.path.join(self._path, str(version))
        self._info_path = os.path.join(self._versioned_path, "info.json")
        self._value_path = os.path.join(self._versioned_path, "value")
        self.bucket_name = bucket_name
        self.encryption = encryption
        self.kms_key_id = kms_key_id
        self.bucket_key_enabled = bucket_key_enabled
        self.lock_mode = lock_mode
        self.lock_legal_status = lock_legal_status
        self.lock_until = lock_until
        self._tick = 0
        self.disposed = None
        self.checksum_algorithm = "md5"

    def __getstate__(self):
        return self.__dict__.copy()

    def __setstate__(self, state):
        self.__dict__.update({k: v for k, v in state.items() if k != "value"})

    @property
    def _version_id(self):
        return self.version

    @_version_id.setter
    def _version_id(self, value):
        self.version = value

    @property
    def value(self):
        with open(self._value_path, "rb") as file:
            return file.read()

    @value.setter
    def value(self, data):
        if not isinstance(data, (bytes, bytearray)):
            data = data.encode("utf-8")
        with open(self._value_path, "wb") as file:
            file.write(data)

    @property
    def etag(self):
        if self._etag is None:
            with open(self._value_path, "rb") as file:
                # The file might be *very* large. Don't try to do it all at once.
                file_hash = hashlib.md5()
                while chunk := file.read(8192):
                    file_hash.update(chunk)
                self._etag = file_hash.hexdigest()
        return f'"{self._etag}"'

    @property
    def checksum_value(self):
        return self._etag

    @checksum_value.setter
    def checksum_value(self, value):
        # self.etag already generates an md5 hash if needed.
        self.etag

    @property
    def last_modified(self):
        return datetime.datetime.strptime(self._last_modified, "%Y-%m-%dT%H:%M:%S.%fZ")

    @property
    def last_modified_ISO8601(self):
        return self._last_modified

    @property
    def last_modified_RFC1123(self):
        # Different datetime formats depending on how the key is obtained
        # https://github.com/boto/boto/issues/466
        return rfc_1123_datetime(self.last_modified)

    @property
    def response_dict(self):
        r = {
            "etag": self.etag,
            "last-modified": self.last_modified_RFC1123,
            "content-length": str(len(self.value)),
        }
        if self.storage_class is not None:
            r["x-amz-storage-class"] = self.storage_class
        if self.expiry_date is not None:
            rhdr = 'ongoing-request="false", expiry-date="{0}"'
            r["x-amz-restore"] = rhdr.format(self.expiry_date)

        if self.bucket.is_versioned:
            r["x-amz-version-id"] = str(self.version)

        return r

    @property
    def size(self):
        return os.path.getsize(self._value_path)

    def exists(self):
        return os.path.exists(self._versioned_path)

    def create(self, value, storage="STANDARD", etag=None):
        if not os.path.exists(self._versioned_path):
            os.makedirs(self._versioned_path)
        self.value = value
        with open(self._info_path, "w") as file:
            json.dump(
                {
                    "last_modified": iso_8601_datetime_without_milliseconds_s3(
                        datetime.datetime.utcnow()
                    ),
                    "storage_class": storage,
                    "metadata": {},
                    "expiry_date": None,
                    "etag": etag,
                },
                file,
            )
        self.set_acl(models.get_canned_acl("private"))

    def delete(self):
        shutil.rmtree(self._path)

    def copy(self, new_name=None, new_is_versioned=None):
        new_path = os.path.join(self.bucket._path, "keys", new_name)
        os.mkdir(new_path)
        new_versioned_path = os.path.join(new_path, str(self.version))
        shutil.copytree(self._versioned_path, new_versioned_path)
        return Key(
            self.bucket, new_name, version=self.version, is_versioned=new_is_versioned
        )

    def set_metadata(self, metadata, replace=False):
        md = self.metadata if not replace else {}
        md.update(metadata)
        self.metadata = md

    def set_storage_class(self, storage_class):
        self.storage_class = storage_class

    def set_acl(self, acl):
        self.acl = acl

    def append_to_value(self, value):
        if self.bucket.is_versioned:
            old_path = self._versioned_path
            self.__init__(self.bucket, self.name, self.version + 1)
            os.rename(old_path, self._versioned_path)
            self.create(value)

        self.value += self.value
        self.last_modified = datetime.datetime.utcnow()
        self._etag = None  # must recalculate etag

    def restore(self, days):
        expiry = datetime.datetime.utcnow() + datetime.timedelta(days)
        self.expiry_date = expiry.strftime("%a, %d %b %Y %H:%M:%S GMT")

    @classmethod
    def get_versions(cls, bucket, name):
        key_dir = os.path.join(bucket._path, "keys", _encode_name(name))
        if not os.path.exists(key_dir):
            return []
        return sorted(
            (Key(bucket, name, int(version)) for version in os.listdir(key_dir)),
            key=lambda k: k.version,
        )


class VersionedKeyStore(collections.abc.MutableMapping):
    def __init__(self, bucket):
        self.bucket = bucket
        self._path = os.path.join(bucket._path, "keys")

    def __getitem__(self, name):
        versions = Key.get_versions(self.bucket, name)
        if not versions:
            raise KeyError(name)
        return versions[-1]

    def __setitem__(self, name, key):
        if not key.exists():
            key.create(key.value)

    def __delitem__(self, name):
        key = Key(self.bucket, name)
        if not key.exists():
            raise KeyError(name)
        key.delete()

    def __iter__(self):
        if not os.path.exists(self._path):
            return
        for name in os.listdir(self._path):
            yield _decode_name(name)

    def __len__(self):
        return len(os.listdir(self._path))

    def getlist(self, name, default=None):
        keys = Key.get_versions(self.bucket, name)
        if not keys:
            return default
        return keys

    def iterlists(self):
        for name in self.keys():
            yield name, self.getlist(name)


class Part:
    _last_modified = _InfoProperty("last_modified")
    etag = _InfoProperty("etag")

    def __init__(self, multipart, name):
        self.multipart = multipart
        self.name = name
        self._path = os.path.join(multipart._path, str(name) + ".part")
        self._info_path = os.path.join(self._path, "info.json")
        self._value_path = os.path.join(self._path, "value")

    def exists(self):
        return os.path.exists(self._path)

    @property
    def value(self):
        with open(self._value_path, "rb") as file:
            return file.read()

    @value.setter
    def value(self, data):
        with open(self._value_path, "wb") as file:
            file.write(data)
        self.etag = f'"{hashlib.md5(data).hexdigest()}"'

    @property
    def size(self):
        return os.path.getsize(self._value_path)

    @property
    def last_modified(self):
        return datetime.datetime.strptime(self._last_modified, "%Y-%m-%dT%H:%M:%S.%fZ")

    @property
    def last_modified_ISO8601(self):
        return self._last_modified

    @property
    def last_modified_RFC1123(self):
        # Different datetime formats depending on how the key is obtained
        # https://github.com/boto/boto/issues/466
        return rfc_1123_datetime(self.last_modified)

    @property
    def response_dict(self):
        return {
            "etag": self.etag,
            "last-modified": self.last_modified_RFC1123,
        }

    def create(self, value):
        if not os.path.exists(self._path):
            os.makedirs(self._path)
        with open(self._info_path, "w") as file:
            json.dump(
                {
                    "last_modified": iso_8601_datetime_with_milliseconds(
                        datetime.datetime.utcnow()
                    ),
                    "etag": None,
                },
                file,
            )
        self.value = value

    def delete(self):
        shutil.rmtree(self._path)


class Multipart:
    key_name = _InfoProperty("key_name")
    metadata = _InfoProperty("metadata")
    tags = _InfoProperty("tags")
    acl = _InfoProperty("acl")
    sse_encryption = _InfoProperty("sse_encryption")
    kms_key_id = _InfoProperty("kms_key_id")

    def __init__(self, bucket, id=None):
        self.id = id
        if id is None:
            rand_b64 = base64.b64encode(os.urandom(models.UPLOAD_ID_BYTES))
            self.id = (
                rand_b64.decode("utf-8")
                .replace("=", "")
                .replace("+", "")
                .replace("/", "")
            )

        self._path = os.path.join(bucket._path, "multiparts", self.id)
        self._info_path = os.path.join(self._path, "info.json")
        self.storage = None

    def exists(self):
        return os.path.exists(self._path)

    def create(self, key_name, metadata, tags):
        if not os.path.exists(self._path):
            os.makedirs(self._path)
        with open(self._info_path, "w") as file:
            # Make metadata json serialization friendly
            if isinstance(metadata, requests.structures.CaseInsensitiveDict):
                metadata = dict(metadata)
            json.dump({"key_name": key_name, "metadata": metadata, "tags": tags}, file)

    def delete(self):
        if not os.path.exists(self._path):
            return False
        shutil.rmtree(self._path)
        return True

    def complete(self, body):
        decode_hex = codecs.getdecoder("hex_codec")
        total = bytearray()
        md5s = bytearray()

        last = None
        count = 0
        for pn, etag in body:
            part = self.get_part(pn)
            if part is None or part.etag != etag:
                raise models.InvalidPart()
            if last is not None and len(last.value) < settings.S3_UPLOAD_PART_MIN_SIZE:
                raise models.EntityTooSmall()
            part_etag = part.etag.replace('"', "")
            md5s.extend(decode_hex(part_etag)[0])
            total.extend(part.value)
            last = part
            count += 1

        etag = hashlib.md5()
        etag.update(bytes(md5s))
        return total, f"{etag.hexdigest()}-{count}", None

    def get_part(self, part_id):
        part = Part(self, part_id)
        if not part.exists():
            return None
        return part

    def set_part(self, part_id, value):
        if part_id < 1:
            return

        part = Part(self, part_id)
        part.create(value)
        return part

    def list_parts(self):
        parts = sorted(
            (fn[:-5] for fn in os.listdir(self._path) if fn.endswith(".part")),
            key=lambda v: int(v),
        )
        for part in parts:
            yield self.get_part(part)


class Multiparts(collections.abc.MutableMapping):
    def __init__(self, bucket):
        self.bucket = bucket
        self._path = os.path.join(bucket._path, "multiparts")

    def __getitem__(self, name):
        mp = Multipart(self.bucket, name)
        if not mp.exists():
            raise KeyError(name)
        return mp

    def __setitem__(self, name, mp):
        if not mp.exists():
            mp.create()

    def __delitem__(self, name):
        mp = Multipart(self.bucket, name)
        if not mp.exists():
            raise KeyError(name)
        mp.delete()

    def __iter__(self):
        if not os.path.exists(self._path):
            return
        yield from os.listdir(self._path)

    def __len__(self):
        return len(os.listdir(self._path))


class Bucket(models.FakeBucket):
    policy = _InfoProperty("policy")
    versioning_status = _InfoProperty("versioning_status")
    acl = _AclProperty("acl")

    def __init__(self, s3, name):
        self.s3 = s3
        self.name = name
        self.region_name = None

        self.cors = []
        self.logging = {}
        self.notification_configuration = None
        self.accelerate_configuration = None
        self.payer = "BucketOwner"
        self.public_access_block = None
        self.encryption = None
        self.object_lock_enabled = False
        self.default_lock_mode = ""
        self.default_lock_days = 0
        self.default_lock_years = 0

        self._path = os.path.join(s3.directory, self.name + ".bucket")
        self._info_path = os.path.join(self._path, "info.json")
        self._lifecyle_path = os.path.join(self._path, "lifecycle.json")
        self._ws_config_path = os.path.join(self._path, "website_configuration.xml")
        self.creation_date = datetime.datetime.now(tz=pytz.utc)

    @property
    def info(self):
        with open(self._info_path, "rb") as file:
            return json.load(file)

    @info.setter
    def info(self, value):
        with open(self._info_path, "wb") as file:
            return json.dump(value, file)

    @property
    def keys(self):
        return VersionedKeyStore(self)

    @property
    def multiparts(self):
        return Multiparts(self)

    @property
    def location(self):
        return self.info.get("region_name")

    @property
    def rules(self):
        if not os.path.exists(self._lifecyle_path):
            return []
        rules = []
        with open(self._lifecyle_path) as file:
            raw_rules = json.load(file)
        for rule in raw_rules:
            exp = rule.get("Expiration")
            tran = rule.get("Transition")
            tranisitions = []
            if tran:
                tranisitions.append(
                    models.LifecycleTransition(
                        date=tran.get("Date") or None,
                        days=tran["Days"],
                        storage_class=tran["StorageClass"],
                    )
                )
            rules.append(
                models.LifecycleRule(
                    rule_id=rule.get("ID"),
                    prefix=rule["Prefix"],
                    status=rule["Status"],
                    expiration_days=exp.get("Days") if exp else None,
                    expiration_date=exp.get("Date") if exp else None,
                    transitions=tranisitions,
                    noncurrent_version_transitions=[],
                )
            )
        return rules

    @property
    def website_configuration(self):
        if not os.path.exists(self._ws_config_path):
            return []
        with open(self._ws_config_path) as file:
            return file.read()

    def exists(self):
        return os.path.exists(self._path)

    def create(self, region_name=None):
        os.mkdir(self._path)
        self.region_name = region_name
        with open(self._info_path, "w") as file:
            json.dump({"region_name": region_name}, file)

    def delete(self):
        if not os.path.exists(self._path):
            return False
        if len(self.keys):
            return False
        shutil.rmtree(self._path)
        return True

    def set_lifecycle(self, rules):
        with open(self._lifecyle_path, "w") as file:
            json.dump(rules, file)

    def delete_lifecycle(self):
        os.remove(self._lifecyle_path)

    @website_configuration.setter
    def website_configuration(self, website_configuration):
        if isinstance(website_configuration, bytes):
            website_configuration = website_configuration.decode("utf-8")
        if website_configuration is None:
            os.remove(self._ws_config_path)
            return
        with open(self._ws_config_path, "w") as file:
            return file.write(website_configuration)

    def get_cfn_attribute(self, attribute_name):
        if attribute_name == "DomainName":
            raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "DomainName" ]"')
        elif attribute_name == "WebsiteURL":
            raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "WebsiteURL" ]"')
        raise UnformattedGetAttTemplateException()


class ShoobxS3Backend(models.S3Backend):
    def __init__(self, region_name="us-east-42", account_id="deadbeef00d"):
        self.region_name = region_name
        self.account_id = account_id
        self.directory = "./data"
        super().__init__(self.region_name, self.account_id)

    @property
    def directory(self):
        return self._directory

    @directory.setter
    def directory(self, dir):
        self._directory = dir

    @property
    def _url_module(self):
        # Prevent a circular import
        import shoobx.mocks3.urls as backend_urls_module

        # No reload is necessary since we don't allow for overwriting urls
        return backend_urls_module

    def create_bucket(self, bucket_name, region_name):
        new_bucket = Bucket(self, bucket_name)
        if new_bucket.exists():
            raise models.BucketAlreadyExists(bucket=bucket_name)
        new_bucket.create(region_name)

    def list_buckets(self):
        return [
            Bucket(self, fn[:-7])
            for fn in os.listdir(self.directory)
            if fn.endswith(".bucket")
        ]

    def get_bucket(self, bucket_name):
        bucket = Bucket(self, bucket_name)
        if not bucket.exists():
            raise models.MissingBucket(bucket=bucket_name)
        return bucket

    def delete_bucket(self, bucket_name):
        bucket = Bucket(self, bucket_name)
        return bucket.delete()

    def put_object(
        self,
        bucket_name,
        key_name,
        value,
        storage=None,
        etag=None,
        multipart=None,
        encryption=None,
        kms_key_id=None,
        bucket_key_enabled=None,
        lock_mode=None,
        lock_legal_status="OFF",
        lock_until=None,
        checksum_value=None,  # noqa
    ):
        bucket = self.get_bucket(bucket_name)

        old_key = bucket.keys.get(key_name, None)
        if old_key is not None and bucket.is_versioned:
            new_version = old_key.version + 1
        else:
            new_version = 0

        new_key = Key(
            bucket,
            key_name,
            version=new_version,
            is_versioned=bucket.is_versioned,
            multipart=multipart,
            encryption=encryption,
            kms_key_id=kms_key_id,
            bucket_key_enabled=bucket_key_enabled,
            lock_mode=lock_mode,
            lock_legal_status=lock_legal_status,
            lock_until=lock_until,
        )
        new_key.create(value=value, storage=storage, etag=etag)

        return new_key

    def initiate_multipart(self, bucket_name, key_name, metadata):
        bucket = self.get_bucket(bucket_name)
        new_multipart = Multipart(bucket)
        new_multipart.create(key_name, metadata)
        bucket.multiparts[new_multipart.id] = new_multipart
        return new_multipart

    def complete_multipart(self, bucket_name, multipart_id, body):
        bucket = self.get_bucket(bucket_name)
        multipart = bucket.multiparts[multipart_id]
        value, etag = multipart.complete(body)
        if value is None:
            return
        key = self.put_object(
            bucket_name, multipart.key_name, value, etag=etag, multipart=multipart
        )
        key.set_metadata(multipart.metadata)

        del bucket.multiparts[multipart_id]

        return key

    def create_multipart_upload(
        self,
        bucket_name,
        key_name,
        metadata,
        storage_type,
        tags,
        acl,
        sse_encryption,
        kms_key_id,
    ):
        bucket = self.get_bucket(bucket_name)
        new_multipart = Multipart(bucket, key_name)
        new_multipart.create(key_name, metadata, tags)
        new_multipart.storage = storage_type
        bucket.multiparts[new_multipart.id] = new_multipart
        return new_multipart.id

    def complete_multipart_upload(self, bucket_name, multipart_id, body):
        bucket = self.get_bucket(bucket_name)
        multipart = bucket.multiparts[multipart_id]
        value, etag, checksum = multipart.complete(body)
        return multipart, value, etag, checksum

    def reset(self):
        # For every key and multipart, Moto opens a TemporaryFile to write the value of
        # those keys. Ensure that these TemporaryFile-objects are closed, and leave no
        # filehandles open
        for bucket in self.buckets.values():
            for key in bucket.keys.values():
                if isinstance(key, Key):
                    key._value_buffer.close()
                    if key.multipart is not None:
                        for part in key.multipart.parts.values():
                            part._value_buffer.close()
        super().reset()


s3_backends = BackendDict(
    ShoobxS3Backend,
    service_name="s3",
    use_boto3_regions=False,
    additional_regions=["global"],
)