gunthercox/ChatterBot

View on GitHub
chatterbot/storage/django_storage.py

Summary

Maintainability
B
5 hrs
Test Coverage
from chatterbot.storage import StorageAdapter
from chatterbot import constants


class DjangoStorageAdapter(StorageAdapter):
    """
    Storage adapter that allows ChatterBot to interact with
    Django storage backends.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.django_app_name = kwargs.get(
            'django_app_name',
            constants.DEFAULT_DJANGO_APP_NAME
        )

    def get_statement_model(self):
        from django.apps import apps
        return apps.get_model(self.django_app_name, 'Statement')

    def get_tag_model(self):
        from django.apps import apps
        return apps.get_model(self.django_app_name, 'Tag')

    def count(self):
        Statement = self.get_model('statement')
        return Statement.objects.count()

    def filter(self, **kwargs):
        """
        Returns a list of statements in the database
        that match the parameters specified.
        """
        from django.db.models import Q

        Statement = self.get_model('statement')

        kwargs.pop('page_size', 1000)
        order_by = kwargs.pop('order_by', None)
        tags = kwargs.pop('tags', [])
        exclude_text = kwargs.pop('exclude_text', None)
        exclude_text_words = kwargs.pop('exclude_text_words', [])
        persona_not_startswith = kwargs.pop('persona_not_startswith', None)
        search_text_contains = kwargs.pop('search_text_contains', None)

        # Convert a single sting into a list if only one tag is provided
        if type(tags) == str:
            tags = [tags]

        if tags:
            kwargs['tags__name__in'] = tags

        statements = Statement.objects.filter(**kwargs)

        if exclude_text:
            statements = statements.exclude(
                text__in=exclude_text
            )

        if exclude_text_words:
            or_query = [
                ~Q(text__icontains=word) for word in exclude_text_words
            ]

            statements = statements.filter(
                *or_query
            )

        if persona_not_startswith:
            statements = statements.exclude(
                persona__startswith='bot:'
            )

        if search_text_contains:
            or_query = Q()

            for word in search_text_contains.split(' '):
                or_query |= Q(search_text__contains=word)

            statements = statements.filter(
                or_query
            )

        if order_by:
            statements = statements.order_by(*order_by)

        for statement in statements.iterator():
            yield statement

    def create(self, **kwargs):
        """
        Creates a new statement matching the keyword arguments specified.
        Returns the created statement.
        """
        Statement = self.get_model('statement')
        Tag = self.get_model('tag')

        tags = kwargs.pop('tags', [])

        if 'search_text' not in kwargs:
            kwargs['search_text'] = self.tagger.get_text_index_string(kwargs['text'])

        if 'search_in_response_to' not in kwargs:
            if kwargs.get('in_response_to'):
                kwargs['search_in_response_to'] = self.tagger.get_text_index_string(kwargs['in_response_to'])

        statement = Statement(**kwargs)

        statement.save()

        tags_to_add = []

        for _tag in tags:
            tag, _ = Tag.objects.get_or_create(name=_tag)
            tags_to_add.append(tag)

        statement.tags.add(*tags_to_add)

        return statement

    def create_many(self, statements):
        """
        Creates multiple statement entries.
        """
        Statement = self.get_model('statement')
        Tag = self.get_model('tag')

        tag_cache = {}

        for statement in statements:

            statement_data = statement.serialize()
            tag_data = statement_data.pop('tags', [])

            statement_model_object = Statement(**statement_data)

            if not statement.search_text:
                statement_model_object.search_text = self.tagger.get_text_index_string(statement.text)

            if not statement.search_in_response_to and statement.in_response_to:
                statement_model_object.search_in_response_to = self.tagger.get_text_index_string(statement.in_response_to)

            statement_model_object.save()

            tags_to_add = []

            for tag_name in tag_data:
                if tag_name in tag_cache:
                    tag = tag_cache[tag_name]
                else:
                    tag, _ = Tag.objects.get_or_create(name=tag_name)
                    tag_cache[tag_name] = tag
                tags_to_add.append(tag)

            statement_model_object.tags.add(*tags_to_add)

    def update(self, statement):
        """
        Update the provided statement.
        """
        Statement = self.get_model('statement')
        Tag = self.get_model('tag')

        if hasattr(statement, 'id'):
            statement.save()
        else:
            statement = Statement.objects.create(
                text=statement.text,
                search_text=self.tagger.get_text_index_string(statement.text),
                conversation=statement.conversation,
                in_response_to=statement.in_response_to,
                search_in_response_to=self.tagger.get_text_index_string(statement.in_response_to),
                created_at=statement.created_at
            )

        for _tag in statement.tags.all():
            tag, _ = Tag.objects.get_or_create(name=_tag)

            statement.tags.add(tag)

        return statement

    def get_random(self):
        """
        Returns a random statement from the database
        """
        Statement = self.get_model('statement')

        statement = Statement.objects.order_by('?').first()

        if statement is None:
            raise self.EmptyDatabaseException()

        return statement

    def remove(self, statement_text):
        """
        Removes the statement that matches the input text.
        Removes any responses from statements if the response text matches the
        input text.
        """
        Statement = self.get_model('statement')

        statements = Statement.objects.filter(text=statement_text)

        statements.delete()

    def drop(self):
        """
        Remove all data from the database.
        """
        Statement = self.get_model('statement')
        Tag = self.get_model('tag')

        Statement.objects.all().delete()
        Tag.objects.all().delete()