thenetcircle/dino

View on GitHub
dino/storage/cassandra_driver.py

Summary

Maintainability
D
2 days
Test Coverage
#!/usr/bin/env python

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import pytz
import os

from datetime import datetime
from enum import Enum

from cassandra.cqlengine.operators import EqualsOperator
from cassandra.cqlengine.statements import BaseCQLStatement
from cassandra.cqlengine.statements import UpdateStatement
from zope.interface import implementer

from cassandra.cluster import ResultSet
from cassandra.cluster import Session
from cassandra.query import ValueSequence
from cassandra.cqlengine.query import BatchQuery

from dino.storage.cassandra_interface import IDriver
from dino.config import ConfigKeys

__author__ = 'Oscar Eriksson <oscar.eriks@gmail.com>'

logger = logging.getLogger(__name__)


class StatementKeys(Enum):
    acks_update = 'acks_update'
    acks_insert = 'acks_insert'
    acks_get = 'acks_get'
    acks_get_for_status = 'acks_get_for_status'
    msg_insert = 'msg_insert'
    msg_update = 'msg_update'
    msg_select = 'msg_select'
    msg_select_all = 'msg_select_all'
    msgs_select = 'msgs_select'
    msgs_select_all_in = 'msgs_select_all_in'
    msgs_select_pagination = 'msgs_select_pagination'
    msgs_select_time_slice = 'msgs_select_time_slice'
    msgs_select_by_time_stamp = 'msgs_select_by_time_stamp'
    msgs_select_latest_non_deleted = 'msgs_select_latest_non_deleted'
    msgs_select_from_user = 'msg_select_from_user'
    msgs_select_from_user_to_target = 'msg_select_from_user_to_target'
    msgs_select_from_user_to_target_time_slice = 'msg_select_from_user_to_target_time_slice'
    msg_select_one = 'msg_select_one'
    msg_select_msg_id_from_user_not_deleted = 'msg_select_msg_id_from_user_not_deleted'
    msg_select_msg_id_from_user_all = 'msg_select_msg_id_from_user_all'
    msg_select_msgs_from_user_not_deleted_for_time = 'msg_select_msgs_from_user_not_deleted_for_time'
    msg_select_msg_id_from_user_and_room_not_deleted = 'msg_select_msg_id_from_user_and_room_not_deleted'


