IlyaGusev/rupo

View on GitHub
rupo/stress/predictor.py

Summary

Maintainability
B
4 hrs
Test Coverage
# -*- coding: utf-8 -*-
# Автор: Гусев Илья
# Описание: Класс для определения ударения.

from typing import List

from rupo.stress.dict import StressDict
from rupo.util.preprocess import count_vowels, get_first_vowel_position
from rupo.settings import CMU_DICT, ZALYZNYAK_DICT, RU_STRESS_DEFAULT_MODEL
from rupo.stress.word import Stress

from russ.stress.model import StressModel


class StressPredictor:
    def predict(self, word: str) -> List[int]:
        raise NotImplementedError()


class DictStressPredictor(StressPredictor):
    def __init__(self, language="ru", raw_dict_path=None, trie_path=None,
                 zalyzniak_dict=ZALYZNYAK_DICT, cmu_dict=CMU_DICT):
        self.stress_dict = StressDict(language, raw_dict_path=raw_dict_path, trie_path=trie_path,
                                      zalyzniak_dict=zalyzniak_dict, cmu_dict=cmu_dict)

    def predict(self, word: str) -> List[int]:
        """
        Определение ударения в слове по словарю. Возможно несколько вариантов ударения.

        :param word: слово для простановки ударений.
        :return stresses: позиции букв, на которые падает ударение.
        """
        stresses = []
        if count_vowels(word) == 0:
            # Если гласных нет, то и ударений нет.
            pass
        elif count_vowels(word) == 1:
            # Если одна гласная, то на неё и падает ударение.
            stresses.append(get_first_vowel_position(word))
        elif word.find("ё") != -1:
            # Если есть буква "ё", то только на неё может падать ударение.
            stresses.append(word.find("ё"))
        else:
            # Проверяем словарь на наличие форм с ударениями.
            stresses = self.stress_dict.get_stresses(word, Stress.Type.PRIMARY) +\
                       self.stress_dict.get_stresses(word, Stress.Type.SECONDARY)
            if 'е' not in word:
                return stresses
            # Находим все возможные варинаты преобразований 'е' в 'ё'.
            positions = [i for i in range(len(word)) if word[i] == 'е']
            beam = [word[:positions[0]]]
            for i in range(len(positions)):
                new_beam = []
                for prefix in beam:
                    n = positions[i+1] if i+1 < len(positions) else len(word)
                    new_beam.append(prefix + 'ё' + word[positions[i]+1:n])
                    new_beam.append(prefix + 'е' + word[positions[i]+1:n])
                    beam = new_beam
            # И проверяем их по словарю.
            for permutation in beam:
                if len(self.stress_dict.get_stresses(permutation)) != 0:
                    yo_pos = permutation.find("ё")
                    if yo_pos != -1:
                        stresses.append(yo_pos)
        return stresses


class CombinedStressPredictor(StressPredictor):
    def __init__(self, language="ru", stress_model_path: str=RU_STRESS_DEFAULT_MODEL, raw_stress_dict_path=None,
                 stress_trie_path=None, zalyzniak_dict=ZALYZNYAK_DICT, cmu_dict=CMU_DICT):
        self.rnn = StressModel.load(stress_model_path)
        self.dict = DictStressPredictor(language, raw_stress_dict_path, stress_trie_path, zalyzniak_dict, cmu_dict)

    def predict(self, word: str) -> List[int]:
        stresses = self.dict.predict(word)
        if len(stresses) == 0:
            return self.rnn.predict(word)
        else:
            return stresses