DragonComputer/Dragonfire

View on GitHub
dragonfire/odqa.py

Summary

Maintainability
C
1 day
Test Coverage
#!/usr/bin/python3
# -*- coding: utf-8 -*-

"""
.. module:: odqa
    :platform: Unix
    :synopsis: the top-level submodule of Dragonfire that contains the classes related to **ODQA**: Dragonfire's DeepPavlov SQuAD BERT model based Open-Domain Question Answering Engine.

.. moduleauthor:: Mehmet Mert Yıldıran <mert.yildiran@bil.omu.edu.tr>
"""

import collections  # Imported to support ordered dictionaries in Python
from random import uniform  # Generate pseudo-random numbers

from dragonfire.utilities import nostderr  # With statement to suppress errors

import requests.exceptions  # HTTP for Humans
import wikipedia  # Provides and API-like functionality to search and access Wikipedia data
import wikipedia.exceptions  # Exceptions of wikipedia library
from nltk.corpus import wordnet as wn  # The WordNet corpus
from nltk.corpus.reader.wordnet import WordNetError  # To catch the errors

from deeppavlov import build_model, configs


class ODQA():
    """Class to provide the factoid question answering ability.
    """

    def __init__(self, nlp):
        """Initialization method of :class:`dragonfire.odqa.ODQA` class.

        Args:
            nlp:  :mod:`spacy` model instance.
        """

        self.nlp = nlp  # Load en_core_web_sm, English, 50 MB, default model
        self.model = build_model(configs.squad.squad, download=True)

    def respond(self, com, tts_output=False, userin=None, user_prefix=None, is_server=False):
        """Method to respond the user's input/command using factoid question answering ability.

        Args:
            com (str):  User's command.

        Keyword Args:
            tts_output (bool):      Is text-to-speech output enabled?
            userin:                 :class:`dragonfire.utilities.TextToAction` instance.
            user_prefix (str):      Prefix to address/call user when answering.
            is_server (bool):       Is Dragonfire running as an API server?

        Returns:
            str:  Response.

        .. note::

            Entry function for :class:`ODQA` class. Dragonfire calls only this function. Unlike :func:`Learner.respond`, it executes TTS because of its late reponse nature.

        """

        result = None
        subject, subjects, focus, subject_with_objects = self.semantic_extractor(com)  # Extract the subject, focus, objects etc.
        if not subject:
            return False

        doc = self.nlp(com)  # spaCy does all kinds of NLP analysis in one function
        query = subject  # Wikipedia search query (same as the subject)
        # This is where the real search begins
        if query:  # If there is a Wikipedia query determined
            if not tts_output and not is_server: print("Please wait...")
            if tts_output and not is_server: userin.say("Please wait...", True, False)  # Gain a few more seconds by saying Please wait...
            wh_question = []
            for word in doc:  # Iterate over the words in the command(user's speech)
                if word.tag_ in ['WDT', 'WP', 'WP$', 'WRB']:  # if there is a wh word then
                    wh_question.append(word.text.upper())  # append it by converting to uppercase
            if not wh_question:
                return False
            with nostderr():
                try:
                    wikiresult = wikipedia.search(query)  # run a Wikipedia search with the query
                    if len(wikiresult) == 0:  # if there are no results
                        result = "Sorry, " + user_prefix + ". But I couldn't find anything about " + query + " in Wikipedia."
                        if not tts_output and not is_server: print(result)
                        if tts_output and not is_server: userin.say(result)
                        return result

                    wikipage = wikipedia.page(wikiresult[0])
                    return self.model([wikipage.content], [com])[0][0]
                except requests.exceptions.ConnectionError:  # if there is a connection error
                    result = "Sorry, " + user_prefix + ". But I'm unable to connect to Wikipedia servers."
                    if not is_server:
                        userin.execute([" "], "Wikipedia connection error.")
                        if not tts_output: print(result)
                        if tts_output: userin.say(result)
                    return result
                except wikipedia.exceptions.DisambiguationError as disambiguation:  # if there is a disambiguation
                    wikiresult = wikipedia.search(disambiguation.options[0])  # run Wikipedia search again with the most common option
                except:
                    result = "Sorry, " + user_prefix + ". But something went horribly wrong while I'm searching Wikipedia."
                    if not tts_output and not is_server: print(result)
                    if tts_output and not is_server: userin.say(result)
                    return result

    def phrase_cleaner(self, phrase):
        """Function to clean unnecessary words from the given phrase/string. (Punctuation mark, symbol, unknown, conjunction, determiner, subordinating or preposition and space)

        Args:
            phrase (str):  Noun phrase.

        Returns:
            str:  Cleaned noun phrase.
        """

        clean_phrase = []
        for word in self.nlp(phrase):
            if word.pos_ not in ['PUNCT', 'SYM', 'X', 'CONJ', 'DET', 'ADP', 'SPACE']:
                clean_phrase.append(word.text)
        return ' '.join(clean_phrase)

    def semantic_extractor(self, string):
        """Function to extract subject, subjects, focus, subject_with_objects from given string.

        Args:
            string (str):  String.

        Returns:
            (list) of (str)s: List of subject, subjects, focus, subject_with_objects.
        """

        doc = self.nlp(string)  # spaCy does all kinds of NLP analysis in one function
        the_subject = None  # Wikipedia search query variable definition (the subject)
        # Followings are lists because it could be multiple of them in a string. Multiple objects or subjects...
        subjects = []  # subject list
        pobjects = []  # object of a preposition list
        dobjects = []  # direct object list
        # https://nlp.stanford.edu/software/dependencies_manual.pdf - Hierarchy of typed dependencies
        for np in doc.noun_chunks:  # Iterate over the noun phrases(chunks)
            # print(np.text, np.root.text, np.root.dep_, np.root.head.text)
            if (np.root.dep_ == 'nsubj' or np.root.dep_ == 'nsubjpass') and np.root.tag_ != 'WP':  # if it's a nsubj(nominal subject) or nsubjpass(passive nominal subject) then
                subjects.append(np.text)  # append it to subjects
            if np.root.dep_ == 'pobj':  # if it's an object of a preposition then
                pobjects.append(np.text)  # append it to pobjects
            if np.root.dep_ == 'dobj':  # if it's a direct object then
                dobjects.append(np.text)  # append it to direct objects

        # This block determines the Wikipedia query (the subject) by relying on this priority: [Object of a preposition] > [Subject] > [Direct object]
        pobjects = [x for x in pobjects]
        subjects = [x for x in subjects]
        dobjects = [x for x in dobjects]
        if pobjects:
            the_subject = ' '.join(pobjects)
        elif subjects:
            the_subject = ' '.join(subjects)
        elif dobjects:
            the_subject = ' '.join(dobjects)
        else:
            return None, None, None, None

        # This block determines the focus(objective/goal) by relying on this priority: [Direct object] > [Subject] > [Object of a preposition]
        focus = None
        if dobjects:
            focus = self.phrase_cleaner(' '.join(dobjects))
        elif subjects:
            focus = self.phrase_cleaner(' '.join(subjects))
        elif pobjects:
            focus = self.phrase_cleaner(' '.join(pobjects))
        if focus in the_subject:
            focus = None

        # Full string of all subjects and objects concatenated
        subject_with_objects = []
        for dobject in dobjects:
            subject_with_objects.append(dobject)
        for subject in subjects:
            subject_with_objects.append(subject)
        for pobject in pobjects:
            subject_with_objects.append(pobject)
        subject_with_objects = ' '.join(subject_with_objects)

        wh_found = False
        for word in doc:  # iterate over the each word in the given command(user's speech)
            if word.tag_ in ['WDT', 'WP', 'WP$', 'WRB']:  # check if there is a "wh-" question (we are determining that if it's a question or not, so only accepting questions with "wh-" form)
                wh_found = True
        if not wh_found:
            return None, None, None, None

        return the_subject, subjects, focus, subject_with_objects

    def check_how_odqa_performs(self):
        import json
        import urllib.request
        import random
        import multiprocessing
        import threading
        from termcolor import colored

        from dragonfire.utilities import split, s_print

        HOTPOTQA_DATASET_URL = 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json'
        SAMPLE_LENGTH = None

        THREAD_MULTIPLIER = 1
        CPU_COUNT = multiprocessing.cpu_count()
        THREAD_COUNT = CPU_COUNT * THREAD_MULTIPLIER

        correct_counter = 0
        wrong_counter = 0

        response = urllib.request.urlopen(HOTPOTQA_DATASET_URL)
        dataset = response.read()
        if SAMPLE_LENGTH is not None:
            samples = random.sample(json.loads(dataset), SAMPLE_LENGTH)
        else:
            samples = json.loads(dataset)

        question_number = 0
        for sample in samples:
            question_number += 1
            sample['question_number'] = question_number

        print('\nThread Count: {0}\n'.format(THREAD_COUNT))
        print('\nStarting to test {0} questions'.format(len(samples)))

        samples_split = list(split(samples, THREAD_COUNT))
        results = []
        threads = []
        for j in range(THREAD_COUNT):
            t = threading.Thread(
                target=self.async_performer,
                args=(
                    j,
                    samples_split[j],
                    results,
                    s_print
                )
            )

            t.daemon = True
            t.start()
            threads.append(t)

        for t in threads:
            t.join()

        print(colored('\n(Correct, Wrong) Pairs: {0}\n'.format(results), 'yellow'))

        correct_total = sum([pair[0] for pair in results])
        wrong_total = sum([pair[1] for pair in results])

        success = correct_total / (correct_total + wrong_total)

        print(colored('\nPerformance: {0}\n'.format(success), 'yellow'))

        if success >= 0.05:
            print(colored('SUCCESS!', 'green'))
            exit(0)
        else:
            print(colored('FAILURE!', 'red'))
            exit(1)

    def async_performer(self, thread_number, samples, results, s_print):
        s_print('\nThead {0} is started.\n'.format(thread_number + 1))
        from termcolor import colored

        correct_counter = 0
        wrong_counter = 0

        for sample in samples:
            out = ''
            out += '\n({0})'.format(sample['question_number'])
            question = sample['question']
            correct_answer = sample['answer']
            out += '\nQuestion: {0}'.format(question.encode('ascii', 'ignore').decode('ascii'))
            out += '\nCorrect Answer: {0}'.format(correct_answer.encode('ascii', 'ignore').decode('ascii'))
            if not question or not correct_answer:
                out += colored('\nDataset contains an empty question or answer, so it\'s skipped!', 'yellow')
                continue

            answer = self.respond(question, user_prefix="sir", is_server=True)
            if isinstance(answer, str):
                out += '\nOur Answer: {0}'.format(answer.encode('ascii', 'ignore').decode('ascii'))
            else:
                out += '\nOur Answer: {0}'.format(answer)

            if not isinstance(answer, str):
                wrong_counter += 1
                out += colored('\nWRONG', 'red')
            elif answer in correct_answer:
                correct_counter += 1
                out += colored('\nCORRECT', 'green')
            else:
                wrong_counter += 1
                out += colored('\nWRONG', 'red')

            s_print(out)

        results.append((correct_counter, wrong_counter))
        return True


if __name__ == '__main__':
    import spacy
    odqa = ODQA(spacy.load('en'))
    odqa.check_how_odqa_performs()