giganticode/codeprep

View on GitHub
codeprep/bpepkg/merge.py

Summary

Maintainability
A
0 mins
Test Coverage
# SPDX-FileCopyrightText: 2020 Hlib Babii <hlibbabii@gmail.com>
#
# SPDX-License-Identifier: Apache-2.0

import copy

from typing import List, Tuple, Union, Optional, Iterator, Dict

from codeprep.util.misc import is_python_3_6_and_higher, to_literal_str, to_non_literal_str


# TODO this class should be frozen
class Merge(object):
    def __init__(self, pair: Tuple[str, str], freq: int = None, priority: int = None):
        self.pair = pair
        self.freq = freq
        self.priority = priority
    
    @classmethod
    def parse_file_entry(cls, line: str, priority: int) -> "Merge":
        try:
            spl = to_non_literal_str(line).split(" ")
            if len(spl) == 2:
                return cls((spl[0], spl[1]), priority=priority)
            else:
                return cls((spl[0], spl[1]), freq=int(spl[2]), priority=priority)
        except (IndexError, TypeError) as err:
            raise ValueError(f"Invalid merge entry format: {line}", err)
            
    def __str__(self):
        return self.__repr__()
    
    def __repr__(self):
        return f'{self.pair}: ({self.freq}, {self.priority})'

    def __eq__(self, other):
        return self.__class__ == other.__class__ and self.pair == other.pair and self.priority == other.priority \
               and self.freq == other.freq

    def __hash__(self):
        return hash((self.pair, self.priority, self.freq))


class MergeList(object):
    """
    >>> merges = MergeList()
    >>> merges = merges.append(Merge(('a', 'b'), 34, 0)).append(Merge(('b', 'c'), 44, 1))
    >>> [m for m in merges]
    [('a', 'b'): (34, 0), ('b', 'c'): (44, 1)]
    >>> len(merges)
    2
    >>> merges[0]
    ('a', 'b'): (34, 0)
    >>> merges[1]
    ('b', 'c'): (44, 1)
    >>> merges[-1]
    ('b', 'c'): (44, 1)
    >>> merges[0:-1]
    [('a', 'b'): (34, 0)]
    >>> type(merges[0:-1])
    <class 'list'>

    >>> merges[2]
    Traceback (most recent call last):
    ...
    IndexError: list index out of range

    >>> ('a', 'b') in merges
    True
    >>> ('a', 'x') in merges
    False

    >>> merge1 = Merge(('a', 'b'), 34, 0)
    >>> merge2 = Merge(('a', 'b'), 34, 0)
    >>> dct = {merge1: 3}
    >>> dct[merge2]
    3

    >>> merges + MergeList().append(Merge(('d', 'e'), 84, 0))
    [('a', 'b'): (34, 0), ('b', 'c'): (44, 1), ('d', 'e'): (84, 2)]
    >>> merges + [(('d', 'e'), 84, 1)]
    Traceback (most recent call last):
    ...
    TypeError: Cannot add <class 'list'> to a MergeList

    >>> merges + merges
    Traceback (most recent call last):
    ...
    ValueError: It's only possible to add merges in priority order. The priority of the next merge should be 2 but is 3

    >>> merges.append(Merge(('x', 'y'), 34, 0))
    Traceback (most recent call last):
    ...
    ValueError: It's only possible to add merges in priority order. The priority of the next merge should be 2 but is 0

    >>> merges = merges.append(Merge(('x', 'y'), 34))
    >>> merges
    [('a', 'b'): (34, 0), ('b', 'c'): (44, 1), ('x', 'y'): (34, 2)]
    >>> merges.get_priority(('x', 'y'))
    2
    """
    def __init__(self):
        self.merges: Dict[Tuple[str, str], Merge] = {}

    def __contains__(self, item):
        return item in self.merges

    def __len__(self):
        return len(self.merges)

    def __iter__(self) -> Iterator[Merge]:
        return iter(self._get_sorted_merges())

    def _get_sorted_merges(self) -> List[Merge]:
        if not is_python_3_6_and_higher():
            # we cannot rely on dict order for python versions lower than 3.6
            raise NotImplementedError()

        return list(self.merges.values())

    def __add__(self, other: 'MergeList'):
        if self.__class__ != other.__class__:
            raise TypeError(f"Cannot add {other.__class__} to a MergeList")

        new_merge_list = copy.deepcopy(self)
        other_copy = copy.deepcopy(other)
        first_list_len = len(new_merge_list)
        for merge in other_copy:
            merge.priority += first_list_len
            new_merge_list.append(merge)

        return new_merge_list

    def append(self, merge: Merge) -> 'MergeList':
        # along with the pair we save its priority and the number of its occurrences
        if merge.priority is None:
            merge.priority = len(self.merges)
        elif merge.priority != len(self.merges):
            raise ValueError(f"It's only possible to add merges in priority order. "
                             f"The priority of the next merge should be {len(self.merges)} but is {merge.priority}")

        self.merges[merge.pair] = merge
        return self

    def get_priority(self, pair: Tuple[str, str]) -> int:
        return self.merges[pair].priority

    def __getitem__(self, item) -> Union[List[Merge], Merge]:
        lst = self._get_sorted_merges()
        return lst[item]

    def __repr__(self):
        return repr(self[:])

    def __eq__(self, other):
        return self.__class__ == other.__class__ and self[:] == other[:]


def read_merges(file: str, n_merges: Optional[int] = None) -> MergeList:
    merges = MergeList()
    with open(file, 'r') as f:
        for idx, line in enumerate(f):
            if n_merges and idx >= n_merges:
                break
            line = line.rstrip('\n')
            merges.append(Merge.parse_file_entry(line, idx))
    return merges


def dump_merges(merges: MergeList, file: str):
    with open(file, 'w') as f:
        for merge in merges:
            f.write(f"{to_literal_str(' '.join(merge.pair))} {merge.freq}\n")