thenetcircle/dino

View on GitHub
dino/storage/cassandra.py

Summary

Maintainability
C
1 day
Test Coverage
#!/usr/bin/env python

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from zope.interface import implementer
from activitystreams.models.activity import Activity

from dino.storage import IStorage
from dino.config import ConfigKeys
from dino.config import AckStatus
from dino.utils import b64d
from dino.utils.decorators import timeit
from dino import environ

__author__ = 'Oscar Eriksson <oscar.eriks@gmail.com>'

logger = logging.getLogger(__name__)


@implementer(IStorage)
class CassandraStorage(object):
    driver = None
    session = None

    def __init__(self, hosts: list, replications=None, strategy=None, protocol_version: int = 4, key_space='dino'):
        if replications is None:
            replications = 2
        if strategy is None:
            strategy = 'SimpleStrategy'

        self.protocol_version = protocol_version
        self.hosts = hosts
        self.key_space = key_space
        self.strategy = strategy
        self.replications = replications
        self.validate(hosts, replications, strategy)

    def init(self):
        from cassandra.cluster import Cluster
        from dino.storage.cassandra_driver import Driver

        cluster = Cluster(self.hosts, protocol_version=self.protocol_version)
        self.driver = Driver(cluster.connect(), self.key_space, self.strategy, self.replications)
        self.driver.init()

    @timeit(logger, 'on_message_hooks_store')
    def store_message(self, activity: Activity, deleted=False) -> None:
        message = b64d(activity.object.content)
        actor_name = b64d(activity.actor.display_name)
        self.driver.msg_insert(
                msg_id=activity.id,
                from_user_id=activity.actor.id,
                from_user_name=actor_name,
                target_id=activity.target.id,
                target_name=activity.target.display_name,
                body=message,
                domain=activity.target.object_type,
                sent_time=activity.published,
                channel_id=activity.object.url,
                channel_name=activity.object.display_name,
                deleted=deleted
        )

    def get_statuses(self, message_ids: set, receiver_id: str) -> dict:
        rows = self.driver.get_acks_for(message_ids, receiver_id)
        if rows is None or len(rows.current_rows) == 0:
            return dict()
        return {row.message_id: int(row.status) for row in rows}

    def _mark_as_status(self, message_ids: set, receiver_id: str, target_id: str, status: int):
        rows = self.driver.get_acks_for(message_ids, receiver_id)

        if rows is None or len(rows.current_rows) == 0:
            current_acks = dict()
        else:
            current_acks = {row.message_id: int(row.status) for row in rows}

        to_update = list()
        to_add = list()

        for message_id in message_ids:
            if message_id not in current_acks:
                to_add.append(message_id)
                continue
                # don't downgrade status
            if current_acks.get(message_id) >= status:
                continue
            to_update.append(message_id)

        if len(to_update) > 0:
            self.driver.update_acks_with_status(message_ids, receiver_id, status)
        if len(to_add) > 0:
            self.driver.add_acks_with_status(message_ids, receiver_id, target_id, status)

    @timeit(logger, 'on_cassandra_mark_as_received')
    def mark_as_received(self, message_ids: set, receiver_id: str, target_id: str) -> None:
        self._mark_as_status(message_ids, receiver_id, target_id, AckStatus.RECEIVED)

    @timeit(logger, 'on_cassandra_mark_as_read')
    def mark_as_read(self, message_ids: set, receiver_id: str, target_id: str) -> None:
        self._mark_as_status(message_ids, receiver_id, target_id, AckStatus.READ)

    @timeit(logger, 'on_cassandra_mark_as_unacked')
    def mark_as_unacked(self, message_id: str, receiver_id: str, target_id: str) -> None:
        self._mark_as_status({message_id}, receiver_id, target_id, AckStatus.NOT_ACKED)

    @timeit(logger, 'on_cassandra_get_messages')
    def get_messages(self, message_ids: set) -> list:
        rows = self.driver.msgs_select_all_in(message_ids)
        if rows is None or len(rows.current_rows) == 0:
            return list()
        return [self._row_to_json(row) for row in rows]

    @timeit(logger, 'on_cassandra_get_undeleted_message_ids_for_user_and_time')
    def get_undeleted_messages_for_user_and_time(self, user_id: str, from_time: int, to_time: int):
        rows = self.driver.msgs_select_non_deleted_for_user_and_time(user_id, from_time, to_time)
        if rows is None or len(rows.current_rows) == 0:
            return list()
        return [self._row_to_json(row) for row in rows]

    @timeit(logger, 'on_cassandra_get_undeleted_message_ids_for_user')
    def get_undeleted_message_ids_for_user(self, user_id: str):
        rows = self.driver.msgs_select_non_deleted_for_user(user_id)
        if rows is None or len(rows.current_rows) == 0:
            return list()
        return [row.message_id for row in rows]

    @timeit(logger, 'on_cassandra_get_all_message_ids_for_user')
    def get_all_message_ids_for_user(self, user_id: str):
        rows = self.driver.msgs_select_all_for_user(user_id)
        if rows is None or len(rows.current_rows) == 0:
            return list()
        return [row.message_id for row in rows]

    @timeit(logger, 'on_cassandra_get_undeleted_message_ids_for_user_and_room')
    def get_undeleted_message_ids_for_user_and_room(self, user_id: str, room_id: str):
        rows = self.driver.msgs_select_non_deleted_for_user_and_room(user_id, room_id)
        if rows is None or len(rows.current_rows) == 0:
            return list()
        return [row.message_id for row in rows]

    @timeit(logger, 'on_cassandra_delete_message')
    def delete_message(self, message_id: str, room_id: str=None, clear_body: bool=True) -> None:
        self.driver.msg_delete(message_id, clear_body=clear_body)

    @timeit(logger, 'on_cassandra_delete_messages')
    def delete_messages(self, message_ids: list, room_id: str=None, clear_body: bool=True) -> None:
        self.driver.msgs_delete(message_ids, clear_body=clear_body)

    @timeit(logger, 'on_cassandra_delete_message')
    def delete_messages_in_room(self, room_id: str=None, clear_body: bool=False) -> None:
        rows = self.driver.msgs_select(room_id, limit=500)
        if rows is None or len(rows.current_rows) == 0:
            return

        msg_ids = [row.message_id for row in rows]
        for msg_id in msg_ids:
            self.driver.msg_delete(msg_id, clear_body=clear_body)

    @timeit(logger, 'on_cassandra_undelete_message')
    def undelete_message(self, message_id: str) -> None:
        self.driver.msg_undelete(message_id)

    @timeit(logger, 'on_cassandra_get_unread_history')
    def get_unacked_history(self, user_id: str) -> list:
        rows = self.driver.get_acks_for_status(user_id, AckStatus.NOT_ACKED)
        if rows is None or len(rows.current_rows) == 0:
            return list()

        message_ids = {row.message_id for row in rows}
        message_rows = self.driver.msgs_select_all_in(message_ids)

        msgs = list()
        for row in message_rows:
            if row.deleted:
                continue
            msgs.append(self._row_to_json(row))
        return msgs

    @timeit(logger, 'on_cassandra_get_unread_history')
    def get_unread_history(self, room_id: str, last_read: int) -> list:
        rows = self.driver.msgs_select_since_time(room_id, last_read)
        if rows is None or len(rows.current_rows) == 0:
            return list()

        msgs = list()
        for row in rows:
            if row.deleted:
                continue
            msgs.append(self._row_to_json(row))
        return msgs

    @timeit(logger, 'on_get_history_for_user_no_limit')
    def get_history_for_user_no_limit(self, room_id: str, from_user_id: str, from_time: int, to_time: int) -> list:
        if room_id is not None and len(room_id.strip()) > 0:
            if from_user_id is not None and len(from_user_id.strip()) > 0:
                rows = self.driver.msgs_select_from_user_to_target_time_slice(from_user_id, room_id, from_time, to_time)
            else:
                rows = self.driver.msgs_select_time_slice(room_id, from_time, to_time)
            if rows is None or len(rows.current_rows) == 0:
                return list()
        else:
            all_rows = self.driver.msgs_select_from_user(from_user_id)
            rows = list()
            for row in all_rows:
                if row.time_stamp < from_time or row.time_stamp > to_time:
                    continue
                rows.append(row)
            if len(rows) == 0:
                return list()

        msgs = list()
        for row in rows:
            msgs.append(self._row_to_json(row))
        return msgs

    @timeit(logger, 'on_cassandra_get_history_for_time_slice')
    def get_history_pagination(self, room_id: str, to_time: int, limit: int) -> list:
        rows = self.driver.msgs_select_pagination(room_id, to_time, limit)

        msgs = list()
        for row in rows:
            msgs.append(self._row_to_json(row))
        return msgs

    @timeit(logger, 'on_cassandra_get_history_for_time_slice')
    def get_history_for_time_slice(self, room_id: str, from_user_id: str, from_time: int, to_time: int) -> list:
        if room_id is not None and len(room_id.strip()) > 0:
            if from_user_id is not None and len(from_user_id.strip()) > 0:
                rows = self.driver.msgs_select_from_user_to_target_time_slice(from_user_id, room_id, from_time, to_time)
            else:
                rows = self.driver.msgs_select_time_slice(room_id, from_time, to_time)
            if rows is None or len(rows.current_rows) == 0:
                return list()
        else:
            all_rows = self.driver.msgs_select_from_user(from_user_id)
            rows = list()
            for row in all_rows:
                if row.time_stamp < from_time or row.time_stamp > to_time:
                    continue
                rows.append(row)
            if len(rows) == 0:
                return list()

        msgs = list()
        for row in rows:
            msgs.append(self._row_to_json(row))
        return msgs

    @timeit(logger, 'on_cassandra_get_history')
    def get_history(self, room_id: str, limit: int=100) -> list:
        rows = self.driver.msgs_select_latest_non_deleted(room_id, limit)
        if rows is None or len(rows.current_rows) == 0:
            return list()

        msgs = list()
        for row in rows:
            msgs.append(self._row_to_json(row))
        return msgs

    @timeit(logger, 'on_cassandra_msg_select')
    def msg_select(self, message_id: str) -> dict:
        rows = self.driver.msg_select(message_id)
        if rows is None or len(rows.current_rows) == 0:
            return dict()
        if len(rows.current_rows) > 1:
            logger.warning('multiple messages found for id %s' % message_id)
        for row in rows:
            # only interested in the first one if multiple
            return self._row_to_json(row)

    def _row_to_json(self, row):
        return {
            'message_id': row.message_id,
            'from_user_id': row.from_user_id,
            'from_user_name': row.from_user_name,
            'target_id': row.target_id,
            'target_name': row.target_name,
            'body': row.body,
            'domain': row.domain,
            'channel_id': row.channel_id,
            'channel_name': row.channel_name,
            'timestamp': row.sent_time,
            'deleted': row.deleted
        }

    def validate(self, hosts, replications, strategy):
        if environ.env.config.get(ConfigKeys.TESTING, False):
            return

        if not isinstance(replications, int):
            raise ValueError('replications is not a valid int: "%s"' % str(replications))
        if replications < 1 or replications > 99:
            raise ValueError('replications needs to be in the interval [1, 99]')

        if replications > len(hosts):
            logger.warning('replications (%s) is higher than number of nodes in cluster (%s)' %
                             (str(replications), len(hosts)))

        if not isinstance(strategy, str):
            raise ValueError('strategy is not a valid string, but of type: "%s"' % str(type(strategy)))

        valid_strategies = ['SimpleStrategy', 'NetworkTopologyStrategy']
        if strategy not in valid_strategies:
            raise ValueError('unknown strategy "%s", valid strategies are: %s' %
                             (str(strategy), ', '.join(valid_strategies)))

        logger.info('connecting to [%s]/%s (%s, %s replications)' % (
            ','.join(self.hosts), self.key_space, self.strategy, self.replications))