codeprep/pipeline/bperegistry.py
# SPDX-FileCopyrightText: 2020 Hlib Babii <hlibbabii@gmail.com>
#
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from typing import Optional, Dict, Callable, Tuple, List
import regex
import sys
import time
from codeprep.bpepkg.bpe_config import BpeConfig
from codeprep.bpepkg.bpe_encode import read_merges
from codeprep.bpepkg.merge import MergeList
from codeprep.config import USER_BPE_DIR, USER_VOCAB_DIR
logger = logging.getLogger(__name__)
OTHER_VOCAB_FILE_NAME = "other_vocab"
BPE_REASSEMBLED_VOCAB_FILE_NAME = "bpe_vocab_reassembled.txt"
RESULTING_VOCAB_FILE_NAME = "vocab_res.txt"
MERGES_FILE_NAME = "merges.txt"
MERGES_CACHE_FILE_NAME = "merges_cache.txt"
BPE_CODES_ID_FILENAME = '.name'
USER_PREDEFINED_BPE_CODES = ['1k', '5k', '10k']
PREDEFINED_BPE_CODES = USER_PREDEFINED_BPE_CODES + ['0']
class InvalidBpeCodesIdError(Exception):
pass
class CustomBpeConfig(object):
def __init__(self, merge_list_id: str, n_merges: int, codes_file: str, cache_file: str):
self.merge_list_id = merge_list_id
self.n_merges = n_merges
self.codes_file = codes_file
self.cache_file = cache_file
def can_use_cache_file(self):
return self.cache_file and not self.n_merges
@staticmethod
def from_id(id_str: str) -> 'CustomBpeConfig':
return CustomBpeConfig.create(*parse_merge_list_id(id_str))
@staticmethod
def create(merge_list_id: str, n_merges: int) -> 'CustomBpeConfig':
dataset_bpe_dir = get_dataset_bpe_dir(merge_list_id)
min_merges = get_min_merges(dataset_bpe_dir, limit=n_merges)
# check if not None?
dir_with_min_merges = os.path.join(dataset_bpe_dir, str(min_merges))
if min_merges:
if not n_merges == 0 or min_merges == n_merges:
cache_file = os.path.join(dir_with_min_merges, MERGES_CACHE_FILE_NAME)
else:
cache_file = None
return CustomBpeConfig(merge_list_id, n_merges, os.path.join(dir_with_min_merges, MERGES_FILE_NAME), cache_file)
else:
raise InvalidBpeCodesIdError(
f"{n_merges} merges has not been computed for {merge_list_id}."
f"Max possible value: {get_max_merges(dataset_bpe_dir)}")
def __repr__(self):
return f'{self.__class__.__name__} ({self.merge_list_id}, {self.n_merges}, {self.codes_file}, {self.cache_file})'
def is_predefined_id(id: str):
"""
>>> is_predefined_id('1k')
True
>>> is_predefined_id('5k')
True
>>> is_predefined_id('10k')
True
>>> is_predefined_id('abc')
False
>>> is_predefined_id('abc-10')
False
"""
return id in PREDEFINED_BPE_CODES
def get_codes_id_by_bpe_path(dataset_bpe_dir: str) -> Optional[str]:
file_with_id = os.path.join(dataset_bpe_dir, BPE_CODES_ID_FILENAME)
if not os.path.exists(file_with_id):
return None
else:
with open(file_with_id, 'r') as f:
return f.read().strip()
def create_new_id_from(path: str, bpe_config: BpeConfig, predefined_bpe_codes_id: Optional[str] = None) -> str:
if predefined_bpe_codes_id:
return predefined_bpe_codes_id
else:
name_parts = [os.path.basename(path)]
id_suffix = bpe_config.to_suffix()
if id_suffix:
name_parts.append(id_suffix)
id_base = '_'.join(name_parts)
existing_ids = _get_all_custom_bpe_codes_and_max_merges().keys()
if id_base not in existing_ids:
return id_base
else:
def extract_number(full_id: str, id_base: str) -> int:
m = regex.fullmatch(f"{id_base}_([0-9]+)", full_id)
return int(m[1]) if m else 0
numbers = list(map(lambda d: extract_number(d, id_base), existing_ids))
new_number = max(numbers) + 1
return f'{id_base}_{new_number}'
def write_bpe_codes_id(dataset_bpe_dir: str, bpe_codes_id: str) -> None:
file_with_id = os.path.join(dataset_bpe_dir, BPE_CODES_ID_FILENAME)
with open(file_with_id, 'w') as f:
f.write(bpe_codes_id)
def parse_merge_list_id(s: str) -> Tuple[str, int]:
"""
>>> parse_merge_list_id("python-no-case-1000")
('python-no-case', 1000)
>>> parse_merge_list_id("python_1000")
Traceback (most recent call last):
...
codeprep.pipeline.bperegistry.InvalidBpeCodesIdError: Invalid id format: "python_1000". \
Format should be: "(.*)-([1-9][0-9]*)$"
>>> parse_merge_list_id("python-")
Traceback (most recent call last):
...
codeprep.pipeline.bperegistry.InvalidBpeCodesIdError: Invalid id format: "python-". \
Format should be: "(.*)-([1-9][0-9]*)$"
"""
REGEX = "(.*)-([1-9][0-9]*)$"
m = regex.match(REGEX, s)
if m:
return m[1], int(m[2])
else:
raise InvalidBpeCodesIdError(f'Invalid id format: "{s}". Format should be: "{REGEX}"')
def get_base_vocab_dir(bpe_list_id: str) -> str:
dataset_bpe_dir = get_dataset_bpe_dir(bpe_list_id)
prep_config_str = os.path.basename(dataset_bpe_dir)
#TODO do not hard code date and dir format in general
m = regex.fullmatch(r'(.*?)((?:_-_.*)?)', prep_config_str)
if not m:
raise ValueError(f'Invalid dir format: {prep_config_str}')
bpe_config = BpeConfig.from_suffix(m[2])
base_prep_config = bpe_config.to_prep_config()
return os.path.join(USER_VOCAB_DIR, f'{m[1]}_-_{base_prep_config}')
def get_dataset_bpe_dir(bpe_list_id: str) -> str:
if not os.path.exists(USER_BPE_DIR):
raise InvalidBpeCodesIdError(f"No custom bpe codes has been trained yet. "
f"To train a custom bpe code run: `codeprep learn-bpe` command")
bpe_dirs = next(os.walk(USER_BPE_DIR))[1]
for dir in bpe_dirs:
current_bpe_dir = os.path.join(USER_BPE_DIR, dir)
current_id = get_codes_id_by_bpe_path(current_bpe_dir)
if current_id and current_id == bpe_list_id:
return current_bpe_dir
raise InvalidBpeCodesIdError(f"Bpe id: {bpe_list_id} is not found. Possible values:\n {format_available_merge_list_ids()}")
def get_bpe_dir(merge_list_id: str, n_merges: int) -> str:
bpe_dir = os.path.join(get_dataset_bpe_dir(merge_list_id), str(n_merges))
if os.path.exists(bpe_dir):
return bpe_dir
else:
raise InvalidBpeCodesIdError(f'Dir {bpe_dir} not found.')
def load_bpe_merges(merge_list_id: str, n_merges: int) -> MergeList:
custom_bpe_config = CustomBpeConfig.create(merge_list_id, n_merges)
return read_merges(custom_bpe_config.codes_file, n_merges)
def format_available_merge_list_ids() -> str:
res = ""
for k, v in _get_all_custom_bpe_codes_and_max_merges().items():
res += f'{k}-[1..{v}]\n'
return res
def _get_extreme_n_merges(root_bpe_dir: str, limit: int, init_val: int, in_order: Callable[[int, int, int], bool]):
subdirs = _get_all_bpe_merges_dirs(root_bpe_dir)
extreme_value = init_val
for subdir in subdirs:
try:
num = int(subdir)
if in_order(extreme_value, num, limit):
extreme_value = num
except ValueError:
pass # for the case of part_vocab folder for example
if extreme_value != init_val:
return extreme_value
else:
return None
def get_min_merges(dataset_bpe_dir: str, limit: Optional[int]=0) -> Optional[int]:
return _get_extreme_n_merges(dataset_bpe_dir, limit, sys.maxsize, lambda e,n,l: e > n >= l)
def get_max_merges(dataset_bpe_dir: str, limit: Optional[int]=sys.maxsize) -> Optional[int]:
return _get_extreme_n_merges(dataset_bpe_dir, limit, 0, lambda e,n,l: e < n <= l)
def _get_all_custom_bpe_codes_and_max_merges() -> Dict[str, int]:
result = {}
bpe_dirs = next(os.walk(USER_BPE_DIR))[1]
for bpe_dir in bpe_dirs:
path = os.path.join(USER_BPE_DIR, bpe_dir)
code = get_codes_id_by_bpe_path(path)
max_merges = get_max_merges(path)
if code and max_merges:
result[code] = max_merges
return result
def _get_all_bpe_merges_dirs(dataset_bpe_dir: str) -> List[str]:
if not os.path.exists(dataset_bpe_dir):
raise FileNotFoundError(f'Directory {dataset_bpe_dir} does not exist!')
return next(os.walk(dataset_bpe_dir))[1]
def archive_existing_common_bpe_folder(dataset_bpe_dir: str) -> None:
if os.path.exists(dataset_bpe_dir):
logger.info(f'Archiving existing bpe dir. '
f'{dataset_bpe_dir} -> {dataset_bpe_dir}.{str(int(time.time()))}')
os.rename(dataset_bpe_dir, f'{dataset_bpe_dir}.{str(int(time.time()))}')