IlyaGusev/rupo

View on GitHub
rupo/metre/metre_classifier.py

Summary

Maintainability
D
2 days
Test Coverage
# -*- coding: utf-8 -*-
# Автор: Гусев Илья
# Описание: Классификатор метра.

from collections import OrderedDict
from typing import List, Dict, Tuple
import jsonpickle
import logging

from rupo.main.markup import Line, Markup
from rupo.util.mixins import CommonMixin
from rupo.metre.pattern_analyzer import PatternAnalyzer
from rupo.util.preprocess import get_first_vowel_position
from rupo.util.timeit import timeit


class StressCorrection(CommonMixin):
    """
    Исправление ударения.
    """
    def __init__(self, line_number: int, word_number: int, syllable_number: int,
                 word_text: str, stress: int) -> None:
        """
        :param line_number: номер строки.
        :param word_number: номер слова.
        :param syllable_number: номер слога.
        :param word_text: текст слова.
        :param stress: позиция ударения (с 0).
        """
        self.line_number = line_number
        self.word_number = word_number
        self.syllable_number = syllable_number
        self.word_text = word_text
        self.stress = stress


class ClassificationResult(CommonMixin):
    """
    Результат классификации стихотворения по метру.
    """
    def __init__(self, count_lines: int=0) -> None:
        """
        :param count_lines: количество строк.
        """
        self.metre = None
        self.count_lines = count_lines
        self.errors_count = {k: 0 for k in MetreClassifier.metres.keys()}  # type: Dict[str, int]
        self.corrections = {k: [] for k in MetreClassifier.metres.keys()}  # type: Dict[str, List[StressCorrection]]
        self.resolutions = {k: [] for k in MetreClassifier.metres.keys()}  # type: Dict[str, List[StressCorrection]]
        self.additions = {k: [] for k in MetreClassifier.metres.keys()}  # type: Dict[str, List[StressCorrection]]

    def get_metre_errors_count(self):
        """
        :return: получить количество ошибок на заданном метре.
        """
        return self.errors_count[self.metre]

    def to_json(self):
        """
        :return: сериализация в json.
        """
        return jsonpickle.encode(self)

    @staticmethod
    def str_corrections(collection: List[StressCorrection]) -> str:
        """
        :param collection: список исправлений.
        :return: его строковое представление.
        """
        return"\n".join([str((item.word_text, item.syllable_number)) for item in collection])

    def __str__(self):
        st = "Метр: " + str(self.metre) + "\n"
        st += "Снятая омография: \n" + ClassificationResult.str_corrections(self.resolutions[self.metre]) + "\n"
        st += "Неправильные ударения: \n" + ClassificationResult.str_corrections(self.corrections[self.metre]) + "\n"
        st += "Новые ударения: \n" + ClassificationResult.str_corrections(self.additions[self.metre]) + "\n"
        return st


class ErrorsTableRecord:
    def __init__(self, strong_errors, weak_errors, pattern, failed=False):
        self.strong_errors = strong_errors
        self.weak_errors = weak_errors
        self.pattern = pattern
        self.failed = failed

    def __str__(self):
        return self.pattern + " " + str(self.strong_errors) + " " + str(self.weak_errors)

    def __repr__(self):
        return self.__str__()


