codeprep/pipeline/to_repr.py
# SPDX-FileCopyrightText: 2020 Hlib Babii <hlibbabii@gmail.com>
#
# SPDX-License-Identifier: Apache-2.0
import gzip
import logging
import os
import pickle
import platform
from multiprocessing.pool import Pool
from typing import List, Tuple, Set
from typing import Optional
import time
from tqdm import tqdm
from codeprep.bpepkg.bpe_encode import read_merges, BpeData
from codeprep.bpepkg.cache import read_bpe_cache
from codeprep.config import DEFAULT_BPE_DIR, NO_CASE_DIR, CASE_DIR, DEFAULT_BPE_CACHE_DIR, REWRITE_PREPROCESSED_FILE, \
CHUNKSIZE, LIMIT_FILES_SCANNING
from codeprep.pipeline import vocabloader
from codeprep.pipeline.bperegistry import CustomBpeConfig
from codeprep.pipeline.dataset import Dataset, NOT_FINISHED_EXTENSION
from codeprep.prepconfig import PrepParam, PrepConfig
from codeprep.preprocess.core import to_repr_list
from codeprep.preprocess.result import PreprocessingResult
from codeprep.preprocess.placeholders import placeholders
from codeprep.preprocess.tokens import TokenSequence
from codeprep.tokentypes.rootclasses import ParsedToken
from codeprep.tokentypes.word import SpecialToken
from codeprep.util.misc import to_literal_str
logger = logging.getLogger(__name__)
def get_global_bpe_data_if_available() -> Optional[BpeData]:
return global_bpe_data if 'global_bpe_data' in globals() else None
def to_repr(prep_config: PrepConfig, token_list: List[ParsedToken], bpe_data: Optional[BpeData] = None) -> PreprocessingResult:
bpe_data = bpe_data or get_global_bpe_data_if_available()
preprocessing_result = to_repr_list(token_list, prep_config.get_repr_config(bpe_data))
if prep_config.is_bpe():
preprocessing_result.prepped_tokens = insert_word_end_tokens_(preprocessing_result.prepped_tokens)
return preprocessing_result
def to_token_str(tokens: List) -> str:
return " ".join(map(lambda t: str(t), tokens))
def preprocess_and_write(params: Tuple[bytes, bytes, PrepConfig, str], bpe_data: Optional[BpeData] = None):
src_file_path, dest_file_path, prep_config, part_nonbpe_vocab_folder = params
dest_dirname = os.path.dirname(dest_file_path)
if not os.path.exists(dest_dirname):
os.makedirs(dest_dirname, exist_ok=True)
if not REWRITE_PREPROCESSED_FILE and os.path.exists(dest_file_path):
logger.warning(f"File {dest_file_path} already exists! Doing nothing.")
return
not_finished_dest_file_path = dest_file_path + NOT_FINISHED_EXTENSION.encode()
with gzip.GzipFile(src_file_path, 'rb') as i, open(not_finished_dest_file_path, 'w') as o:
token_list = pickle.load(i)
bpe_data = get_global_bpe_data_if_available() if bpe_data is None else bpe_data
preprocessing_result = to_repr(prep_config, token_list + [SpecialToken(placeholders['ect'])], bpe_data)
o.write(to_literal_str(to_token_str(preprocessing_result.prepped_tokens._tokens)) + '\n')
if part_nonbpe_vocab_folder:
save_non_processable_tokens(preprocessing_result.non_processable_tokens, os.path.join(part_nonbpe_vocab_folder, f'{os.path.basename(dest_file_path)}_-_{time.time()}'))
os.rename(not_finished_dest_file_path, dest_file_path)
#TODO make this method independent of actual directory structure
def init_bpe_data(prep_config: PrepConfig, custom_bpe_config: Optional[CustomBpeConfig], force_reinit: bool=True):
if get_global_bpe_data_if_available() and not force_reinit:
return # already initialized
global global_bpe_data
global_bpe_data = BpeData()
if custom_bpe_config:
logger.info(f'Using bpe merges file: {custom_bpe_config.codes_file}')
if custom_bpe_config.can_use_cache_file():
global_bpe_data.merges_cache = read_bpe_cache(custom_bpe_config.cache_file)
else:
global_bpe_data.merges_cache = {}
global_bpe_data.merges = read_merges(custom_bpe_config.codes_file, custom_bpe_config.n_merges)
if custom_bpe_config.n_merges:
logger.info(f'Using first {custom_bpe_config.n_merges} merges.')
nonbpe_vocab = vocabloader.nonbpe(custom_bpe_config.merge_list_id)
global_bpe_data.merges_cache.update({s: [s] for s in nonbpe_vocab})
else:
bpe_n_merges_dict = {'4': '5k', '5': '1k', '6': '10k', '7': '20k', '8': '0'}
bpe_n_merges = bpe_n_merges_dict[prep_config.get_param_value(PrepParam.SPLIT)]
bpe_merges_file = os.path.join(DEFAULT_BPE_DIR,
CASE_DIR if prep_config.get_param_value(PrepParam.CASE) == 'u' else NO_CASE_DIR,
str(bpe_n_merges), 'merges.txt')
bpe_merges_cache_file = os.path.join(DEFAULT_BPE_CACHE_DIR,
CASE_DIR if prep_config.get_param_value(PrepParam.CASE) == 'u' else NO_CASE_DIR,
str(bpe_n_merges), 'merges_cache.txt')
if os.path.exists(bpe_merges_cache_file):
global_bpe_data.merges_cache = read_bpe_cache(bpe_merges_cache_file)
else:
global_bpe_data.merges_cache = {}
global_bpe_data.merges = read_merges(bpe_merges_file)
def params_generator(dataset: Dataset, path_to_part_metadata: Optional[str]):
for input_file_path in dataset.parsed.file_iterator():
output_file_path = dataset.parsed.get_new_file_name(input_file_path, dataset.preprocessed)
yield (input_file_path, output_file_path, dataset.prep_config, path_to_part_metadata)
def get_n_cpus_to_be_used():
system_platform = platform.system()
n_cpus = 1 if system_platform in ['Windows', 'Darwin'] else os.cpu_count() or 1
logger.info(f"Platform: {system_platform}, n cores to be used: {n_cpus}")
return n_cpus
def run(dataset: Dataset, custom_bpe_config: Optional[CustomBpeConfig]) -> None:
path_to_parsed_dataset = dataset.parsed.path
if not os.path.exists(path_to_parsed_dataset):
logger.error(f"Dir does not exist: {path_to_parsed_dataset}")
exit(3)
logger.info(f"Reading parsed files from: {path_to_parsed_dataset}")
if dataset.prep_config.is_bpe():
init_bpe_data(dataset.prep_config, custom_bpe_config)
if not os.path.exists(dataset.path_to_nonbpe_vocab_file) and dataset.prep_config.is_base_bpe_config():
path_to_part_metadata = f'{dataset.path_to_nonbpe_vocab_file}_part'
else:
path_to_part_metadata = None
if path_to_part_metadata and not os.path.exists(path_to_part_metadata):
os.makedirs(path_to_part_metadata)
logger.info(f"Writing preprocessed files to {dataset.preprocessed.path}")
if dataset.files_need_to_be_saved():
files_total = 0
for _ in dataset.get_all_files():
files_total += 1
print(f"Files scanned: {files_total}", end='\r')
if files_total > LIMIT_FILES_SCANNING:
files_total = None
logger.info(f"Total files to be preprocessed: {LIMIT_FILES_SCANNING}+")
break
else:
files_total = len([f for f in dataset.get_all_files()])
n_cpus = get_n_cpus_to_be_used()
if n_cpus > 1:
with Pool(processes=n_cpus) as pool:
it = pool.imap_unordered(preprocess_and_write, params_generator(dataset, path_to_part_metadata), chunksize=CHUNKSIZE)
for _ in tqdm(it, total=files_total):
pass
else:
for params in tqdm(params_generator(dataset, path_to_part_metadata), total=files_total):
preprocess_and_write(params, get_global_bpe_data_if_available())
if path_to_part_metadata:
vocabloader.gather_non_bpe_vocab(dataset)
dataset.preprocessed.set_ready()
def save_non_processable_tokens(non_processable_tokens: Set[str], save_to: bytes) -> None:
with open(save_to, 'w') as f:
for token in non_processable_tokens:
f.write(f'{to_literal_str(token)}\n')
def insert_word_end_tokens_(token_seq: TokenSequence) -> TokenSequence:
assert not token_seq.word_end_token_added
new_tokens = []
for subtokens in token_seq.fulltokens.with_format(formatter=lambda x: x[:-1] + [x[-1] + placeholders['compound_word_end']]):
new_tokens.extend(subtokens)
return TokenSequence.create(new_tokens, token_seq.metadata, word_end_token_added=True)