giganticode/codeprep

View on GitHub
codeprep/prepconfig.py

Summary

Maintainability
A
1 hr
Test Coverage
# SPDX-FileCopyrightText: 2020 Hlib Babii <hlibbabii@gmail.com>
#
# SPDX-License-Identifier: Apache-2.0

"""
This module encapsulate all the tricky ilogic of encoding preprocessing options into e.g. 30100
"""

import logging
from enum import Enum
from typing import Dict, List, Type, Optional

import sys

from codeprep.bpepkg.bpe_encode import BpeData, get_bpe_subwords
from codeprep.preprocess.reprconfig import Splitter, ReprConfig
from codeprep.tokentypes.containers import Identifier, StringLiteral, OneLineComment, MultilineComment
from codeprep.tokentypes.noneng import NonEng
from codeprep.tokentypes.numeric import Number
from codeprep.tokentypes.whitespace import NewLine, Tab
from codeprep.tokentypes.word import Word

logger = logging.getLogger(__name__)


class PrepParam(str, Enum):
    EN_ONLY: str = 'enonly'
    COM: str = 'com'
    STR: str = 'str'
    SPLIT: str = 'split'
    TABS_NEWLINES: str = 'tabsnewlines'
    CASE: str = 'caps'


def get_possible_str_values() -> List[str]:
    RANGES = [(48, 58), (65, 91), (97, 123)]
    return list(map(lambda x: chr(x), [e for r in RANGES for e in list(range(*r))]))


def get_max_str_length(ch: str) -> Optional[int]:
    num = get_possible_str_values().index(ch)
    if num == 0:
        return None
    elif num == 1:
        return sys.maxsize
    else:
        return num