class ErrorsTable:
    def __init__(self, num_lines):
        self.data = {}
        self.num_lines = num_lines
        self.coef = OrderedDict(
            [("iambos", 0.3),
             ("choreios", 0.3),
             ("daktylos", 0.4),
             ("amphibrachys", 0.4),
             ("anapaistos", 0.4),
             ("dolnik3", 0.5),
             ("dolnik2", 0.5),
             ("taktovik3", 6.0),
             ("taktovik2", 6.0)
             ])
        self.sum_coef = OrderedDict(
            [("iambos", 0.0),
             ("choreios", 0.0),
             ("daktylos", 0.0),
             ("amphibrachys", 0.0),
             ("anapaistos", 0.0),
             ("dolnik3", 0.035),
             ("dolnik2", 0.035),
             ("taktovik3", 0.10),
             ("taktovik2", 0.10)
             ])
        for metre_name in MetreClassifier.metres.keys():
            self.data[metre_name] = [ErrorsTableRecord(0, 0, "") for _ in range(num_lines)]

    def add_record(self, metre_name, line_num, strong_errors, weak_errors, pattern, failed=False):
        self.data[metre_name][line_num] = ErrorsTableRecord(strong_errors, weak_errors, pattern, failed)

    def get_best_metre(self):
        for l in range(self.num_lines):
            strong_sum = 0
            weak_sum = 0
            for metre_name in self.data.keys():
                strong_sum += self.data[metre_name][l].strong_errors
                weak_sum += self.data[metre_name][l].weak_errors
            for metre_name, column in self.data.items():
                if strong_sum != 0:
                    column[l].strong_errors = column[l].strong_errors / float(strong_sum)
                if weak_sum != 0:
                    column[l].weak_errors = column[l].weak_errors / float(weak_sum)
        sums = dict()
        for metre_name in self.data.keys():
            sums[metre_name] = (0, 0)
        for metre_name, column in self.data.items():
            strong_sum = 0
            weak_sum = 0
            for l in range(self.num_lines):
                strong_sum += column[l].strong_errors
                weak_sum += column[l].weak_errors
            sums[metre_name] = (strong_sum, weak_sum)
        for metre_name, pair in sums.items():
            sums[metre_name] = self.sum_coef[metre_name] + (pair[0] + pair[1] / 2.0) * self.coef[metre_name] / self.num_lines
        logging.debug(sums)
        return min(sums, key=sums.get)


