IlyaGusev/rupo

View on GitHub
rupo/generate/transforms.py

Summary

Maintainability
A
3 hrs
Test Coverage
from copy import deepcopy
import numpy as np
from collections import defaultdict

from rulm.transform import Transform

from rupo.rhymes.rhymes import Rhymes
from rupo.main.vocabulary import StressVocabulary
from rupo.stress.word import StressedWord


class PoemTransform(Transform):
    """
    Фильтр по шаблону метра.
    """
    def __init__(self,
                 stress_vocabulary: StressVocabulary,
                 metre_pattern: str,
                 rhyme_pattern: str,
                 n_syllables: int,
                 eos_index: int,
                 letters_to_rhymes: dict=None,
                 score_border=4):
        self.stress_vocabulary = stress_vocabulary

        self.n_syllables = n_syllables

        mul = n_syllables // len(metre_pattern)
        if n_syllables % len(metre_pattern) != 0:
            mul += 1

        self.metre_pattern = metre_pattern * mul
        self.stress_position = len(self.metre_pattern) - 1
        self.eos_index = eos_index

        self.rhyme_pattern = rhyme_pattern
        self.rhyme_position = len(self.rhyme_pattern) - 1
        self.score_border = score_border

        self.letters_to_rhymes = defaultdict(set)
        if letters_to_rhymes is not None:
            for letter, words in letters_to_rhymes.items():
                for word in words:
                    self.letters_to_rhymes[letter].add(word)

    def __call__(self, probabilities: np.array) -> np.array:
        if self.rhyme_position < 0 and self.stress_position == len(self.metre_pattern) - 1:
            probabilities = np.zeros(probabilities.shape, dtype="float")
            probabilities[self.eos_index] = 1.
            return probabilities

        for index in range(probabilities.shape[0]):
            word = self.stress_vocabulary.get_word(index)
            is_good_by_stress = self._filter_word_by_stress(word)
            is_good_by_rhyme = True
            if self.stress_position == len(self.metre_pattern) - 1:
                is_good_by_rhyme = self._filter_word_by_rhyme(word)
            if not is_good_by_stress or not is_good_by_rhyme:
                probabilities[index] = 0.

        assert np.sum(probabilities > 0) != 0, "Poem transform filtered out all words"
        return probabilities

    def advance(self, index: int):
        word = self.stress_vocabulary.get_word(index)
        syllables_count = len(word.syllables)

        if self.stress_position == len(self.metre_pattern) - 1:
            letter = self.rhyme_pattern[self.rhyme_position]
            self.letters_to_rhymes[letter].add(word)
            self.rhyme_position -= 1

        self.stress_position -= syllables_count

        if self.stress_position < 0:
            self.stress_position = len(self.metre_pattern) - 1

    def _filter_word_by_stress(self, word: StressedWord) -> bool:
        syllables = word.syllables
        syllables_count = len(syllables)
        if syllables_count == 0:
            return False
        if self.stress_position - syllables_count < -1:
            return False
        for i in range(syllables_count):
            syllable = syllables[i]
            syllable_number = self.stress_position - syllables_count + i + 1
            if syllables_count >= 2 and syllable.stress == -1 and self.metre_pattern[syllable_number] == "+":
                for j in range(syllables_count):
                    other_syllable = syllables[j]
                    other_syllable_number = other_syllable.number - syllable.number + syllable_number
                    if i != j and other_syllable.stress != -1 and self.metre_pattern[other_syllable_number] == "-":
                        return False
        return True

    def _filter_word_by_rhyme(self, word: StressedWord) -> bool:
        if len(word.syllables) <= 1:
            return False
        rhyming_words = self.letters_to_rhymes[self.rhyme_pattern[self.rhyme_position]]
        if len(rhyming_words) == 0:
            return True
        first_word = list(rhyming_words)[0]

        is_rhyme = Rhymes.is_rhyme(first_word, word,
                                   score_border=self.score_border,
                                   syllable_number_border=2) and first_word.text != word.text
        return is_rhyme

    def __copy__(self):
        obj = type(self)(self.stress_vocabulary, self.metre_pattern, self.rhyme_pattern, self.n_syllables,
                         self.eos_index, self.letters_to_rhymes, self.score_border)
        obj.stress_position = self.stress_position
        obj.rhyme_position = self.rhyme_position
        obj.letters_to_rhymes = deepcopy(self.letters_to_rhymes)
        return obj