@implementer(IDriver)
class Driver(object):
    def __init__(self, session: Session, key_space: str, strategy: str, replications: int):
        self.session: Session = session
        self.statements = dict()
        self.key_space = key_space
        self.key_space_test = key_space + 'test'
        self.strategy = strategy
        self.replications = replications
        self.logger = logging.getLogger(__name__)

    def init(self):
        def create_test_key_space():
            self.logger.debug('creating test keyspace...')
            create_key_space_stmt = self.session.prepare(
                    """
                    CREATE KEYSPACE IF NOT EXISTS %s
                    WITH replication = {'class': '%s', 'replication_factor': '%s'}
                    """ % (self.key_space + 'test', self.strategy, str(self.replications))
            )
            self.session.execute(create_key_space_stmt)

        def create_key_space():
            self.logger.debug('creating keyspace...')
            create_key_space_stmt = self.session.prepare(
                    """
                    CREATE KEYSPACE IF NOT EXISTS %s
                    WITH replication = {'class': '%s', 'replication_factor': '%s'}
                    """ % (self.key_space, self.strategy, str(self.replications))
            )
            self.session.execute(create_key_space_stmt)

        def set_key_space():
            self.logger.debug('switching to key space: %s' % self.key_space)
            self.session.set_keyspace(self.key_space)

        def set_test_key_space():
            self.logger.debug('switching to key space: %s' % self.key_space_test)
            self.session.set_keyspace(self.key_space_test)

        def create_tables():
            self.logger.debug('creating tables...')
            self.session.execute(
                """
                CREATE TABLE IF NOT EXISTS messages (
                    message_id varchar,
                    from_user_id text,
                    from_user_name text,
                    target_id text,
                    target_name text,
                    body text,
                    domain text,
                    sent_time varchar,
                    time_stamp int,
                    channel_id varchar,
                    channel_name text,
                    deleted boolean,
                    PRIMARY KEY (target_id, from_user_id, sent_time, time_stamp)
                );
                """
            )
            self.session.execute(
                """
                CREATE TABLE IF NOT EXISTS msg_acks (
                    for_user_id text, 
                    message_id varchar, 
                    status int, 
                    target_id varchar,
                    primary key(for_user_id, message_id)
                );
                """
            )

        def create_views():
            self.session.execute(
                    """
                    CREATE MATERIALIZED VIEW IF NOT EXISTS acks_by_user_and_status AS
                        SELECT * from msg_acks
                            WHERE
                                for_user_id IS NOT NULL AND
                                message_id IS NOT NULL AND
                                status IS NOT NULL AND
                                target_id IS NOT NULL
                        PRIMARY KEY (for_user_id, status, message_id)
                        WITH CLUSTERING ORDER BY (status DESC);
                    """
            )
            self.session.execute(
                    """
                    CREATE MATERIALIZED VIEW IF NOT EXISTS messages_by_id AS
                        SELECT * from messages
                            WHERE
                                message_id IS NOT NULL AND
                                target_id IS NOT NULL AND
                                from_user_id IS NOT NULL AND
                                sent_time IS NOT NULL AND
                                time_stamp IS NOT NULL
                        PRIMARY KEY (message_id, target_id, from_user_id, sent_time, time_stamp)
                        WITH CLUSTERING ORDER BY (time_stamp DESC);
                    """
            )
            self.session.execute(
                    """
                    CREATE MATERIALIZED VIEW IF NOT EXISTS messages_by_time_stamp AS
                        SELECT * from messages
                            WHERE
                                message_id IS NOT NULL AND
                                body IS NOT NULL AND
                                target_id IS NOT NULL AND
                                from_user_id IS NOT NULL AND
                                sent_time IS NOT NULL AND
                                time_stamp IS NOT NULL
                        PRIMARY KEY (target_id, time_stamp, from_user_id, sent_time)
                        WITH CLUSTERING ORDER BY (time_stamp DESC);
                    """
            )
            self.session.execute(
                    """
                    CREATE MATERIALIZED VIEW IF NOT EXISTS messages_by_time_stamp_non_deleted AS
                        SELECT * from messages
                            WHERE
                                message_id IS NOT NULL AND
                                body IS NOT NULL AND
                                target_id IS NOT NULL AND
                                from_user_id IS NOT NULL AND
                                sent_time IS NOT NULL AND
                                deleted IS NOT NULL AND
                                time_stamp IS NOT NULL
                        PRIMARY KEY (target_id, deleted, time_stamp, from_user_id, sent_time)
                        WITH CLUSTERING ORDER BY (time_stamp DESC);
                    """
            )
            self.session.execute(
                    """
                    CREATE MATERIALIZED VIEW IF NOT EXISTS messages_by_from_user_id AS
                        SELECT * from messages
                            WHERE
                                message_id IS NOT NULL AND
                                body IS NOT NULL AND
                                target_id IS NOT NULL AND
                                from_user_id IS NOT NULL AND
                                sent_time IS NOT NULL AND
                                time_stamp IS NOT NULL
                        PRIMARY KEY (from_user_id, target_id, time_stamp, sent_time)
                        WITH CLUSTERING ORDER BY (time_stamp DESC);
                    """
            )

        def prepare_statements():
            self.statements[StatementKeys.msg_insert] = self.session.prepare(
                    """
                    INSERT INTO messages (
                        message_id,
                        from_user_id,
                        from_user_name,
                        target_id,
                        target_name,
                        body,
                        domain,
                        sent_time,
                        time_stamp,
                        channel_id,
                        channel_name,
                        deleted
                    )
                    VALUES (
                        ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
                    )
                    """
            )
            self.statements[StatementKeys.msg_update] = self.session.prepare(
                    """
                    UPDATE messages SET body = ?, deleted = ? 
                    WHERE 
                      target_id = ? AND 
                      from_user_id = ? AND
                      sent_time = ? AND
                      time_stamp = ?
                    """
            )
            self.statements[StatementKeys.acks_update] = self.session.prepare(
                """
                UPDATE msg_acks SET status = ? where for_user_id = ? and message_id in ?
                """
            )
            self.statements[StatementKeys.acks_insert] = self.session.prepare(
                """
                INSERT INTO msg_acks (for_user_id, message_id, status, target_id) values(?, ?, ?, ?)
                """
            )
            self.statements[StatementKeys.acks_get] = self.session.prepare(
                """
                SELECT * FROM msg_acks WHERE for_user_id = ? and message_id in ?
                """
            )
            self.statements[StatementKeys.acks_get_for_status] = self.session.prepare(
                """
                SELECT * FROM acks_by_user_and_status WHERE for_user_id = ? and status = ?
                """
            )
            self.statements[StatementKeys.msgs_select] = self.session.prepare(
                    """
                    SELECT * FROM messages WHERE target_id = ? LIMIT ?
                    """
            )
            self.statements[StatementKeys.msgs_select_all_in] = self.session.prepare(
                    """
                    SELECT * FROM messages_by_id WHERE message_id in ?
                    """
            )
            self.statements[StatementKeys.msgs_select_by_time_stamp] = self.session.prepare(
                    """
                    SELECT * FROM messages_by_time_stamp WHERE target_id = ? AND time_stamp > ?
                    """
            )
            self.statements[StatementKeys.msgs_select_latest_non_deleted] = self.session.prepare(
                    """
                    SELECT * FROM messages_by_time_stamp_non_deleted WHERE target_id = ? AND deleted = False LIMIT ?
                    """
            )
            self.statements[StatementKeys.msgs_select_pagination] = self.session.prepare(
                    """
                    SELECT * FROM messages_by_time_stamp_non_deleted WHERE target_id = ? AND deleted = False AND time_stamp < ? LIMIT ?
                    """
            )
            self.statements[StatementKeys.msgs_select_time_slice] = self.session.prepare(
                    """
                    SELECT * FROM messages_by_time_stamp WHERE target_id = ? AND time_stamp > ? AND time_stamp < ?
                    """
            )
            self.statements[StatementKeys.msgs_select_from_user] = self.session.prepare(
                    """
                    SELECT * FROM messages_by_from_user_id WHERE from_user_id = ? LIMIT ?
                    """
            )
            self.statements[StatementKeys.msgs_select_from_user_to_target] = self.session.prepare(
                    """
                    SELECT * FROM messages_by_from_user_id WHERE from_user_id = ? AND target_id = ? LIMIT ?
                    """
            )
            self.statements[StatementKeys.msgs_select_from_user_to_target_time_slice] = self.session.prepare(
                    """
                    SELECT * FROM messages_by_from_user_id WHERE from_user_id = ? AND target_id = ? AND time_stamp > ? AND time_stamp < ? LIMIT ?
                    """
            )
            self.statements[StatementKeys.msg_select] = self.session.prepare(
                    """
                    SELECT target_id, from_user_id, sent_time FROM messages_by_id WHERE message_id = ?
                    """
            )
            self.statements[StatementKeys.msg_select_all] = self.session.prepare(
                    """
                    SELECT * FROM messages_by_id WHERE message_id = ?
                    """
            )
            self.statements[StatementKeys.msg_select_one] = self.session.prepare(
                    """
                    SELECT * FROM messages WHERE target_id = ? AND from_user_id = ? AND sent_time = ?
                    """
            )
            self.statements[StatementKeys.msg_select_msg_id_from_user_not_deleted] = self.session.prepare(
                    """
                    SELECT
                        message_id FROM messages_by_from_user_id
                    WHERE
                        from_user_id = ? AND
                        deleted = False AND
                        domain = 'room'
                    ALLOW FILTERING
                    """
            )
            self.statements[StatementKeys.msg_select_msg_id_from_user_all] = self.session.prepare(
                    """
                    SELECT
                        message_id FROM messages_by_from_user_id
                    WHERE
                        from_user_id = ? AND
                        domain = 'room'
                    ALLOW FILTERING
                    """
            )
            self.statements[StatementKeys.msg_select_msgs_from_user_not_deleted_for_time] = self.session.prepare(
                    """
                    SELECT
                        * FROM messages_by_from_user_id
                    WHERE
                        from_user_id = ? AND
                        deleted = False AND
                        domain = 'room' AND
                        time_stamp > ? AND
                        time_stamp < ?
                    ALLOW FILTERING
                    """
            )
            self.statements[StatementKeys.msg_select_msg_id_from_user_and_room_not_deleted] = self.session.prepare(
                    """
                    SELECT
                        message_id FROM messages_by_from_user_id
                    WHERE
                        from_user_id = ? AND
                        target_id = ? AND
                        deleted = False AND
                        domain = 'room'
                    ALLOW FILTERING
                    """
            )

        if os.getenv("CREATE_DB") == "true":
            # create keyspace and tables for tests
            create_test_key_space()
            set_test_key_space()
            create_tables()
            create_views()
            # create keyspace and tables for other
            create_key_space()
            set_key_space()
            create_tables()
            create_views()

        set_key_space()
        prepare_statements()

    def msg_insert(self, msg_id, from_user_id, from_user_name, target_id, target_name, body, domain, sent_time, channel_id, channel_name, deleted=False) -> None:
        dt = datetime.strptime(sent_time, ConfigKeys.DEFAULT_DATE_FORMAT)
        dt = pytz.timezone('utc').localize(dt, is_dst=None)
        time_stamp = int(dt.astimezone(pytz.utc).strftime('%s'))
        self._execute(
                StatementKeys.msg_insert, msg_id, from_user_id, from_user_name, target_id, target_name,
                body, domain, sent_time, time_stamp, channel_id, channel_name, deleted)

    def msg_update(self, from_user_id, target_id, body, sent_time, deleted=False) -> None:
        dt = datetime.strptime(sent_time, ConfigKeys.DEFAULT_DATE_FORMAT)
        dt = pytz.timezone('utc').localize(dt, is_dst=None)
        time_stamp = int(dt.astimezone(pytz.utc).strftime('%s'))
        self._execute(StatementKeys.msg_update, body, deleted, target_id, from_user_id, sent_time, time_stamp)

    def get_acks_for(self, message_ids: set, receiver_id: str) -> ResultSet:
        return self._execute(StatementKeys.acks_get, receiver_id, message_ids)

    def get_acks_for_status(self, user_id: str, status: int) -> ResultSet:
        return self._execute(StatementKeys.acks_get_for_status, user_id, status)

    def add_acks_with_status(self, message_ids: set, receiver_id: str, target_id: str, status: int):
        for message_id in message_ids:
            self._execute(StatementKeys.acks_insert, receiver_id, message_id, status, target_id)

    def update_acks_with_status(self, message_ids: set, receiver_id: str, status: int):
        return self._execute(StatementKeys.acks_update, status, receiver_id, message_ids)

    def msgs_select_pagination(self, target_id: str, to_time: int, limit: int):
        return self._execute(StatementKeys.msgs_select_pagination, target_id, to_time, limit)

    def msgs_select_time_slice(self, target_id: str, from_time: int, to_time: int) -> ResultSet:
        return self._execute(StatementKeys.msgs_select_time_slice, target_id, from_time, to_time)

    def msgs_select_from_user(self, from_user_id: str, limit: int=500) -> ResultSet:
        return self._execute(StatementKeys.msgs_select_from_user, from_user_id, limit)

    def msgs_select_from_user_to_target(self, from_user_id: str, target_id: str, limit: int=500) -> ResultSet:
        return self._execute(StatementKeys.msgs_select_from_user_to_target, from_user_id, target_id, limit)

    def msgs_select_from_user_to_target_time_slice(self, from_user_id: str, target_id: str, from_time: int, to_time: int, limit: int=500) -> ResultSet:
        return self._execute(StatementKeys.msgs_select_from_user_to_target_time_slice, from_user_id, target_id, from_time, to_time, limit)

    def msgs_select(self, target_id: str, limit: int=100) -> ResultSet:
        return self._execute(StatementKeys.msgs_select, target_id, limit)

    def msgs_select_all_in(self, message_ids: set) -> ResultSet:
        return self._execute(StatementKeys.msgs_select_all_in, ValueSequence(list(message_ids)))

    def msg_select(self, message_id) -> ResultSet:
        return self._execute(StatementKeys.msg_select_all, message_id)

    def msgs_select_latest_non_deleted(self, target_id: str, limit: int=100) -> ResultSet:
        return self._execute(StatementKeys.msgs_select_latest_non_deleted, target_id, limit)

    def msgs_select_since_time(self, target_id: str, time_stamp: int) -> ResultSet:
        return self._execute(StatementKeys.msgs_select_by_time_stamp, target_id, time_stamp)

    def msgs_select_non_deleted_for_user(self, from_user_id: str) -> ResultSet:
        return self._execute(StatementKeys.msg_select_msg_id_from_user_not_deleted, from_user_id)

    def msgs_select_all_for_user(self, from_user_id: str) -> ResultSet:
        return self._execute(StatementKeys.msg_select_msg_id_from_user_all, from_user_id)

    def msgs_select_non_deleted_for_user_and_time(self, from_user_id: str, from_time: int, to_time: int) -> ResultSet:
        return self._execute(
            StatementKeys.msg_select_msgs_from_user_not_deleted_for_time,
            from_user_id, from_time, to_time)

    def msgs_select_non_deleted_for_user_and_room(self, from_user_id: str, target_id: str) -> ResultSet:
        return self._execute(StatementKeys.msg_select_msg_id_from_user_and_room_not_deleted, from_user_id, target_id)

    def msg_undelete(self, message_id: str) -> None:
        self._msg_delete(message_id, deleted=False)

    def msg_delete(self, message_id: str, clear_body=True) -> None:
        self._msg_delete(message_id, deleted=True, clear_body=clear_body)

    def msgs_delete(self, message_ids: list, clear_body=True) -> None:
        to_update = list()

        for message_id in message_ids:
            keys = self._execute(StatementKeys.msg_select, message_id)
            if keys is None or len(keys.current_rows) == 0:
                # not found
                continue

            if len(keys.current_rows) > 1:
                logger.warning('found %s msgs when deleting with message_id %s' % (len(keys.current_rows), message_id))

            for key in keys.current_rows:
                target_id, from_user_id, sent_time = key.target_id, key.from_user_id, key.sent_time
                message_rows = self._execute(StatementKeys.msg_select_one, target_id, from_user_id, sent_time)

                if len(message_rows.current_rows) > 1:
                    logger.warning(
                        'found %s msgs when deleting with target_id %s, from_user_id %s and sent_time %s' %
                        (len(message_rows.current_rows), target_id, from_user_id, sent_time))

                for message_row in message_rows.current_rows:
                    body = message_row.body

                    if clear_body:
                        body = ''

                    to_update.append((from_user_id, target_id, body, sent_time))

        update_query = self.session.prepare("""
            UPDATE messages SET body = ?, deleted = true
            WHERE target_id = ? AND from_user_id = ? AND sent_time = ? AND time_stamp = ?
        """)

        for from_user_id, target_id, body, sent_time in to_update:
            dt = datetime.strptime(sent_time, ConfigKeys.DEFAULT_DATE_FORMAT)
            dt = pytz.timezone('utc').localize(dt, is_dst=None)
            time_stamp = int(dt.astimezone(pytz.utc).strftime('%s'))

            self.session.execute(update_query.bind([
                body, target_id, from_user_id, sent_time, time_stamp
            ]))

    def _msg_delete(self, message_id: str, deleted: bool, clear_body: bool = True) -> None:
        """
        We're doing three queries here, one to get primary index of messages table from message_id, then getting the
        complete row from messages table, and finally updating that row. This could be lowered to two queries by
        duplicating everything from messages table to messages_by_id materialized view, but would also double storage
        requirements. Since message deletion is likely not a frequent operation we can accept doing three queries.

        :param message_id: the uuid of the message to 'delete' (will only flag as deleted, will not remove)
        """
        keys = self._execute(StatementKeys.msg_select, message_id)
        if keys is None or len(keys.current_rows) == 0:
            # not found
            return

        if len(keys.current_rows) > 1:
            logger.warning('found %s msgs when deleting with message_id %s' % (len(keys.current_rows), message_id))

        for key in keys.current_rows:
            target_id, from_user_id, timestamp = key.target_id, key.from_user_id, key.sent_time
            message_rows = self._execute(StatementKeys.msg_select_one, target_id, from_user_id, timestamp)

            if len(message_rows.current_rows) > 1:
                logger.warning(
                        'found %s msgs when deleting with target_id %s, from_user_id %s and timestamp %s' %
                        (len(message_rows.current_rows), target_id, from_user_id, timestamp))

            for message_row in message_rows.current_rows:
                body = message_row.body

                if clear_body:
                    body = ''

                self.msg_update(from_user_id, target_id, body, timestamp, deleted)

    def _execute(self, statement_key, *params) -> ResultSet:
        if params is not None and len(params) > 0:
            return self.session.execute(self.statements[statement_key].bind(params))
        return self.session.execute(self.statements[statement_key])