class PrepConfig(object):
    possible_param_values = {
        PrepParam.EN_ONLY: ['u', 'U'],
        PrepParam.COM: ['0', 'c'],
        PrepParam.STR: get_possible_str_values(),
        PrepParam.SPLIT: ['0', 'F', '1', '2', '3', 's', '4', '5', '6', '7', '8', '9'],
        PrepParam.TABS_NEWLINES: ['s', '0'],
        PrepParam.CASE: ['u', 'l'],
    }

    human_readable_values = {
        PrepParam.EN_ONLY: {'u': 'multilang',
                            'U': 'asci_only'},
        PrepParam.COM: {'0': 'NO_comments',
                            'c': 'comments'},
        PrepParam.STR: {k: k for k in get_possible_str_values()},
        PrepParam.SPLIT: {'0': 'NO_splitting',
                          'F': 'No splitting + full strings',
                          '1': 'camel+underscore',
                          '2': 'camel+underscore+numbers',
                          '3': 'numbers+ronin',
                          's': 'camel+underscore+numbers+stemming',
                          '4': 'No splitting+bpe_5k',
                          '5': 'No splitting+bpe_1k',
                          '6': 'No splitting+bpe_10k',
                          '7': 'No splitting+bpe_20k',
                          '8': 'No splitting+bpe_0',
                          '9': 'No splitting+bpe_custom'},
        PrepParam.TABS_NEWLINES: {'s': 'tabs+newlines',
                                  '0': 'NO_tabs+NO_newlines'},
        PrepParam.CASE: {
            'u': 'case_preserved',
            'l': 'lowercased'
        }
    }


    @staticmethod
    def __check_param_number(n_passed_params: int):
        n_expected_params = len([i for i in PrepParam])
        if n_passed_params != n_expected_params:
            raise ValueError(f'Expected {n_expected_params} params, got {n_passed_params}')

    @classmethod
    def from_encoded_string(cls, s: str):
        PrepConfig.__check_param_number(len(s))

        res = {}
        for ch, pp in zip(s, PrepParam):
            res[pp] = ch
        return cls(res)

    @staticmethod
    def __check_invariants(params: Dict[PrepParam, str]):
        PrepConfig.__check_param_number(len(params))
        for pp in PrepParam:
            if params[pp] not in PrepConfig.possible_param_values[pp]:
                raise ValueError(f'Invalid value {params[pp]} for prep param {pp}, '
                                 f'possible values are: {PrepConfig.possible_param_values[pp]}')

        if params[PrepParam.CASE] == 'l' and params[PrepParam.SPLIT] in ['0', 'F']:
            raise ValueError("Combination NOSPLIT and LOWERCASED is not supported: "
                             "basic splitting needs to be dont done to lowercase the subword.")

        if params[PrepParam.CASE] == 'l' and params[PrepParam.SPLIT] in ['4', '5', '6', '7', '9']:
            raise ValueError("Combination BPE and LOWERCASE is not supported:")

    def __init__(self, params: Dict[PrepParam, str]):
        PrepConfig.__check_invariants(params)

        self.params = params

    def __str__(self) -> str:
        res = ""
        for k in PrepParam:
            res += self.params[k]
        return res

    def __repr__(self):
        return str(self.params)

    def get_param_value(self, param: PrepParam) -> str:
        return self.params[param]

    def __eq__(self, other):
        return self.params == other.params

    def get_number_splitter(self) -> Splitter:
        split_param_value = self.get_param_value(PrepParam.SPLIT)
        if split_param_value in ['0', 'F', '1']:
            return lambda s,c: [s]
        elif split_param_value in ['2', '3', 's']:
            return lambda s,c: [ch for ch in s]
        elif split_param_value in ['4', '5', '6', '7', '8', '9']:
            return lambda s,c: get_bpe_subwords(s, c)
        else:
            raise ValueError(f"Invalid SPLIT param value: {split_param_value}")

    def get_word_splitter(self) -> Optional[Splitter]:
        split_param_value = self.get_param_value(PrepParam.SPLIT)
        if split_param_value in ['4', '5', '6', '7', '8', '9']:
            return lambda s, c: get_bpe_subwords(s, c)
        elif split_param_value in ['1', '2']:
            return lambda s,c: [s]
        elif split_param_value == '3':
            from spiral import ronin
            return lambda s, c: ronin.split(s)
        elif split_param_value == 's':
            from codeprep.stemming import stem
            from spiral import ronin
            return lambda s,c: list(map(lambda ss: stem(ss), ronin.split(s)))
        elif split_param_value in ['0', 'F']:
            return None
        else:
            raise ValueError(f"Invalid SPLIT param value: {split_param_value}")

    def get_types_to_be_repr(self) -> List[Type]:
        res = []
        if self.get_param_value(PrepParam.SPLIT) in ['1', '2', '3', '4', '5', '6', '7', '8', '9', 's']:
            res.extend([Identifier, Word])
        if self.get_param_value(PrepParam.SPLIT) in ['2', '3', '4', '5', '6', '7', '8', '9', 's']:
            res.append(Number)
        if self.get_param_value(PrepParam.COM) == '0':
            res.extend([OneLineComment, MultilineComment])
        if self.get_param_value(PrepParam.STR) == '0':
            res.append(StringLiteral)
        if self.get_param_value(PrepParam.EN_ONLY) == 'U':
            res.append(NonEng)
        if self.get_param_value(PrepParam.TABS_NEWLINES) == '0':
            res.extend([NewLine, Tab])
        return res

    def get_repr_config(self, bpe_data: Optional[BpeData]):
        return ReprConfig(self.get_types_to_be_repr(),
                          bpe_data if self.is_bpe() else None,
                          self.get_param_value(PrepParam.CASE) == 'l',
                          self.get_number_splitter(),
                          self.get_word_splitter(),
                          self.get_param_value(PrepParam.SPLIT) == 'F',
                          get_max_str_length(self.get_param_value(PrepParam.STR)))

    BPE_SPLIT_VALUES = ['4', '5', '6', '7', '8', '9']

    def is_bpe(self):
        """
        Check if this config corresponds to preprocessing with BPE.
        Note: splitting into chars is implemented as BPE with 0 merges, so in this case this method will also return True.

        :return: True if this config corresponds to preprocessing with BPE, False otherwise.
        """
        return self.get_param_value(PrepParam.SPLIT) in PrepConfig.BPE_SPLIT_VALUES

    #TODO make use of basic_bpe mask
    def is_base_bpe_config(self):
        return self.get_param_value(PrepParam.COM) == '0' \
               and self.get_param_value(PrepParam.STR) == 'E' \
               and self.get_param_value(PrepParam.SPLIT) == 'F' \
               and self.get_param_value(PrepParam.TABS_NEWLINES) == 's' \
               and self.get_param_value(PrepParam.CASE) == 'u'