IlyaGusev/rulm

View on GitHub
rulm/train_tokenizer.py

Summary

Maintainability
A
1 hr
Test Coverage
import argparse
import random

from transformers import PreTrainedTokenizerFast
from datasets import load_dataset
from tokenizers import Tokenizer, models, pre_tokenizers, normalizers, Regex, decoders, trainers, processors

from rulm.util import read_jsonl


def train_tokenizer(
    dataset_path,
    train_path,
    output_dir,
    sample_rate,
    vocab_size
):
    assert train_path or dataset_path
    if train_path:
        dataset = load_dataset("rulm/jsonl_loader.py", data_files={"train": [train_path]}, streaming=True)["train"]
    elif dataset_path:
        dataset = load_dataset(dataset_path, streaming=True)["train"]

    def read_texts():
        for r in dataset:
            if random.random() < sample_rate:
                yield " ".join(r["text"].split(" ")[:1000000])

    tokenizer = Tokenizer(models.Unigram())
    tokenizer.normalizer = normalizers.Sequence([
        normalizers.NFKC(),
        normalizers.Replace(Regex(" {2,}"), " "),
        normalizers.Strip()
    ])
    tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
        pre_tokenizers.Metaspace(),
        pre_tokenizers.Digits(individual_digits=True),
        pre_tokenizers.Punctuation(behavior='isolated'),
        pre_tokenizers.Split(pattern="\n", behavior="isolated")
    ])
    tokenizer.decoder = decoders.Metaspace()

    special_tokens = ["<pad>", "<unk>", "<s>", "</s>", "<sep>"]
    trainer = trainers.UnigramTrainer(
        vocab_size=vocab_size,
        special_tokens=special_tokens,
        unk_token="<unk>"
    )
    tokenizer.train_from_iterator(read_texts(), trainer=trainer)

    bos_token_id = tokenizer.token_to_id("<s>")
    eos_token_id = tokenizer.token_to_id("</s>")
    sep_token_id = tokenizer.token_to_id("<sep>")
    tokenizer.post_processor = processors.TemplateProcessing(
        single="<s>:0 $A:0 </s>:0",
        pair="<s>:0 $A:0 <sep>:0 $B:1 </s>:1",
        special_tokens=[("<sep>", sep_token_id), ("<s>", bos_token_id), ("</s>", eos_token_id)],
    )
    encoding = tokenizer.encode("Привет! Как дела? 1994 + 11 = 2005\nПока!")
    print(encoding.tokens)
    wrapped_tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=tokenizer,
        bos_token="<s>",
        eos_token="</s>",
        unk_token="<unk>",
        pad_token="<pad>",
        sep_token="<sep>",
        padding_side="left",
    )
    wrapped_tokenizer.save_pretrained(output_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset-path", default=None)
    parser.add_argument("--train-path", default=None)
    parser.add_argument("--output-dir", required=True)
    parser.add_argument("--sample-rate", type=float, default=1.0)
    parser.add_argument("--vocab-size", type=int, default=50000)
    args = parser.parse_args()
    train_tokenizer(**vars(args))