deeplearning4j/deeplearning4j

View on GitHub
libnd4j/auto_vectorization/inverted_index.py

Summary

Maintainability
B
5 hrs
Test Coverage
'''
@author : Abdelrauf rauf@konduit.ai
'''

#  /* ******************************************************************************
#   *
#   *
#   * This program and the accompanying materials are made available under the
#   * terms of the Apache License, Version 2.0 which is available at
#   * https://www.apache.org/licenses/LICENSE-2.0.
#   *
#   *  See the NOTICE file distributed with this work for additional
#   *  information regarding copyright ownership.
#   * Unless required by applicable law or agreed to in writing, software
#   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#   * License for the specific language governing permissions and limitations
#   * under the License.
#   *
#   * SPDX-License-Identifier: Apache-2.0
#   ******************************************************************************/

import json


def get_compressed_indices_list(set_a):
    '''Get compressed list from set '''
    new_list = sorted(list(set_a))
    for i in range(len(new_list)-1, 0, -1):
        new_list[i] = new_list[i] - new_list[i-1]
    return new_list


def intersect_compressed_sorted_list(list1, list2):
    len_1 = len(list1)
    len_2 = len(list2)
    intersected = []
    last_1, last_2, i, j = 0, 0, 0, 0
    while i < len_1 and j < len_2:
        real_i = last_1 + list1[i]
        real_j = last_2 + list2[j]

        # move first if it is smaller
        if real_i < real_j:
            last_1 = real_i
            i += 1
        elif real_i > real_j:
            last_2 = real_j
            j += 1
        else:
            # intersected
            i += 1
            j += 1
            intersected.append(real_i)
            last_1 = real_i
            last_2 = real_j
    return intersected


class InvertedIndex:
    ''' InvertedIndex for the auto_vect generated invert_index json format '''

    def __init__(self, file_name):
        with open(file_name, "r") as ifx:
            self.index_obj = json.load(ifx)

    def get_all_index(self, entry_name, predicate):
        '''
         Parameters:
         entry_name  {messages,files,function}
         predicate  function
         Returns:
         list:   list of indexes
        '''
        return [idx for idx, x in enumerate(self.index_obj[entry_name]) if predicate(x)]

    def get_all_index_value(self, entry_name, predicate):
        '''
         Parameters:
         entry_name  {messages,files,function}
         predicate  function
         Returns:
         list:   list of indexes  ,values
        '''
        return [(idx, x) for idx, x in enumerate(self.index_obj[entry_name]) if predicate(x)]

    def get_function_index(self, predicate=lambda x: True):
        return self.get_all_index('functions', predicate)

    def get_msg_index(self, predicate=lambda x: True):
        return self.get_all_index('messages', predicate)

    def get_file_index(self, predicate=lambda x: True):
        return self.get_all_index('files', predicate)

    def get_msg_postings(self, index):
        '''
         Gets postings for the given message   
         Parameters:
         index   message index
         Returns:
         [[file index , line position , [ functions ]]]:  list of file index  line position and and compressed functions for the given message  
        '''
        key = str(index)
        if not key in self.index_obj['msg_entries']:
            return []
        return self.index_obj['msg_entries'][key]

    def intersect_postings(self, posting1, compressed_sorted_functions, sorted_files=None):
        '''
         Intersects postings with the given functions and sorted_files
         Parameters:
         posting1 postings. posting is [[file_id1,line, [compressed_functions]],..]
         compressed_sorted_functions compressed sorted function index to be intersected
         sorted_files  sorted index of files to be Intersected with the result [ default is None]
         Returns:
         filtered uncompressed posting
        '''
        new_postings = []
        if sorted_files is not None:
            i, j = 0, 0
            len_1 = len(posting1)
            len_2 = len(sorted_files)
            while i < len_1 and j < len_2:
                file_1 = posting1[i][0]
                file_2 = sorted_files[j]
                if file_1 < file_2:
                    i += 1
                elif file_1 > file_2:
                    j += 1
                else:
                    new_postings.append(posting1[i])
                    i += 1
                    # dont increase sorted_files in this case
                    # j=+1

        input_p = new_postings if sorted_files is not None else posting1
        # search and intersect all functions
        new_list = []
        for p in input_p:
            px = intersect_compressed_sorted_list(
                compressed_sorted_functions, p[2])
            if len(px) > 0:
                new_list.append([p[0], p[1], px])
        return new_list

    def get_results_for_msg(self,  msg_index, functions, sorted_files=None):
        '''
         Return  filtered posting for the given msgs index
         Parameters:
         msg_index  message index
         functions   function index list
         sorted_files   intersects with sorted_files also (default: None)
         Returns:
         filtered uncompressed posting for msg index  [ [doc, line, [ function index]] ]
        '''
        result = []
        compressed = get_compressed_indices_list(functions)
        ix = self.intersect_postings(self.get_msg_postings(
            msg_index), compressed, sorted_files)
        if len(ix) > 0:
            result.append(ix)
        return result

    def get_results_for_msg_grouped_by_func(self,  msg_index, functions, sorted_files=None):
        '''
         Return  {functions: set((doc_id, pos))} for the given msg index
         Parameters:
         msg_index  message index
         functions   function index list
         sorted_files   intersects with sorted_files also (default: None)
         Returns:
         {functions: set((doc_id, pos))}
        '''
        result = {}
        compressed = get_compressed_indices_list(functions)
        ix = self.intersect_postings(self.get_msg_postings(
            msg_index), compressed, sorted_files)
        for t in ix:
            for f in t[2]:
                if f in result:
                    result[f].add((t[0], t[1]))
                else:
                    result[f] = set()
                    result[f].add((t[0], t[1]))
        return result


'''
Example:

import re
rfile='vecmiss_fsave_inverted_index.json'
import inverted_index as helper
ind_obj = helper.InvertedIndex(rfile) 
reg_ops = re.compile(r'simdOps')
reg_msg = re.compile(r'success')
simdOps = ind_obj.get_function_index(lambda x : reg_ops.search(x) )
msgs    = ind_obj.get_msg_index(lambda x: reg_msg.search(x))
files   = ind_obj.get_file_index(lambda x:  'cublas' in x)
res     = ind_obj.get_results_for_msg(msgs[0],simdOps)

'''