megagonlabs/bunkai

View on GitHub
bunkai/algorithm/lbd/predict.py

Summary

Maintainability
A
2 hrs
Test Coverage
F
42%
#!/usr/bin/env python3

import argparse
import sys
import typing
from pathlib import Path

import numpy as np
import torch
from more_itertools import chunked
from transformers.models.auto.modeling_auto import AutoModelForTokenClassification

import bunkai.constant
from bunkai.algorithm.lbd.corpus import LABEL_OTHER, LABEL_SEP, annotation2spans
from bunkai.algorithm.lbd.custom_tokenizers import JanomeSubwordsTokenizer, JanomeTokenizer
from bunkai.algorithm.lbd.train import BunkaiConfig, MyDataset
from bunkai.base.annotation import Tokens
from bunkai.third.utils_ner import InputExample, get_labels

StringMorphemeInputType = typing.List[typing.List[str]]  # (batch-size * variable-length of sentence * tokens)


class Predictor(object):
    def __init__(self, modelpath: Path) -> None:
        """
        Use JanomeTokenizer by default if the input is Document(String).

        If the input is Morpheme(String), Tokenizers are not called.

        :param subword_tokenizer_type: pre_tokenize, janome, basic
        """
        self.model = AutoModelForTokenClassification.from_pretrained(str(modelpath))
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)

        self.labels = get_labels(str(modelpath.joinpath("labels.txt")))
        self.label_map: typing.Dict[int, str] = {i: label for i, label in enumerate(self.labels)}

        with modelpath.joinpath("bunkai.json").open() as bcf:
            self.bc = BunkaiConfig.from_json(bcf.read())
        # use janome tokenizer or tokenizer based on a vocab-file.
        self.path_tokenizer_model: str = str(Path(modelpath).joinpath("vocab.txt"))
        self.tokenizer = JanomeSubwordsTokenizer(self.path_tokenizer_model)

        # hotfix
        if self.model.base_model_prefix == "distilbert" and "token_type_ids" in self.tokenizer.model_input_names:
            self.tokenizer.model_input_names.remove("token_type_ids")

    def _split_long_text(
        self, tokens: typing.List[str]
    ) -> typing.Tuple[typing.List[typing.List[str]], typing.List[typing.List[int]]]:
        """
        Split documents(tokens) into sub-documents(tokens).

        This is because Bert has the maximum token-length of the input.
        """
        # tokenized_spans_list is a temporary stack which holds subword-token and underbar.
        # That is because underbar is replaced into UNK if underbar is put into a subword-tokeniser.
        tokenized_spans_list: typing.List[typing.List[str]] = []
        tmp_stack: typing.List[str] = []
        for __t in tokens:
            if __t == bunkai.constant.METACHAR_LINE_BREAK:
                if len(tmp_stack) > 0:
                    tokenized_spans_list.append(tmp_stack)
                tokenized_spans_list.append([bunkai.constant.METACHAR_LINE_BREAK])
                tmp_stack = []
            else:
                # sub-word tokenize
                tmp_stack.append(__t)
        if len(tmp_stack) > 0:
            tokenized_spans_list.append(tmp_stack)

        # run subword-tokeniser
        processed_tokens: typing.List[typing.List[str]] = [[]]
        processed_num_sws: typing.List[typing.List[int]] = [[]]
        current_count: int = 0
        tokenized_span: typing.List[str]
        for tokenized_span in tokenized_spans_list:
            for word in tokenized_span:
                subwords = self.tokenizer.tokenize(word)
                num_subwords = len(subwords)
                current_count += num_subwords
                if current_count >= self.bc.max_seq_length:
                    processed_tokens.append([])
                    processed_num_sws.append([])
                    current_count = num_subwords
                    assert current_count < self.bc.max_seq_length
                processed_tokens[-1].append(word)
                processed_num_sws[-1].append(num_subwords)
        return processed_tokens, processed_num_sws

    def predict(self, documents_morphemes: StringMorphemeInputType) -> typing.List[typing.Set[int]]:
        """
        Run prediction on incoming inputs. Inputs are 2 dims array with [[sentence]].

        :param spans_list: 2 dims list (batch-size * variable-length of sentence) or
        [['ラウンジ', 'も', '気軽', 'に', '利用', 'でき', '、', '申し分', 'ない', 'です', '。', '▁', '']].
        """
        examples = []

        # Note: gave up to separate this process in MyDataset because _split_long_text calls a tokenizer.
        num_subwords = []
        for d_id, spans in enumerate(documents_morphemes):
            sentences_within_length, _num_subwords = self._split_long_text(spans)
            num_subwords += _num_subwords
            for local_s_id, tokenized_spans in enumerate(sentences_within_length):
                examples.append(
                    InputExample(
                        guid=f"{d_id}-{local_s_id}",
                        words=tokenized_spans,
                        labels=[LABEL_OTHER] * len(tokenized_spans),
                        is_document_first=bool(local_s_id == 0),
                    )
                )
        ds = MyDataset(examples, self.labels, self.bc.max_seq_length, self.tokenizer, False)

        kwargs = {
            "input_ids": torch.tensor([f.input_ids for f in ds.features], device=self.device),
            "attention_mask": torch.tensor([f.attention_mask for f in ds.features], device=self.device),
        }
        if self.model.base_model_prefix == "bert":  # hotfix
            kwargs["token_type_ids"] = torch.tensor([f.token_type_ids for f in ds.features], device=self.device)

        predictions = self.model(**kwargs).logits.to("cpu").detach().numpy()

        assert isinstance(
            predictions, np.ndarray
        ), f"Unexpected error. A value type of a model prediction is {type(predictions)}. expect = numpy.ndarray"
        assert len(predictions.shape) == 3, (
            f"Unexpected error. A value tensor of a model prediction is {len(predictions.shape)} tensor. "
            f"expect = 3rd tensor."
        )

        out: typing.List[typing.Set[int]] = []
        word_idx_offset = 0
        for idx, (example, pred) in enumerate(zip(examples, predictions)):
            sw_idx = 0
            if example.is_document_first:
                word_idx_offset = 0  # reset
                out.append(set())
            else:
                word_idx_offset += len(examples[idx - 1].words)

            for word_idx, _ in enumerate(example.words):
                num_sw: int = num_subwords[idx][word_idx]
                while num_sw > 0:
                    sw_idx += 1
                    label_high_prob = int(np.argmax(pred[sw_idx]))
                    if self.label_map[label_high_prob] == LABEL_SEP:
                        out[-1].add(word_idx + word_idx_offset)
                    num_sw -= 1

        return out


