median-research-group/LibMTL

View on GitHub
examples/xtreme/propocess_data/conll.py

Summary

Maintainability
D
2 days
Test Coverage
# https://github.com/google-research/xtreme/issues/63

try:
    import networkx as nx
except ModuleNotFoundError:
    from pip._internal import main as pip
    pip(['install', '--user', 'networkx==1.11'])
    import networkx as nx

from collections import Counter
import re


#TODO make these parse functions static methods of ConllReder
def parse_id(id_str):
    if id_str == '_':
        return None
    if "." in id_str:
        return None
    ids = tuple(map(int, id_str.split("-")))
    if len(ids) == 1:
        return ids[0]
    else:
        return ids

def parse_feats(feats_str):
    if feats_str == '_':
        return {}
    feat_pairs = [pair.split("=") for pair in feats_str.split("|")]
    return {k: v for k, v in feat_pairs}

def parse_deps(dep_str):
    if dep_str == '_':
        return []
    dep_pairs = [pair.split(":") for pair in dep_str.split("|")]
    return [(int(pair[0]), pair[1]) for pair in dep_pairs if pair[0].isdigit()]




class DependencyTree(nx.DiGraph):
    """
    A DependencyTree as networkx graph:
    nodes store information about tokens
    edges store edge related info, e.g. dependency relations
    """

    def __init__(self):
        nx.DiGraph.__init__(self)

    def pathtoroot(self, child):
        path = []
        newhead = self.head_of(self, child)
        while newhead:
            path.append(newhead)
            newhead = self.head_of(self, newhead)
        return path

    def head_of(self, n):
        for u, v in self.edges():
            if v == n:
                return u
        return None

    def get_sentence_as_string(self,printid=False):
        out = []
        for token_i in range(1, max(self.nodes()) + 1):
            if printid:
                out.append(str(token_i)+":"+self.node[token_i]['form'])
            else:
                out.append(self.node[token_i]['form'])
        return u" ".join(out)

    def subsumes(self, head, child):
        if head in self.pathtoroot(self, child):
            return True

    def remove_arabic_diacritics(self):
        # The following code is based on nltk.stem.isri
        # It is equivalent to an interative application of isri.norm(word,num=1)
        # i.e. we do not remove any hamza characters

        re_short_vowels = re.compile(r'[\u064B-\u0652]')
        for n in self.nodes():
            self.node[n]["form"] = re_short_vowels.sub('', self.node[n]["form"])


    def get_highest_index_of_span(self, span):  # retrieves the node index that is closest to root
        #TODO: CANDIDATE FOR DEPRECATION
        distancestoroot = [len(self.pathtoroot(self, x)) for x in span]
        shortestdistancetoroot = min(distancestoroot)
        spanhead = span[distancestoroot.index(shortestdistancetoroot)]
        return spanhead

    def get_deepest_index_of_span(self, span):  # retrieves the node index that is farthest from root
        #TODO: CANDIDATE FOR DEPRECATION
        distancestoroot = [len(self.pathtoroot(self, x)) for x in span]
        longestdistancetoroot = max(distancestoroot)
        lownode = span[distancestoroot.index(longestdistancetoroot)]
        return lownode

    def span_makes_subtree(self, initidx, endidx):
        G = nx.DiGraph()
        span_nodes = list(range(initidx,endidx+1))
        span_words = [self.node[x]["form"] for x in span_nodes]
        G.add_nodes_from(span_nodes)
        for h,d in self.edges():
            if h in span_nodes and d in span_nodes:
                G.add_edge(h,d)
        return nx.is_tree(G)

    def _choose_spanhead_from_heuristics(self,span_nodes,pos_precedence_list):
        distancestoroot = [len(nx.ancestors(self,x)) for x in span_nodes]
        shortestdistancetoroot = min(distancestoroot)
        distance_counter = Counter(distancestoroot)

        highest_nodes_in_span = []
        # Heuristic Nr 1: If there is one single highest node in the span, it becomes the head
        # N.B. no need for the subspan to be a tree if there is one single highest element
        if distance_counter[shortestdistancetoroot] == 1:
            spanhead = span_nodes[distancestoroot.index(shortestdistancetoroot)]
            return spanhead

        # Heuristic Nr 2: Choose by POS ranking the best head out of the highest nodes
        for x in span_nodes:
            if len(nx.ancestors(self,x)) == shortestdistancetoroot:
                highest_nodes_in_span.append(x)

        best_rank = len(pos_precedence_list) + 1
        candidate_head = - 1
        span_upos  = [self.node[x]["cpostag"]for x in highest_nodes_in_span]
        for upos, idx in zip(span_upos,highest_nodes_in_span):
            if pos_precedence_list.index(upos) < best_rank:
                best_rank = pos_precedence_list.index(upos)
                candidate_head = idx
        return candidate_head

    def _remove_node_properties(self,fields):
        for n in sorted(self.nodes()):
            for fieldname in self.node[n].keys():
                if fieldname in fields:
                    self.node[n][fieldname]="_"

    def _remove_deprel_suffixes(self):
        for h,d in self.edges():
            if ":" in self[h][d]["deprel"]:
                self[h][d]["deprel"]=self[h][d]["deprel"].split(":")[0]

    def _keep_fused_form(self,posPreferenceDicts):
        # For a span A,B  and external tokens C, such as  A > B > C, we have to
        # Make A the head of the span
        # Attach C-level tokens to A
        #Remove B-level tokens, which are the subtokens of the fused form della: de la

        if self.graph["multi_tokens"] == {}:
            return

        spanheads = []
        spanhead_fused_token_dict = {}
        # This double iteration is overkill, one could skip the spanhead identification
        # but in this way we avoid modifying the tree as we read it
        for fusedform_idx in sorted(self.graph["multi_tokens"]):
            fusedform_start, fusedform_end = self.graph["multi_tokens"][fusedform_idx]["id"]
            fuseform_span = list(range(fusedform_start,fusedform_end+1))
            spanhead = self._choose_spanhead_from_heuristics(fuseform_span,posPreferenceDicts)
            #if not spanhead:
            #    spanhead = self._choose_spanhead_from_heuristics(fuseform_span,posPreferenceDicts)
            spanheads.append(spanhead)
            spanhead_fused_token_dict[spanhead] = fusedform_idx

        # try:
        #     order = list(nx.topological_sort(self))
        # except nx.NetworkXUnfeasible:
        #     msg = 'Circular dependency detected between hooks'
        #     problem_graph = ', '.join(f'{a} -> {b}'
        #                   for a, b in nx.find_cycle(self))
        #     print('nx.simple_cycles', list(nx.simple_cycles(self)))
        #     print(problem_graph)
        #     exit(0)
            # for edge in list(nx.simple_cycles(self)):
            #     self.remove_edge(edge[0], edge[1])
        self = remove_all_cycle(self)
        bottom_up_order = [x for x in nx.topological_sort(self) if x in spanheads]
        for spanhead in bottom_up_order:
            fusedform_idx = spanhead_fused_token_dict[spanhead]
            fusedform = self.graph["multi_tokens"][fusedform_idx]["form"]
            fusedform_start, fusedform_end = self.graph["multi_tokens"][fusedform_idx]["id"]
            fuseform_span = list(range(fusedform_start,fusedform_end+1))

            if spanhead:
                #Step 1: Replace form of head span (A)  with fusedtoken form  -- in this way we keep the lemma and features if any
                self.node[spanhead]["form"] = fusedform
                # 2-  Reattach C-level (external dependents) to A
                #print(fuseform_span,spanhead)

                internal_dependents = set(fuseform_span) - set([spanhead])
                external_dependents = [nx.bfs_successors(self,x) for x in internal_dependents]
                for depdict in external_dependents:
                    for localhead in depdict:
                        for ext_dep in depdict[localhead]:
                            if ext_dep in self[localhead]:
                                deprel = self[localhead][ext_dep]["deprel"]
                                self.remove_edge(localhead,ext_dep)
                                self.add_edge(spanhead,ext_dep,deprel=deprel)

                #3- Remove B-level tokens
                for int_dep in internal_dependents:
                    self.remove_edge(self.head_of(int_dep),int_dep)
                    self.remove_node(int_dep)

        #4 reconstruct tree at the very end
        new_index_dict = {}
        for new_node_index, old_node_idex in enumerate(sorted(self.nodes())):
            new_index_dict[old_node_idex] = new_node_index

        T = DependencyTree() # Transfer DiGraph, to replace self

        for n in sorted(self.nodes()):
            T.add_node(new_index_dict[n],self.node[n])

        for h, d in self.edges():
            T.add_edge(new_index_dict[h],new_index_dict[d],deprel=self[h][d]["deprel"])
        #4A Quick removal of edges and nodes
        self.__init__()

        #4B Rewriting the Deptree in Self
        # TODO There must a more elegant way to rewrite self -- self= T for instance?
        for n in sorted(T.nodes()):
            self.add_node(n,T.node[n])

        for h,d in T.edges():
            self.add_edge(h,d,T[h][d])

        # 5. remove all fused forms form the multi_tokens field
        self.graph["multi_tokens"] = {}

        # if not nx.is_tree(self):
        #     print("Not a tree after fused-form heuristics:",self.get_sentence_as_string())

    def filter_sentence_content(self,replace_subtokens_with_fused_forms=False, lang=None, posPreferenceDict=None,node_properties_to_remove=None,remove_deprel_suffixes=False,remove_arabic_diacritics=False):
        if replace_subtokens_with_fused_forms:
            self._keep_fused_form(posPreferenceDict)
        if remove_deprel_suffixes:
            self._remove_deprel_suffixes()
        if node_properties_to_remove:
            self._remove_node_properties(node_properties_to_remove)
        if remove_arabic_diacritics:
            self.remove_arabic_diacritics()

