codeprep/bpepkg/wild_bpe.py
# SPDX-FileCopyrightText: 2020 Hlib Babii <hlibbabii@gmail.com>
#
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import sys
from collections import defaultdict
from enum import Enum, auto
import time
from typing import List, Dict, Tuple, Set, Generator, Optional
from codeprep.util.misc import PriorityCounter, getsize
logger = logging.getLogger(__name__)
__version__ = '0.2'
class BpePerformanceStatsEntry(object):
def __init__(self, merges_done: int, time_for_last_merge: float,
n_priority_queue_entries: int,
n_index_enties: int,
location_index_obj_size: float,
neighbour_index_obj_size: float,
priority_counter_obj_size: float
):
self.merges_done = merges_done
self.time_for_last_merge = time_for_last_merge
self.n_priority_queue_entries = n_priority_queue_entries
self.n_index_entries = n_index_enties
self.location_index_obj_size = location_index_obj_size
self.neighbour_index_obj_size = neighbour_index_obj_size
self.priority_counter_obj_size = priority_counter_obj_size
class Side(Enum):
RIGHT = auto()
LEFT = auto()
@staticmethod
def any():
return Side.LEFT
def opposite(self):
if self.value == Side.LEFT.value:
return Side.RIGHT
elif self.value == Side.RIGHT.value:
return Side.LEFT
def get_char_iterator_for_file(path_to_file: str):
try:
yield from get_char_iterator_for_file_with_encoding(path_to_file, 'utf-8')
except UnicodeDecodeError:
yield from get_char_iterator_for_file_with_encoding(path_to_file, 'ISO-8859-1')
def escape_char(char: str):
return '\xA0' if str(char) in [' '] else str(char)
def get_char_iterator_for_file_with_encoding(path_to_file: str, encoding: str) -> Generator[str, None, None]:
with open(path_to_file, encoding=encoding) as f:
while True:
char = f.read(1)
if char:
escaped_char = escape_char(char)
yield escaped_char
else:
return
def get_char_iterator_for_dir(path_to_dir: str) -> Generator[str, None, None]:
for root, dirs, files in os.walk(path_to_dir):
for file in files:
if file.endswith('.py'):
yield from get_char_iterator_for_file(os.path.join(root, file))
for i in range(3):
yield str("\n")
def swap_pair(pair: str) -> str:
pair_split = pair.split(" ")
return " ".join([pair_split[1], pair_split[0]])
def are_symmetric(pair1: str, pair2: str):
"""
>>> are_symmetric("abc dcba", "dcba abc")
True
>>> are_symmetric("abc dfe", "efd cba")
False
>>> are_symmetric("a c", "ac")
False
"""
if len(pair1) != len(pair2):
return False
# TODO make it faster
split1 = pair1.split(" ")
split2 = pair2.split(" ")
return split1[0] == split2[1] and split1[1] == split2[0]
def build_indices(it: Generator[str, None, None]) -> Tuple[Dict[str, List[int]], Dict[str, Dict[Side, Set[str]]]]:
index = defaultdict(list)
index_index = defaultdict(lambda: {Side.LEFT: set(), Side.RIGHT: set()})
first_char = next(it)
last_key = None
for i, second_char in enumerate(it):
key = ' '.join([first_char, second_char])
index[key].append(i)
if last_key:
index_index[last_key][Side.RIGHT].add(key)
index_index[key][Side.LEFT].add(last_key)
last_key = key
first_char = second_char
return index, index_index
def merge_lists(main_list: List[int], list2: List[int], position_shift: int) -> Tuple[List[int], List[int]]:
list2_result = []
result = []
i = 0; j = 0
while i < len(main_list) and j < len(list2):
if main_list[i] + position_shift == list2[j]:
result.append(min(main_list[i], list2[j]))
i +=1; j+=1
elif main_list[i] + position_shift > list2[j]:
list2_result.append(list2[j])
j += 1
else:
i += 1
list2_result.extend(list2[j:])
return list2_result, result
def self_merge(main_list, position_shift):
result = []
i = 0
while i < len(main_list)-1:
if main_list[i] + position_shift == main_list[i+1]:
result.append(main_list[i])
i += 2
else:
i += 1
return [], result
def merge_lists_both(main_list: List[int], list2: List[int], position_shift: Tuple[int, int]) -> Tuple[List[int], List[int]]:
"""
>>> merge_lists_both([0, 5, 7, 11, 16], [1, 9, 15], (2, -2))
([1, 15], [7])
>>> merge_lists_both([2, 7, 10, 12, 15], [1, 3, 6, 11], (1, -1))
([1, 3, 6], [10])
>>> merge_lists_both([0, 2, 4], [1, 3], (1, -1))
([], [0, 2])
"""
# position shift should be one positive and one negative number
if not (position_shift[0] >0 and position_shift[1] < 0):
raise AssertionError()
if main_list == list2:
raise AssertionError("")
list2_result = []
result = []
i = 0; j = 0
while i < len(main_list)-1 and j < len(list2):
if main_list[i] + position_shift[0] == list2[j]:
if main_list[i+1] + position_shift[1] == list2[j]:
result.append(main_list[i])
else:
list2_result.append(list2[j])
i +=1; j+=1
elif main_list[i] + position_shift[0] > list2[j]:
list2_result.append(list2[j])
j += 1
else:
i += 1
list2_result.extend(list2[j:])
return list2_result, result
def is_left(main_pair: str, pair2: str):
right = (pair2.split(" ")[0] == main_pair.split(" ")[1])
return not right
def merge_pair(pair: str) -> str:
return "".join(pair.split(" "))
def concat_pairs(main_pair: str, pair2: str, side: Side):
if not can_be_concat(main_pair, pair2, side):
raise AssertionError()
merged_main_pair = merge_pair(main_pair)
if side.value == Side.LEFT.value:
return " ".join([pair2.split(" ")[0],merged_main_pair])
elif side.value == Side.RIGHT.value:
return " ".join([merged_main_pair, pair2.split(" ")[1]])
def can_be_concat(main_pair: str, pair: str, side: Side):
"""
>>> can_be_concat("ab cd", "1 ab", Side.LEFT)
True
>>> can_be_concat("ab cd", "1 ab", Side.RIGHT)
False
"""
if side.value == Side.LEFT.value:
return main_pair.split(" ")[0] == pair.split(" ")[1]
elif side.value == Side.RIGHT.value:
return main_pair.split(" ")[1] == pair.split(" ")[0]
def double_pair(pair: str):
merged_pair = merge_pair(pair)
return " ".join([merged_pair, merged_pair])
def calc_position_shift(main_pair: str, pair2: str, side: Side):
if side.value == Side.LEFT.value:
return -len(pair2.split(" ")[0])
elif side.value == Side.RIGHT.value:
return len(main_pair.split(" ")[0])
def add_pairs_to_neighbour_index(index, pair1, pair2, side, location_index):
if not can_be_concat(pair1, pair2, side):
raise AssertionError("")
if pair2 in location_index:
index[pair1][side].add(pair2)
index[pair2][side.opposite()].add(pair1)
def choose_positions_to_merge(main_list, position_shift):
"""
>>> choose_positions_to_merge([0 ,1, 2, 5, 8, 9, 10, 11, 12, 20, 33, 34], 1)
([0, 2, 5, 8, 10, 12, 20, 33], [1, 9, 11, 34])
>>> choose_positions_to_merge([0], 1)
([0], [])
"""
result_main = []
result_disappearing = []
i = 0
while i < len(main_list)-1:
result_main.append(main_list[i])
if main_list[i] + position_shift == main_list[i+1]:
result_disappearing.append(main_list[i+1])
i += 1
i += 1
if i == len(main_list) -1:
result_main.append(main_list[i])
return result_main, result_disappearing
def cleanup_location_index(location_index, most_freq_pair, disappearing_pairs):
for side in Side:
for disappearing_pair in disappearing_pairs[side]:
if len(location_index[disappearing_pair]) == 0:
del location_index[disappearing_pair]
if most_freq_pair in location_index:
# check needed for the case when most freq pair was also a disappearing pair
del location_index[most_freq_pair]
def cleanup_neighbour_index(location_index, neighbour_index, most_freq_pair):
for side in Side:
disappearing_pairs = neighbour_index[most_freq_pair][side]
for disappearing_pair in disappearing_pairs:
if disappearing_pair not in location_index and disappearing_pair in neighbour_index:
del neighbour_index[disappearing_pair]
del neighbour_index[most_freq_pair]
def update_location_index(location_index, neighbour_index, pair_to_merge):
occurence_changes = []
disappearing_pairs = neighbour_index[pair_to_merge]
main_list = location_index[pair_to_merge]
if pair_to_merge in neighbour_index[pair_to_merge][Side.any()]:
main_list, disappearing_pair_list_for_merge_pair = choose_positions_to_merge(
main_list,
calc_position_shift(pair_to_merge, pair_to_merge, Side.RIGHT)
)
for side in Side:
for disappearing_pair in disappearing_pairs[side]:
if pair_to_merge != disappearing_pair:
disappearing_pair_list = location_index[disappearing_pair]
elif side.value == Side.RIGHT.value:
disappearing_pair_list = disappearing_pair_list_for_merge_pair
else:
continue
if can_be_concat(disappearing_pair, pair_to_merge, side) and side.value == Side.RIGHT.value:
appeared_pair = double_pair(pair_to_merge)
position_shift = (
calc_position_shift(pair_to_merge, disappearing_pair, side),
calc_position_shift(pair_to_merge, disappearing_pair, side.opposite())
)
disappearing_pair_list, appeared_pairs_locations = merge_lists_both(
main_list, disappearing_pair_list, position_shift
)
if len(appeared_pairs_locations) > 0:
location_index[appeared_pair] = appeared_pairs_locations
reduced_occurences = len(location_index[appeared_pair])
occurence_changes.append((appeared_pair, appeared_pairs_locations[0], disappearing_pair, disappearing_pair_list[0] if disappearing_pair_list else -1, reduced_occurences))
appeared_pair = concat_pairs(pair_to_merge, disappearing_pair, side)
position_shift = calc_position_shift(pair_to_merge, disappearing_pair, side)
disappearing_pair_list, appeared_pairs_locations = merge_lists(
main_list, disappearing_pair_list, position_shift
)
if len(appeared_pairs_locations) > 0:
location_index[appeared_pair] = appeared_pairs_locations
reduced_occurences = len(location_index[appeared_pair])
occurence_changes.append((appeared_pair, appeared_pairs_locations[0], disappearing_pair, disappearing_pair_list[0] if disappearing_pair_list else -1, reduced_occurences))
location_index[disappearing_pair] = disappearing_pair_list
cleanup_location_index(location_index, pair_to_merge, disappearing_pairs)
return occurence_changes
def update_neighbour_index(location_index, neighbour_index, pair_to_merge):
for side in Side:
disappearing_pairs = neighbour_index[pair_to_merge][side]
for disappearing_pair in disappearing_pairs:
if can_be_concat(disappearing_pair, pair_to_merge, side):
appeared_pair = double_pair(pair_to_merge)
if appeared_pair in location_index:
for disappeared_pair2 in disappearing_pairs:
mm = concat_pairs(pair_to_merge, disappeared_pair2, side)
add_pairs_to_neighbour_index(neighbour_index, appeared_pair, mm, side, location_index)
if can_be_concat(pair_to_merge, mm, side.opposite()):
mm_concat = concat_pairs(pair_to_merge, mm, side.opposite())
add_pairs_to_neighbour_index(neighbour_index, appeared_pair,
mm_concat, side, location_index)
appeared_pair = concat_pairs(pair_to_merge, disappearing_pair, side)
if appeared_pair in location_index:
for neighbour_of_neighbour in neighbour_index[disappearing_pair][side]:
add_pairs_to_neighbour_index(neighbour_index, appeared_pair, neighbour_of_neighbour, side,
location_index)
if can_be_concat(pair_to_merge, neighbour_of_neighbour, side.opposite()):
neighbour_of_neighbour_concat = concat_pairs(pair_to_merge, neighbour_of_neighbour,
side.opposite())
add_pairs_to_neighbour_index(neighbour_index, appeared_pair, neighbour_of_neighbour_concat,
side, location_index)
op_side = side.opposite()
for neighbour_of_neighbour in neighbour_index[pair_to_merge][op_side]:
cc = concat_pairs(pair_to_merge, neighbour_of_neighbour, op_side)
add_pairs_to_neighbour_index(neighbour_index, appeared_pair, cc, op_side, location_index)
if can_be_concat(pair_to_merge, cc, op_side.opposite()):
cc_concat = concat_pairs(pair_to_merge, cc, op_side.opposite())
add_pairs_to_neighbour_index(neighbour_index, appeared_pair, cc_concat, op_side,
location_index)
cleanup_neighbour_index(location_index, neighbour_index, pair_to_merge)
def run(generator: Generator[str, None, None], n_merges: int=sys.maxsize,
include_performance_stats_every_n_merges: int = 0) \
-> Tuple[str, int, Optional[List[BpePerformanceStatsEntry]]]:
checkpoint = time.time()
location_index, neighbour_index = build_indices(generator)
priority_counter = PriorityCounter({k: (len(v), v[0]) for k, v in location_index.items()}, automatic_count=False)
logger.debug(f'Size of location index: {getsize(location_index) / 1e+6} (MB)')
logger.debug(f'Size of neighbour index: {getsize(neighbour_index) / 1e+6} (MB)')
logger.debug(f'Index build in : {time.time()-checkpoint} s')
bpe_performance_stats = None
if include_performance_stats_every_n_merges:
bpe_performance_stats = [
BpePerformanceStatsEntry(
merges_done=0,
time_for_last_merge=0,
n_priority_queue_entries=len(priority_counter.pq),
n_index_enties=len(location_index),
location_index_obj_size=getsize(location_index) / 1e+6,
neighbour_index_obj_size=getsize(neighbour_index) / 1e+6,
priority_counter_obj_size=getsize(priority_counter) / 1e+6
)
]
for i in range(n_merges):
checkpoint = time.time()
try:
most_freq_pair, occurences = priority_counter.pop_pair()
logger.debug(f'Merge {i+1}: {most_freq_pair} {occurences}')
except KeyError:
break
occurence_changes = update_location_index(location_index, neighbour_index, most_freq_pair)
for (appeared_pair, first_appeared_pair, disappearing_pair, first_left_disappering_pair, n_occurences) in occurence_changes:
priority_counter.add(appeared_pair, n_occurences, first_appeared_pair)
if disappearing_pair != most_freq_pair:
priority_counter.add(disappearing_pair, -n_occurences, first_left_disappering_pair)
else:
occurences -= n_occurences
update_neighbour_index(location_index, neighbour_index, most_freq_pair)
time_per_merge = time.time() - checkpoint
if include_performance_stats_every_n_merges > 0 and (i == 1 or i % include_performance_stats_every_n_merges == 0):
n_index_entries = len(location_index)
n_priority_queue_entries = len(priority_counter.pq)
location_index_obj_size = getsize(location_index) / 1e+6
neighbour_index_obj_size = getsize(neighbour_index) / 1e+6
priority_queue_obj_size = getsize(priority_counter) / 1e+6
logger.debug(f"--------------------------- After merge {i}")
logger.debug(f"Last merge was done in {time_per_merge} s")
logger.debug(f'The number of keys in the index: {n_index_entries}')
logger.debug(f'Length of pq {n_priority_queue_entries}')
logger.debug(f'Size of location index: {location_index_obj_size} (MB)')
logger.debug(f'Size of neighbour index: {neighbour_index_obj_size} (MB)')
logger.debug(f'Size of priority counter: {priority_queue_obj_size} (MB)')
bpe_performance_stats.append(
BpePerformanceStatsEntry(
merges_done=i,
time_for_last_merge=time_per_merge,
n_priority_queue_entries=n_priority_queue_entries,
n_index_enties=n_index_entries,
location_index_obj_size=location_index_obj_size,
neighbour_index_obj_size=neighbour_index_obj_size,
priority_counter_obj_size=priority_queue_obj_size
)
)
yield (most_freq_pair, occurences, bpe_performance_stats)
if include_performance_stats_every_n_merges > 0 and (location_index_obj_size + neighbour_index_obj_size + priority_queue_obj_size) > 3072:
return
def run_from_file(path_to_file: str, n_merges: int=sys.maxsize) -> Tuple[str, int, Optional[List[BpePerformanceStatsEntry]]]:
it = get_char_iterator_for_file(path_to_file)
return run(it, n_merges)
def run_from_dir(path_to_dir: str, n_merges: int=sys.maxsize) -> Tuple[str, int, Optional[List[BpePerformanceStatsEntry]]]:
it = get_char_iterator_for_dir(path_to_dir)
return run(it, n_merges)
def run_from_text(text: str, n_merges: int=sys.maxsize) -> Tuple[str, int, Optional[List[BpePerformanceStatsEntry]]]:
"""
>>> def run_and_get_merges(text: str):
... return [(m, occ) for m, occ, _ in run_from_text(text)]
>>> run_and_get_merges("a")
[]
>>> run_and_get_merges("ab")
[('a b', 1)]
>>> run_and_get_merges("abcdbc")
[('b c', 2), ('a bc', 1), ('abc d', 1), ('abcd bc', 1)]
>>> run_and_get_merges("aaa")
[('a a', 1), ('aa a', 1)]
>>> run_and_get_merges("aaaa")
[('a a', 2), ('aa aa', 1)]
>>> run_and_get_merges("aaaaa")
[('a a', 2), ('aa aa', 1), ('aaaa a', 1)]
>>> run_and_get_merges("aaaaaa")
[('a a', 3), ('aa aa', 1), ('aaaa aa', 1)]
>>> run_and_get_merges("aaaaaab")
[('a a', 3), ('aa aa', 1), ('aaaa aa', 1), ('aaaaaa b', 1)]
>>> run_and_get_merges("aaaaaaaa")
[('a a', 4), ('aa aa', 2), ('aaaa aaaa', 1)]
>>> run_and_get_merges("there|is|a|thin|tooth|in|the|tooth")
[('t h', 5), ('th e', 2), ('| i', 2), ('n |', 2), ('t o', 2), ('to o', 2), ('too th', 2), \
('the r', 1), ('ther e', 1), ('there |i', 1), ('there|i s', 1), ('there|is |', 1), ('there|is| a', 1), \
('there|is|a |', 1), ('there|is|a| th', 1), ('there|is|a|th i', 1), ('there|is|a|thi n|', 1), \
('there|is|a|thin| tooth', 1), ('there|is|a|thin|tooth |i', 1), ('there|is|a|thin|tooth|i n|', 1), \
('there|is|a|thin|tooth|in| the', 1), ('there|is|a|thin|tooth|in|the |', 1), \
('there|is|a|thin|tooth|in|the| tooth', 1)]
"""
return run(iter(text), n_merges)
if __name__ == '__main__':
pass