def get_opts() -> argparse.Namespace:
    oparser = argparse.ArgumentParser()
    oparser.add_argument("--input", "-i", type=argparse.FileType("r"), required=False, default=sys.stdin)
    oparser.add_argument(
        "--output",
        "-o",
        type=argparse.FileType("w"),
        required=False,
        default=sys.stdout,
    )
    oparser.add_argument("--model", "-m", type=Path, required=True)
    oparser.add_argument("--batch", "-b", type=int, default=1, help="Number of documents to feed a batch")
    return oparser.parse_args()


def generate_initial_annotation_obj(
    input_stream: typing.Iterator[str],
) -> typing.Iterator[typing.List[str]]:
    tokenizer = JanomeTokenizer()
    for document in input_stream:
        # document: a text = document
        assert bunkai.constant.METACHAR_SENTENCE_BOUNDARY not in document
        document_spans: Tokens = annotation2spans(document[:-1])
        document_tokens: typing.List[str] = []
        for fragment in document_spans.spans:
            if bunkai.constant.METACHAR_LINE_BREAK in fragment:
                document_tokens.append(fragment)
            else:
                tokens = tokenizer.tokenize(fragment)
                document_tokens += tokens

        assert "".join(document_tokens) == "".join(document_spans.spans)
        yield document_tokens


def main() -> None:
    opts = get_opts()
    pdt = Predictor(opts.model)

    with opts.input as inf, opts.output as outf:
        for one_batch in chunked(generate_initial_annotation_obj(inf), n=opts.batch):
            for did, token_ids_seps in enumerate(pdt.predict(one_batch)):
                for tid, token in enumerate(one_batch[did]):
                    outf.write(token)
                    if tid in token_ids_seps:
                        outf.write(bunkai.constant.METACHAR_SENTENCE_BOUNDARY)
                else:
                    outf.write("\n")


if __name__ == "__main__":
    main()