def remove_all_cycle(G):
    GC = nx.DiGraph(G.edges())
    edges = list(nx.simple_cycles(GC))
    for edge in edges:
        for i in range(len(edge)-1):
            for j in range(i+1, len(edge)):
                a, b = edge[i], edge[j]
                if G.has_edge(a, b):
                    # print('remove {} - {}'.format(a, b))
                    G.remove_edge(a, b)
    return G


class CoNLLReader(object):
    """
    conll input/output
    """

    "" "Static properties"""
    CONLL06_COLUMNS = [('id',int), ('form',str), ('lemma',str), ('cpostag',str), ('postag',str), ('feats',str), ('head',int), ('deprel',str), ('phead', str), ('pdeprel',str)]
    #CONLL06_COLUMNS = ['id', 'form', 'lemma', 'cpostag', 'postag', 'feats', 'head', 'deprel', 'phead', 'pdeprel']
    CONLL06DENSE_COLUMNS = [('id',int), ('form',str), ('lemma',str), ('cpostag',str), ('postag',str), ('feats',str), ('head',int), ('deprel',str), ('edgew',str)]
    CONLL_U_COLUMNS = [('id', parse_id), ('form', str), ('lemma', str), ('cpostag', str),
                   ('postag', str), ('feats', str), ('head', parse_id), ('deprel', str),
                   ('deps', parse_deps), ('misc', str)]
    #CONLL09_COLUMNS =  ['id','form','lemma','plemma','cpostag','pcpostag','feats','pfeats','head','phead','deprel','pdeprel']



    def __init__(self):
        pass

    def read_conll_2006(self, filename):
        sentences = []
        sent = DependencyTree()
        for line_num, conll_line in enumerate(open(filename)):
            parts = conll_line.strip().split("\t")
            if len(parts) in (8, 10):
                token_dict = {key: conv_fn(val) for (key, conv_fn), val in zip(self.CONLL06_COLUMNS, parts)}

                sent.add_node(token_dict['id'], token_dict)
                sent.add_edge(token_dict['head'], token_dict['id'], deprel=token_dict['deprel'])
            elif len(parts) == 0  or (len(parts)==1 and parts[0]==""):
                sentences.append(sent)
                sent = DependencyTree()
            else:
                raise Exception("Invalid input format in line nr: ", line_num, conll_line, filename)

        return sentences

    def read_conll_2006_dense(self, filename):
        sentences = []
        sent = DependencyTree()
        for conll_line in open(filename):
            parts = conll_line.strip().split("\t")
            if len(parts) == 9:
                token_dict = {key: conv_fn(val) for (key, conv_fn), val in zip(self.CONLL06DENSE_COLUMNS, parts)}

                sent.add_node(token_dict['id'], token_dict)
                sent.add_edge(token_dict['head'], token_dict['id'], deprel=token_dict['deprel'])
            elif len(parts) == 0 or (len(parts)==1 and parts[0]==""):
                sentences.append(sent)
                sent = DependencyTree()
            else:
                raise Exception("Invalid input format in line: ", conll_line, filename)

        return sentences



    def write_conll(self, list_of_graphs, conll_path,conllformat, print_fused_forms=False,print_comments=False):
        # TODO add comment writing
        if conllformat == "conllu":
            columns = [colname for colname, fname in self.CONLL_U_COLUMNS]
        else:
            columns = [colname for colname, fname in self.CONLL06_COLUMNS]

        with conll_path.open('w') as out:
            for sent_i, sent in enumerate(list_of_graphs):
                if sent_i > 0:
                    print("", file=out)
                if print_comments:
                    for c in sent.graph["comment"]:
                        print(c, file=out)
                for token_i in range(1, max(sent.nodes()) + 1):
                    token_dict = dict(sent.node[token_i])
                    head_i = sent.head_of(token_i)
                    if head_i is None:
                        token_dict['head'] = 0
                        token_dict['deprel'] = ''
                    else:
                        token_dict['head'] = head_i
                        token_dict['deprel'] = sent[head_i][token_i]['deprel']
                    token_dict['id'] = token_i
                    row = [str(token_dict.get(col, '_')) for col in columns]
                    if print_fused_forms and token_i in sent.graph["multi_tokens"]:
                       currentmulti = sent.graph["multi_tokens"][token_i]
                       currentmulti["id"]=str(currentmulti["id"][0])+"-"+str(currentmulti["id"][1])
                       currentmulti["feats"]="_"
                       currentmulti["head"]="_"
                       rowmulti = [str(currentmulti.get(col, '_')) for col in columns]
                       print(u"\t".join(rowmulti),file=out)
                    print(u"\t".join(row), file=out)

            # emtpy line afterwards
            print(u"", file=out)


    def read_conll_u(self,filename,keepFusedForm=False, lang=None, posPreferenceDict=None):
        sentences = []
        sent = DependencyTree()
        multi_tokens = {}

        for line_no, line in enumerate(open(filename).readlines()):
            line = line.strip("\n")
            if not line:
                # Add extra properties to ROOT node if exists
                if 0 in sent:
                    for key in ('form', 'lemma', 'cpostag', 'postag'):
                        sent.node[0][key] = 'ROOT'

                # Handle multi-tokens
                sent.graph['multi_tokens'] = multi_tokens
                multi_tokens = {}
                sentences.append(sent)
                sent = DependencyTree()
            elif line.startswith("#"):
                if 'comment' not in sent.graph:
                    sent.graph['comment'] = [line]
                else:
                    sent.graph['comment'].append(line)
            else:
                parts = line.split("\t")
                if len(parts) != len(self.CONLL_U_COLUMNS):
                    error_msg = 'Invalid number of columns in line {} (found {}, expected {})'.format(line_no, len(parts), len(CONLL_U_COLUMNS))
                    raise Exception(error_msg)

                token_dict = {key: conv_fn(val) for (key, conv_fn), val in zip(self.CONLL_U_COLUMNS, parts)}
                if isinstance(token_dict['id'], int):
                    sent.add_edge(token_dict['head'], token_dict['id'], deprel=token_dict['deprel'])
                    sent.node[token_dict['id']].update({k: v for (k, v) in token_dict.items()
                                                        if k not in ('head', 'id', 'deprel', 'deps')})
                    for head, deprel in token_dict['deps']:
                        sent.add_edge(head, token_dict['id'], deprel=deprel, secondary=True)
                elif token_dict['id'] is not None:
                    #print(token_dict['id'])
                    first_token_id = int(token_dict['id'][0])
                    multi_tokens[first_token_id] = token_dict
        return sentences