class MetreClassifier(object):
    """
    Классификатор, считает отклонения от стандартных шаблонов ритма(метров).
    """
    metres = OrderedDict(
        [("iambos", '(us)*(uS)(U)?(U)?'),
         ("choreios", '(su)*(S)(U)?(U)?'),
         ("daktylos", '(suu)*(S)(U)?(U)?'),
         ("amphibrachys", '(usu)*(uS)(U)?(U)?'),
         ("anapaistos",  '(uus)*(uuS)(U)?(U)?'),
         ("dolnik3", '(u)?(u)?((su)(u)?)*(S)(U)?(U)?'),
         ("dolnik2", '(u)?(u)?((s)(u)?)*(S)(U)?(U)?'),
         ("taktovik3", '(u)?(u)?((su)(u)?(u)?)*(S)(U)?(U)?'),
         ("taktovik2", '(u)?(u)?((s)(u)?(u)?)*(S)(U)?(U)?')
         ])

    border_syllables_count = 20

    @staticmethod
    @timeit
    def classify_metre(markup):
        """
        Классифицируем стихотворный метр.

        :param markup: разметка.
        :return: результат классификации.
        """
        result = ClassificationResult(len(markup.lines))
        num_lines = len(markup.lines)
        errors_table = ErrorsTable(num_lines)
        for l, line in enumerate(markup.lines):
            for metre_name, metre_pattern in MetreClassifier.metres.items():
                line_syllables_count = sum([len(word.syllables) for word in line.words])

                # Строчки длиной больше border_syllables_count слогов не обрабатываем.
                if line_syllables_count > MetreClassifier.border_syllables_count or line_syllables_count == 0:
                    continue
                error_border = 7
                if metre_name == "dolnik2" or metre_name == "dolnik3":
                    error_border = 3
                if metre_name == "taktovik2" or metre_name == "taktovik3":
                    error_border = 2
                pattern, strong_errors, weak_errors, analysis_errored = \
                    PatternAnalyzer.count_errors(MetreClassifier.metres[metre_name],
                                                 MetreClassifier.__get_line_pattern(line),
                                                 error_border)
                if analysis_errored or len(pattern) == 0:
                    errors_table.add_record(metre_name, l, strong_errors, weak_errors, pattern, True)
                    continue
                corrections = MetreClassifier.__get_line_pattern_matching_corrections(line, l, pattern)[0]
                accentuation_errors = len(corrections)
                strong_errors += accentuation_errors
                errors_table.add_record(metre_name, l, strong_errors, weak_errors, pattern)
        result.metre = errors_table.get_best_metre()

        # Запомним все исправления.
        for l, line in enumerate(markup.lines):
            pattern = errors_table.data[result.metre][l].pattern
            failed = errors_table.data[result.metre][l].failed
            if failed or len(pattern) == 0:
                continue
            corrections, resolutions, additions =\
                MetreClassifier.__get_line_pattern_matching_corrections(line, l, pattern)
            result.corrections[result.metre] += corrections
            result.resolutions[result.metre] += resolutions
            result.additions[result.metre] += additions
            result.errors_count[result.metre] += len(corrections)
        return result

    @staticmethod
    def __get_line_pattern(line: Line) -> str:
        """
        Сопоставляем строку шаблону, считаем ошибки.

        :param line: строка.
        :return: количество ошибок
        """
        pattern = ""
        for w, word in enumerate(line.words):
            if len(word.syllables) == 0:
                pattern += "U"
            else:
                for syllable in word.syllables:
                    if syllable.stress != -1:
                        pattern += "S"
                    else:
                        pattern += "U"
        return pattern

    @staticmethod
    def __get_line_pattern_matching_corrections(line: Line, line_number: int, pattern: str) \
            -> Tuple[List[StressCorrection], List[StressCorrection], List[StressCorrection]]:
        """
        Ударения могут приходиться на слабое место,
        если безударный слог того же слова не попадает на икт. Иначе - ошибка.

        :param line: строка.
        :param line_number: номер строки.
        :param pattern: шаблон.
        :return: ошибки, дополнения и снятия
        """
        corrections = []
        resolutions = []
        additions = []
        number_in_pattern = 0
        for w, word in enumerate(line.words):
            # Игнорируем слова длиной меньше 2 слогов.
            if len(word.syllables) == 0:
                continue
            if len(word.syllables) == 1:
                if pattern[number_in_pattern].lower() == "s" and word.syllables[0].stress == -1:
                    additions.append(StressCorrection(line_number, w, 0, word.text, word.syllables[0].vowel()))
                number_in_pattern += len(word.syllables)
                continue
            stress_count = word.count_stresses()
            for syllable in word.syllables:
                if stress_count == 0 and pattern[number_in_pattern].lower() == "s":
                    # Ударений нет, ставим такое, какое подходит по метру. Возможно несколько.
                    additions.append(StressCorrection(line_number, w, syllable.number, word.text, syllable.vowel()))
                elif pattern[number_in_pattern].lower() == "u" and syllable.stress != -1:
                    # Ударение есть и оно падает на этот слог, при этом в шаблоне безударная позиция.
                    # Найдём такой слог, у которого в шаблоне ударная позиция. Это и есть наше исправление.
                    for other_syllable in word.syllables:
                        other_number_in_pattern = other_syllable.number - syllable.number + number_in_pattern
                        if syllable.number == other_syllable.number or pattern[other_number_in_pattern].lower() != "s":
                            continue
                        ac = StressCorrection(line_number, w, other_syllable.number, word.text, other_syllable.vowel())
                        if stress_count == 1 and other_syllable.stress == -1:
                            corrections.append(ac)
                        else:
                            resolutions.append(ac)
                number_in_pattern += 1
        return corrections, resolutions, additions

    @staticmethod
    def get_improved_markup(markup: Markup, result: ClassificationResult) -> Markup:
        """
        Улучшаем разметку после классификации метра.

        :param markup: начальная разметка.
        :param result: результат классификации.
        :return: улучшенная разметка.
        """
        for pos in result.corrections[result.metre] + result.resolutions[result.metre]:
            syllables = markup.lines[pos.line_number].words[pos.word_number].syllables
            for i, syllable in enumerate(syllables):
                syllable.stress = -1
                if syllable.number == pos.syllable_number:
                    syllable.stress = syllable.begin + get_first_vowel_position(syllable.text)
        for pos in result.additions[result.metre]:
            syllable = markup.lines[pos.line_number].words[pos.word_number].syllables[pos.syllable_number]
            syllable.stress = syllable.begin + get_first_vowel_position(syllable.text)

        return markup

    @staticmethod
    def improve_markup(markup: Markup) -> \
            Tuple[Markup, ClassificationResult]:
        """
        Улучшение разметки метрическим классификатором.

        :param markup: начальная разметка.
        """
        result = MetreClassifier.classify_metre(markup)
        improved_markup = MetreClassifier.get_improved_markup(markup, result)
        return improved_markup, result