almlab/angst

View on GitHub
tree_lib/node.py

Summary

Maintainability
F
3 days
Test Coverage
import sys
import pdb
import math
import multitree
import branch

class node:

    def __init__(self,raw_name,arbre):
        # remove any leading angled brackets
        raw_name = raw_name.lstrip(">")

        # handle mixed underscore and dot environment
        underscore_index = sys.maxint
        dot_index = sys.maxint
        if raw_name.count('_') > 0:
            underscore_index = raw_name.index('_')
        if raw_name.count('.') > 0:
            dot_index = raw_name.index('.')
        # case of species node
        self.name = raw_name
        if underscore_index == dot_index:
            self.species = raw_name
        # case of gene node
        elif dot_index < underscore_index:
            self.species = raw_name[:dot_index]
        else:
            self.species = raw_name[:underscore_index]

    self.tree = arbre               # tree node belongs to
        self.myBranch = None            # used to construct unrooted trees
    self.branch_list = []           # used to construct unrooted trees
    self.child_branches = []        # list of branches to kids
        self.kid_nodes = {}
    self.parent_branch = None         # should have length 1
    self.subnodes = dict()          # all nodes below this
    self.leaves = None              # leaves below this

        # attach to each node a serial number, to tell nodes apart
        self.serial = hash(self.name)
    self.visited = False            # useful for rooting trees
        self.link_dict_visited = False
        self.unrooted_leaving_visited = False

        # things necessary for UnrootedLeaving
        self.leaf_dict = {}
        self.branch_dict = {}
        self.other_nodes = []

        # things for reconciliation
        self.lca_lookups = None
        self.leaf_event_list = None
        self.rec_lca = None
        self.lca_map = None
        self.event_list = None
        self.raw_list = None

        # distance things
        self.leaf_dists = []
        self.dist_ratios = {}


    def __repr__(self):
        return self.name


    def __eq__(self,other):
    if self is None or other is None:
        return False
    elif self.serial == other.serial:
        return True
    else:
        return False


    def __ne__(self,other):

    if self is None or other is None:
        return True
    elif self.serial == other.serial:
        return False
    else:
        return True


    def __gt__(self,other):
    if self.name > other.name:
        return True
    else:
        return False

    def __lt__(self,other):
    if self.name < other.name:
        return True
    else:
        return False

    def __hash__(self):
        try:
            return self.serial
        except AttributeError:
            return 0

    def isLeaf(self):

        if len(self.branch_list) < 2:
            return True
        else:
            return False

    def addBranch(self,length):

    self.myBranch = branch.branch(length)
    self.myBranch.addNode(self)

    # recursively print out nodes
    def treePrint(self,newickString):

    count = 0
        # implement sorting so that you can compare trees in a quick
        # and dirty way
        kid_l = self.kid_nodes.keys()
        kid_l.sort()
        for kid_node in kid_l:
            if count == 0:
                newickString += "("
                newickString = kid_node.treePrint(newickString)
            else:
                newickString += ","
                newickString = kid_node.treePrint(newickString)
                newickString += ")"
            count += 1

    if self.parent_branch is not None and len(self.child_branches) == 0:
        addString = self.name + ":" + str(self.parent_branch.length)
    elif self.parent_branch is not None:
        addString = ":" + str(self.parent_branch.length)
    else:
        addString = ""
    return newickString + addString

    # join two nodes in a central node
    def unite(self,node2):
        ''' create a new node '''

        # sort names to ease comparisons between trees
        new_node_name = self.name + "-" + node2.name
    center_node = node(new_node_name,self.tree)
    self.myBranch.addNode(center_node)
    node2.myBranch.addNode(center_node)
    return center_node

    def imposeHierarchy(self):
    '''find all the child nodes that have yet to be visited and
    assign their children'''

    for kid_branch in self.child_branches:
        for node in kid_branch.ends:
        if not node.visited and node is not self:
                    self.kid_nodes[node] = 1
            node.visited = True
            node.parent_branch = kid_branch
                    node.tree.branch_list.append(node.parent_branch)
                    node.tree.node_dict[node.species] = node
            for child_branch in node.branch_list:
            if child_branch is not kid_branch:
                node.child_branches.append(child_branch)
                node.imposeHierarchy()

    # recursively label the roots of subtrees w/ the leaves contained
    # below
    def subtreeLabel(self):

    leaf_vec = []

    if self.leaves is not None:
        return self.leaves

    # recurse
        for kid_node in self.kid_nodes.keys():
            child_leaves = kid_node.subtreeLabel()
            leaf_vec.extend(child_leaves)

    # once you hit a leaf
    if len(self.child_branches) == 0:
        leaf_vec.append(self.species)

    # non-duplicates:
    self.leaves = dict(map(lambda i: (i,1),leaf_vec))
    return self.leaves.keys()

    # figure out which species tree node a gene tree node maps to
    def subtreeMap(self,species_node):

    # do the children of the current species node possess the
    # relevant genes?  if so, follow that child.  if not, return
    # the current node

    foundSet = 0
        for kid_node in species_node.kid_nodes.keys():
            s_leaves = kid_node.leaves
            g_leaves = self.leaves

            # count how many elements of the gene set of
            # leaves are not in the species set of leaves
            n = 0
            for i in g_leaves.keys():
                if not i in s_leaves:
                    n += 1
                    break

            # if you can see somewhere to descend, follow it
            if n == 0:
                return self.subtreeMap(kid_node)

    # once you've run out of places to descend, just return
    # wherever you've ended up:
    return species_node

    # recursively store at each node a hash of all nodes below
    def Find_Subnodes(self):

        for kid_nodes in self.kid_nodes.keys():
            new_dict = kid_nodes.Find_Subnodes()
            for k in new_dict.keys():
                if not k in self.subnodes:
                    self.subnodes[k] = 1
    self.subnodes[self.species] = 1
    return self.subnodes


    def FindLCA(self,node1,node2):
        ''' find the last common ancestor of two nodes.  will descend
        from the subroot (self) looking to see which children possess
        both nodes.  if neither children possess both nodes, return
        current node '''
        for kid_nodes in self.kid_nodes.keys():
            if node1.species in kid_nodes.subnodes:
                if node2.species in kid_nodes.subnodes:
                    return kid_nodes.FindLCA(node1,node2)
        return self

    # determine if one node is descended from the other ...
    def AreRelated(self,other_node):
        '''is other_node a child of self?'''
        node1 = self
        node2 = other_node
        if len(node1.leaves) > len(node2.leaves):
            node2 = self
            node1 = other_node
    for e in node2.leaves.keys():
        if e in node1.leaves:
                return True
        return False

    # method to list all the leaves of this node, keyed by the parent
    def UnrootedLeaving(this_node):

        other_nodes = this_node.GetOtherNodes()
        for parent_node in other_nodes:
            child_nodes = list(set(other_nodes).difference(set([parent_node])))
            this_node.GetChildLeaves(parent_node,child_nodes)

        # when done, move on to neighboring nodes:
        for node in other_nodes:

            if node.isLeaf():
                if node.name in node.tree.leaf_dict:
                    print "duplicate leaves in a bootstrap."
                    print "aborting prematurely."
                    sys.exit(1)
                node.tree.leaf_dict[node.name] = node

            if len(node.leaf_dict) is not len(node.branch_list):
                node.UnrootedLeaving()

    def GetLeaves(this_node,parent_node):

        # if the leaf dict has already been defined:
        if parent_node in this_node.leaf_dict:
            return this_node.leaf_dict[parent_node]

        # if is a leaf
        if len(this_node.branch_list) < 2:

            this_node.leaf_dict[parent_node] = [this_node.name]
            return this_node.leaf_dict[parent_node]

        other_nodes = this_node.GetOtherNodes()
        child_nodes = list(set(other_nodes).difference(set([parent_node])))
        this_node.GetChildLeaves(parent_node,child_nodes)

        return this_node.leaf_dict[parent_node]

    def GetChildLeaves(this_node,parent_node,other_nodes):

        merge_list = []
        for kid_node in other_nodes:
            if parent_node in this_node.leaf_dict:
                if len(this_node.leaf_dict[parent_node]) > 1:
                    continue
            kid_leaves = kid_node.GetLeaves(this_node)

            if type(kid_leaves[0]) is not type(''):
                merge_list.extend(kid_leaves[0])
            else:
                merge_list.extend(kid_leaves)

            # this is a total hack
            this_node.leaf_dict[parent_node] = []

        merge_list.sort()
        this_node.leaf_dict[parent_node].extend(merge_list)

    def GetNodeLinkDict(this_node,node_link_dict):

        this_node.link_dict_visited = True
        leaf_dict = this_node.leaf_dict

        for key in leaf_dict:

            subleaves = leaf_dict[key]
            if type(subleaves[0]) == type([]):
                subleaves = subleaves[0]
            sub_str = repr(subleaves)

            if sub_str in node_link_dict:
                node_link_dict[sub_str].append((this_node,key))
            else:
                node_link_dict[sub_str] = [(this_node,key)]

        # after that's been done, decide where to recurse
        for relative in leaf_dict:

            relative_keys = relative.leaf_dict.keys()
            subleaves = relative.leaf_dict[relative_keys[0]]
            subleaves.sort()

            if not relative.link_dict_visited:
                relative.GetNodeLinkDict(node_link_dict)

    # return a list of all bordering nodes
    def GetOtherNodes(this_node):

        # fill out the leaf dictionary
        other_nodes = []
        for branch in this_node.branch_list:
            for node in branch.ends:
                if node is not this_node:
                    other_nodes.append(node)
                    node.branch_dict[this_node] = branch
                    this_node.branch_dict[node] = branch
        this_node.other_nodes = other_nodes

        return other_nodes

    # get the distance between two nodes:
    # call as follows:
    # found, dist = node1.DistTo(node2)
    def DistTo(this_node,that_node,this_branch=None,dist=0):
        # what to do at the end
        if this_node is that_node:
            return True, dist
        # if not ...
        was_found = False
        found_dist = dist
        for i in this_node.branch_list:
            if i is not this_branch:
                for j in i.ends:
                    if j is not this_node:
                        found, new_dist = j.DistTo(that_node,i,dist+i.length)
                        if found is True:
                            was_found = found
                            found_dist = new_dist

        return was_found, found_dist


    def FindDists(self):
        ''' recursively enumerate all pairwise distances on the
        species tree '''

        dist_dict = self.tree.dist_dict
        kids = self.kid_nodes.keys()

        if self not in dist_dict:
            dist_dict[self] = {}

        # give own distance
        dist_dict[self][self] = 0.0

        # now, to each of those add the kids
        for kid in kids:

            # get a list of everyone 'self' is affiliated with
            parent_partners = dist_dict[self].keys()

            if kid not in dist_dict:
                dist_dict[kid] = {}
            dist_2_parent = kid.parent_branch.length
            for partner in parent_partners:
                this_dist = dist_dict[self][partner] + dist_2_parent
                dist_dict[kid][partner] = this_dist
                dist_dict[partner][kid] = this_dist
            kid.FindDists()


    def ValidTransfer(self,other_node):
        ''' return boolean regarding whether or not the parent
        branches of the queried nodes overlap temporally '''

        # no auto-transfer
        if self is other_node:
            return 0

        donor_node = self
        accep_node = other_node
        root = self.tree.root
        dist_dict = self.tree.dist_dict

        # note that there can be messiness with >= (as opposed to just
        # >), when you consider overlap between parent and child.
        # since we don't want transfer between parent and child,
        # branch lengths will be drawn in slightly and then compared
        min_dist = math.exp(-10)
        d11 = dist_dict[donor_node][root] - min_dist
        d12 = d11 - donor_node.parent_branch.length + min_dist
        d21 = dist_dict[accep_node][root] - min_dist
        d22 = d21 - accep_node.parent_branch.length + min_dist

        if d12 <= d21:
            if d22 <= d11:
                return 1 # direct transfer
            else:
                return 2 # phantom transfer
        else:
            return 0 # no transfer


    def IsAncestor(self,other_node):
        ''' is self an ancestor of other_node (1), the same as
        other_node (2), or neither (0). '''
        if self is other_node:
            return 2
        lca_node = self.tree.root.FindLCA(self,other_node)
        if lca_node is self:
            return 1
        else:
            return 0


    def GetAncestors(self):
        ''' get a list of ancestors from the current node (included in
        list ) to the root of the tree '''

        ancestor_list = [self]
        cur_node = self
        root = self.tree.root
        while cur_node is not root:
            # get the parent
            parent = filter(lambda i: i is not cur_node,cur_node.parent_branch.ends)[0]
            ancestor_list.append(parent)
            cur_node = parent
        return ancestor_list


    def GetHeights(self):
        ''' get MEDIAN distance to leaves at all internal nodes on the
        tree '''

        # get distances from leaves
        kids = self.kid_nodes.keys()
        my_len = self.parent_branch.length
        median = 0.0
        if len(kids) < 1:
            self.leaf_dists = [my_len]
        else:
            # get list of child distances
            child_dists = []
            for kid in kids:
                kid.GetHeights()
                child_dists.extend(kid.leaf_dists)
            # find the median
            child_dists.sort()
            mid_ind = float(len(child_dists))/2
            int_mid_ind = int(mid_ind)
            if mid_ind == int_mid_ind:
                median = sum(child_dists[int_mid_ind-1:int_mid_ind+1])/2
            else:
                median = child_dists[int_mid_ind]
            self.leaf_dists = [val + my_len for val in child_dists]

        # dist_range = [median, median + my_len]
        # dist_range = median + (my_len/2)
        dist_range = median

        self.tree.height_dict[self] = dist_range

        return


    def GetParent(self):
        ''' return the parent node '''
        nodes = self.parent_branch.ends
        parent_list = [node for node in nodes if node is not self]
        if len(parent_list) < 1:
            return None
        else:
            return parent_list[0]


    def AreSiblings(self,other_node):
        ''' evaluate whether or not arguments are sibling nodes '''
        parent1 = self.GetParent()
        parent2 = other_node.GetParent()

        if parent1 is parent2:
            return True
        # if they're not siblings:
        return False


    def GetSibling(self):
        ''' get the node sibling '''
        node_sibling = None
        node_parent = self.GetParent()

        if node_parent is None:
            pdb.set_trace()

        for kid_node in node_parent.kid_nodes.keys():
            if kid_node is not self:
                node_sibling = kid_node
        return node_sibling