collerek/ormar

View on GitHub
ormar/fields/sqlalchemy_encrypted.py

Summary

Maintainability
A
0 mins
Test Coverage
A
100%
# inspired by sqlalchemy-utils (https://github.com/kvesteri/sqlalchemy-utils)
import abc
import base64
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union

import sqlalchemy.types as types
from sqlalchemy.engine import Dialect

import ormar  # noqa: I100, I202
from ormar import ModelDefinitionError  # noqa: I202, I100
from ormar.fields.parsers import ADDITIONAL_PARAMETERS_MAP

cryptography = None
try:  # pragma: nocover
    import cryptography  # type: ignore
    from cryptography.fernet import Fernet
    from cryptography.hazmat.backends import default_backend
    from cryptography.hazmat.primitives import hashes
except ImportError:  # pragma: nocover
    pass

if TYPE_CHECKING:  # pragma: nocover
    from ormar import BaseField


class EncryptBackend(abc.ABC):
    def _refresh(self, key: Union[str, bytes]) -> None:
        if isinstance(key, str):
            key = key.encode()
        digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
        digest.update(key)
        engine_key = digest.finalize()
        self._initialize_backend(engine_key)

    @abc.abstractmethod
    def _initialize_backend(self, secret_key: bytes) -> None:  # pragma: nocover
        pass

    @abc.abstractmethod
    def encrypt(self, value: Any) -> str:  # pragma: nocover
        pass

    @abc.abstractmethod
    def decrypt(self, value: Any) -> str:  # pragma: nocover
        pass


class HashBackend(EncryptBackend):
    """
    One-way hashing - in example for passwords, no way to decrypt the value!
    """

    def _initialize_backend(self, secret_key: bytes) -> None:
        self.secret_key = base64.urlsafe_b64encode(secret_key)

    def encrypt(self, value: Any) -> str:
        if not isinstance(value, str):  # pragma: nocover
            value = repr(value)
        value = value.encode()
        digest = hashes.Hash(hashes.SHA512(), backend=default_backend())
        digest.update(self.secret_key)
        digest.update(value)
        hashed_value = digest.finalize()
        return hashed_value.hex()

    def decrypt(self, value: Any) -> str:
        if not isinstance(value, str):  # pragma: nocover
            value = str(value)
        return value


class FernetBackend(EncryptBackend):
    """
    Two-way encryption, data stored in db are encrypted but decrypted during query.
    """

    def _initialize_backend(self, secret_key: bytes) -> None:
        self.secret_key = base64.urlsafe_b64encode(secret_key)
        self.fernet = Fernet(self.secret_key)

    def encrypt(self, value: Any) -> str:
        if not isinstance(value, str):
            value = repr(value)
        value = value.encode()
        encrypted = self.fernet.encrypt(value)
        return encrypted.decode("utf-8")

    def decrypt(self, value: Any) -> str:
        if not isinstance(value, str):  # pragma: nocover
            value = str(value)
        decrypted: Union[str, bytes] = self.fernet.decrypt(value.encode())
        if not isinstance(decrypted, str):
            decrypted = decrypted.decode("utf-8")
        return decrypted


class EncryptBackends(Enum):
    NONE = 0
    FERNET = 1
    HASH = 2
    CUSTOM = 3


BACKENDS_MAP = {
    EncryptBackends.FERNET: FernetBackend,
    EncryptBackends.HASH: HashBackend,
}


class EncryptedString(types.TypeDecorator):
    """
    Used to store encrypted values in a database
    """

    impl = types.TypeEngine

    def __init__(
        self,
        encrypt_secret: Union[str, Callable],
        encrypt_backend: EncryptBackends = EncryptBackends.FERNET,
        encrypt_custom_backend: Optional[Type[EncryptBackend]] = None,
        **kwargs: Any,
    ) -> None:
        _field_type = kwargs.pop("_field_type")
        super().__init__()
        if not cryptography:  # pragma: nocover
            raise ModelDefinitionError(
                "In order to encrypt a column 'cryptography' is required!"
            )
        backend = BACKENDS_MAP.get(encrypt_backend, encrypt_custom_backend)
        if (
            not backend
            or not isinstance(backend, type)
            or not issubclass(backend, EncryptBackend)
        ):
            raise ModelDefinitionError("Wrong or no encrypt backend provided!")

        self.backend: EncryptBackend = backend()
        self._field_type: "BaseField" = _field_type
        self._underlying_type: Any = _field_type.column_type
        self._key: Union[str, Callable] = encrypt_secret
        type_ = self._field_type.__type__
        if type_ is None:  # pragma: nocover
            raise ModelDefinitionError(
                f"Improperly configured field " f"{self._field_type.name}"
            )
        self.type_: Any = type_

    def __repr__(self) -> str:  # pragma: nocover
        return "TEXT()"

    def load_dialect_impl(self, dialect: Dialect) -> Any:
        return dialect.type_descriptor(types.TEXT())

    def _refresh(self) -> None:
        key = self._key() if callable(self._key) else self._key
        self.backend._refresh(key)

    def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
        if value is None:
            return value
        self._refresh()
        try:
            value = self._underlying_type.process_bind_param(value, dialect)
        except AttributeError:
            encoder, additional_parameter = self._get_coder_type_and_params(
                coders=ormar.SQL_ENCODERS_MAP
            )
            if encoder is not None:
                params = [value] + (
                    [additional_parameter] if additional_parameter else []
                )
                value = encoder(*params)

        encrypted_value = self.backend.encrypt(value)
        return encrypted_value

    def process_result_value(self, value: Any, dialect: Dialect) -> Any:
        if value is None:
            return value
        self._refresh()
        decrypted_value = self.backend.decrypt(value)
        try:
            return self._underlying_type.process_result_value(decrypted_value, dialect)
        except AttributeError:
            decoder, additional_parameter = self._get_coder_type_and_params(
                coders=ormar.DECODERS_MAP
            )
            if decoder is not None:
                params = [decrypted_value] + (
                    [additional_parameter] if additional_parameter else []
                )
                return decoder(*params)  # type: ignore

            return self._field_type.__type__(decrypted_value)  # type: ignore

    def _get_coder_type_and_params(
        self, coders: Dict[type, Callable]
    ) -> Tuple[Optional[Callable], Optional[str]]:
        coder = coders.get(self.type_, None)
        additional_parameter: Optional[str] = None
        if self.type_ in ADDITIONAL_PARAMETERS_MAP:
            additional_parameter = getattr(
                self._field_type, ADDITIONAL_PARAMETERS_MAP[self.type_]
            )
        return coder, additional_parameter