IlyaGusev/rulm

View on GitHub
data_processing/undup.py

Summary

Maintainability
A
55 mins
Test Coverage
import argparse
import json
import re
import os
import hashlib
import fcntl
import multiprocessing
from multiprocessing.pool import ThreadPool
import zstandard
from tqdm import tqdm
from collections import defaultdict

from datasets import load_dataset
from datasketch import MinHash, MinHashLSH, LeanMinHash

from data_processing.util import read_jsonl, PlainArchive, ngrams


def re_tokenize(text):
    return re.findall(r'[а-яё-]+|[a-z-]+|\d+|\S', text, re.I)


def calc_fingerprint(record, ngram_size: int = 1, num_perm: int = 128):
    tokens = re_tokenize(record["text"])
    if ngram_size > 1:
        tokens = {" ".join(t) for t in ngrams(tokens, ngram_size)}
    tokens = [token.encode('utf-8') for token in tokens]

    minhash = MinHash(num_perm=num_perm)
    minhash.update_batch(tokens)

    lean_minhash = LeanMinHash(minhash)
    buf = bytearray(lean_minhash.bytesize())
    lean_minhash.serialize(buf)

    return {"minhash": buf}


def main(
    input_path,
    output_path,
    num_perm
):
    dataset = load_dataset("rulm/jsonl_loader.py", data_files={"train": [input_path]})["train"]
    dataset = dataset.map(
        function=calc_fingerprint,
        fn_kwargs={
            "num_perm": num_perm,
            "ngram_size": 1,
        },
        num_proc=os.cpu_count(),
        desc="Fingerprinting..."
    )

    archive = PlainArchive(output_path)

    threshold = 0.95
    false_positive_weight = 0.05
    lsh = MinHashLSH(
        threshold=threshold,
        weights=(false_positive_weight, 1 - false_positive_weight),
        num_perm=num_perm,
    )

    for idx, record in tqdm(enumerate(dataset)):
        minhash = LeanMinHash.deserialize(record["minhash"])

        is_dup = False
        for other_idx in lsh.query(minhash):
            other_record = dataset[other_idx]
            other_minhash = LeanMinHash.deserialize(other_record["minhash"])
            if minhash.jaccard(other_minhash) > threshold:
                is_dup = True
                break

        if not is_dup or record["meta"]["source"] == "math":
            record.pop("minhash")
            text = record["text"]
            meta = record["meta"]
            archive.add_data(text=text, meta=meta)

        lsh.insert(idx, minhash)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("input_path", type=str)
    parser.add_argument("output_path", type=str)
    parser.add_argument("--num-perm", type=int, default=128)
    args = parser.parse_args()
    main(**vars(args))