giganticode/codeprep

View on GitHub
codeprep/pipeline/bpelearner.py

Summary

Maintainability
A
0 mins
Test Coverage
# SPDX-FileCopyrightText: 2020 Hlib Babii <hlibbabii@gmail.com>
#
# SPDX-License-Identifier: Apache-2.0

import logging
import os
from typing import Tuple, Dict, Set, Optional

from codeprep.bpepkg.bpe_config import BpeConfig, BpeParam, BpeConfigNotSupported
from codeprep.bpepkg.bpe_encode import escape
from codeprep.bpepkg.bpe_learn import separate_vocabs, logger, do_merges, create_resulting_vocab, create_bpe_cache
from codeprep.bpepkg.cache import dump_bpe_cache
from codeprep.bpepkg.merge import MergeList, read_merges, dump_merges
from codeprep.pipeline import stages
from codeprep.pipeline.bperegistry import get_max_merges, MERGES_FILE_NAME, MERGES_CACHE_FILE_NAME, \
    RESULTING_VOCAB_FILE_NAME, BPE_REASSEMBLED_VOCAB_FILE_NAME
from codeprep.pipeline.dataset import Dataset
from codeprep.pipeline.vocab import _dump_vocab_dict, _load_vocab_dict
from codeprep.util.misc import to_non_literal_str


def get_base_vocab(dataset: Dataset) -> Tuple[Dict[str, int], Dict[str, int]]:
    stages.run_until_base_bpe_vocab(dataset)
    all_vocab = _load_vocab_dict(dataset.path_to_bpe_vocab_file)
    non_bpe_vocab = load_nonbpe_vocab(dataset)
    return separate_vocabs(all_vocab, non_bpe_vocab)


def load_nonbpe_vocab(dataset: Dataset) -> Set[str]:
    non_bpe_vocab = set()
    with open(dataset.path_to_nonbpe_vocab_file, 'r') as f:
        for line in f:
            non_bpe_vocab.add(to_non_literal_str(line.rstrip('\n')))
    return non_bpe_vocab


def check_if_bpe_config_supported(bpe_config: BpeConfig):
    if bpe_config.get_param_value(BpeParam.UNICODE) == 'bytes':
        raise BpeConfigNotSupported('Byte-BPE is not yet supported')

    if bpe_config.get_param_value(BpeParam.WORD_END):
        raise BpeConfigNotSupported('BPE with word-end characters are not yet supported')

    if bpe_config.get_param_value(BpeParam.CASE) == 'prefix':
        raise BpeConfigNotSupported('BPE with case encoded in prefix is not yet supported')


def prepare_vocabs(dataset: Dataset, dir_with_most_merges, starting_from_scratch):
    if starting_from_scratch:
        base_bpe_vocab, other_vocab = get_base_vocab(dataset)  # TODO extract this into stages
        other_vocab = {escape(k, merged=True): v for k, v in other_vocab.items()}
        split_base_vocab = {escape(" ".join(k)): v for k, v in base_bpe_vocab.items()}
    else:
        path_to_bpe_vocab_file = os.path.join(dir_with_most_merges, BPE_REASSEMBLED_VOCAB_FILE_NAME)
        non_bpe_vocab = {escape(k, merged=True) for k in load_nonbpe_vocab(dataset)}
        split_base_vocab = _load_vocab_dict(path_to_bpe_vocab_file)
        split_base_vocab, other_vocab = separate_vocabs(split_base_vocab, non_bpe_vocab)

    return split_base_vocab, other_vocab


def get_dir_with_most_merges(dataset_bpe_path, n_merges) -> Optional[str]:
    max_merges = get_max_merges(dataset_bpe_path, n_merges)
    if not max_merges:
        return None

    dir_with_most_merges = os.path.join(dataset_bpe_path, str(max_merges))
    return dir_with_most_merges


def save_results(split_base_vocab, merges, new_bpe_dir):

    os.makedirs(new_bpe_dir)

    resulting_vocab = create_resulting_vocab(split_base_vocab)
    resulting_vocab_sorted = sorted(resulting_vocab.items(), key=lambda x: x[1], reverse=True)
    _dump_vocab_dict(resulting_vocab_sorted, os.path.join(new_bpe_dir, RESULTING_VOCAB_FILE_NAME))

    bpe_cache = create_bpe_cache(split_base_vocab)
    dump_bpe_cache(bpe_cache, os.path.join(new_bpe_dir, MERGES_CACHE_FILE_NAME))

    dump_merges(merges, os.path.join(new_bpe_dir, MERGES_FILE_NAME))
    _dump_vocab_dict(split_base_vocab.items(), os.path.join(new_bpe_dir, BPE_REASSEMBLED_VOCAB_FILE_NAME))
    logger.info(f'Bpe output files are saved into {new_bpe_dir} folder')


def run(dataset: Dataset, n_merges: int, bpe_config: BpeConfig) -> None:

    check_if_bpe_config_supported(bpe_config)
    dataset_bpe_path = dataset.bpe_path

    dir_with_most_merges = get_dir_with_most_merges(dataset_bpe_path, n_merges)

    if dir_with_most_merges:
        logger.info("Using existing merges...")
        already_done_merges = read_merges(os.path.join(dir_with_most_merges, MERGES_FILE_NAME))
    else:
        logger.info("Starting encoding from scratch.    ..")
        already_done_merges = MergeList()

    split_base_vocab, other_vocab = prepare_vocabs(dataset, dir_with_most_merges,
                                                   starting_from_scratch=not dir_with_most_merges)

    logger.info("Learning bpe codes...")
    split_base_vocab, merges = do_merges(split_base_vocab, n_merges - len(already_done_merges))
    for k, v in other_vocab.items():
        split_base_vocab[k] = v
    merges = already_done_merges + merges

    new_bpe_dir = os.path.join(dataset_bpe_path, str(len(merges)))
    if os.path.exists(new_bpe_dir):
        logging.info("Merges already learned!")
        return

    save_results(split_base_vocab, merges, new_bpe_dir)