codeprep/api/common.py
# SPDX-FileCopyrightText: 2020 Hlib Babii <hlibbabii@gmail.com>
#
# SPDX-License-Identifier: Apache-2.0
import sys
from typing import Optional
from codeprep.prepconfig import PrepConfig, PrepParam, get_possible_str_values
def create_split_value(split_type: str, bpe_codes_id: Optional[str] = None, full_strings: bool = False,
split_numbers: bool = False, ronin: bool = False, stem: bool = False):
if split_type == 'nosplit':
return 'F' if full_strings else '0'
elif split_type == 'chars':
return '8'
elif split_type == 'basic':
if stem:
return 's'
elif ronin:
return '3'
elif split_numbers:
return '2'
else:
return '1'
elif split_type == 'bpe':
if bpe_codes_id == '1k':
return '5'
elif bpe_codes_id == '5k':
return '4'
elif bpe_codes_id == '10k':
return '6'
else:
return '9'
else:
raise AssertionError(f"Invalid split option: {split_type}")
def create_str_value(no_str: bool, max_str_len: int) -> str:
if no_str:
return '0'
if 0 <= max_str_len < 2:
return '2'
if 2 <= max_str_len < len(get_possible_str_values()):
return get_possible_str_values()[max_str_len]
else:
return '1'
def create_prep_config(spl_type: str, bpe_codes_id: Optional[str] = None, no_spaces: bool = False,
no_unicode: bool = False, no_case: bool = False, no_com: bool = False, no_str: bool = False,
full_strings: bool = False, max_str_length: int = sys.maxsize, split_numbers: bool = False,
ronin: bool = False, stem: bool = False):
return PrepConfig({
PrepParam.EN_ONLY: 'U' if no_unicode else 'u',
PrepParam.COM: '0' if no_com else 'c',
PrepParam.STR: create_str_value(no_str, max_str_length),
PrepParam.SPLIT: create_split_value(spl_type, bpe_codes_id=bpe_codes_id, full_strings=full_strings,
split_numbers=split_numbers, ronin=ronin, stem=stem),
PrepParam.TABS_NEWLINES: '0' if no_spaces else 's',
PrepParam.CASE: 'l' if no_case else 'u'
})