vanheeringen-lab/gimmemotifs

View on GitHub
gimmemotifs/motif/cluster.py

Summary

Maintainability
C
7 hrs
Test Coverage
A
90%
# Copyright (c) 2009-2019 Simon van Heeringen <simon.vanheeringen@gmail.com>
#
# This module is free software. You can redistribute it and/or modify it under
# the terms of the MIT License, see the file COPYING included with this
# distribution.
"""Module for motif clustering."""
import logging
import os
import sys
from datetime import datetime

import jinja2

from gimmemotifs import __version__
from gimmemotifs.comparison import MotifComparer
from gimmemotifs.config import MotifConfig
from gimmemotifs.motif.base import Motif
from gimmemotifs.motif.read import read_motifs

logger = logging.getLogger("gimme.cluster")


class MotifTree(object):

    """Class MotifTree used by cluster_motifs"""

    def __init__(self, motif):
        self.motif = motif
        self.parent = None
        self.left = None
        self.right = None
        self.mergescore = None
        self.maxscore = 0
        self.frontier = False

    def setFrontier(self, _arg, root):
        self.frontier = True
        if self != root:
            self.parent.setFrontier(True, root)

    def checkMerge(self, root, threshold):
        if not self.frontier:  # and self != root:
            if self.mergescore > threshold:
                if self.parent:
                    self.parent.checkMerge(root, threshold)
            else:
                self.setFrontier(True, root)

    def printFrontiers(self):
        if self.frontier:
            if self.left:
                self.left.printFrontiers()
                self.right.printFrontiers()
        else:
            self.motif.ppm = self.motif.pfm_to_ppm(self.motif.ppm)

    def get_clustered_motifs(self):
        if self.frontier:
            if self.left:
                return (
                    self.left.get_clustered_motifs() + self.right.get_clustered_motifs()
                )
        else:
            return [self.motif]

    def getResult(self):
        if self.frontier:
            if self.left:
                return self.left.getResult() + self.right.getResult()
        else:
            return [[self.motif, self.recursive_motif()]]

    def recursive_name(self):
        if self.left:
            return self.left.recursive_name() + self.right.recursive_name()
        else:
            return [self.motif.id]

    def recursive_motif(self):
        if self.left:
            return self.left.recursive_motif() + self.right.recursive_motif()
        else:
            return [self.motif]


def cluster_motifs(
    motifs,
    match="total",
    metric="wic",
    combine="mean",
    pval=True,
    threshold=0.95,
    trim_edges=False,
    edge_ic_cutoff=0.2,
    include_bg=True,
    progress=True,
    ncpus=None,
):
    """
    Clusters a set of sequence motifs. Required arg 'motifs' is a file containing
    positional frequency matrices or an array with motifs.

    Optional args:

    'match', 'metric' and 'combine' specify the method used to compare and score
    the motifs. By default the WIC score is used (metric='wic'), using the the
    score over the whole alignment (match='total'), with the total motif score
    calculated as the mean score of all positions (combine='mean').
    'match' can be either 'total' for the total alignment or 'subtotal' for the
    maximum scoring subsequence of the alignment.
    'metric' can be any metric defined in MotifComparer, currently: 'pcc', 'ed',
    'distance', 'wic' or 'chisq'
    'combine' determines how the total score is calculated from the score of
    individual positions and can be either 'sum' or 'mean'

    'pval' can be True or False and determines if the score should be converted to
    an empirical p-value

    'threshold' determines the score (or p-value) cutoff

    If 'trim_edges' is set to True, all motif edges with an IC below
    'edge_ic_cutoff' will be removed before clustering

    When computing the average of two motifs 'include_bg' determines if, at a
    position only present in one motif, the information in that motif should
    be kept, or if it should be averaged with background frequencies. Should
    probably be left set to True.

    """

    # First read pfm or pfm formatted motiffile
    if type([]) != type(motifs):
        motifs = read_motifs(motifs, fmt="pfm")

    # All motifs must have unique ids, used in dictionary below
    motif_ids = [motif.id for motif in motifs]
    assert len(motif_ids) == len(set(motif_ids)), "Motif ids must be unique"

    mc = MotifComparer()

    # Trim edges with low information content
    if trim_edges:
        for motif in motifs:
            motif.trim(edge_ic_cutoff)

    # Make a MotifTree node for every motif
    nodes = [MotifTree(m) for m in motifs]

    # Determine all pairwise scores and maxscore per motif
    scores = {}
    motif_nodes = dict([(n.motif.id, n) for n in nodes])
    motifs = [n.motif for n in nodes]

    if progress:
        logger.info("Calculating initial scores")
    result = mc.get_all_scores(
        motifs, motifs, match, metric, combine, pval, parallel=True, ncpus=ncpus
    )

    for m1, other_motifs in result.items():
        for m2, score in other_motifs.items():
            if m1 == m2:
                if pval:
                    motif_nodes[m1].maxscore = 1 - score[0]
                else:
                    motif_nodes[m1].maxscore = score[0]
            else:
                if pval:
                    score = [1 - score[0]] + score[1:]
                scores[(motif_nodes[m1], motif_nodes[m2])] = score

    cluster_nodes = [node for node in nodes]
    ave_count = 1

    total = len(cluster_nodes)

    while len(cluster_nodes) > 1:
        length = sorted(scores.keys(), key=lambda x: scores[x][0])
        i = -1
        (n1, n2) = length[i]
        while n1 not in cluster_nodes or n2 not in cluster_nodes:
            i -= 1
            (n1, n2) = length[i]

        if len(n1.motif) > 0 and len(n2.motif) > 0:
            (score, pos, orientation) = scores[(n1, n2)]
            ave_motif = n1.motif.average_motifs(
                n2.motif, pos, orientation, include_bg=include_bg
            )

            ave_motif.trim(edge_ic_cutoff)

            # Check if the motif is not empty
            if len(ave_motif) == 0:
                ave_motif = Motif([[0.25, 0.25, 0.25, 0.25]])

            ave_motif.id = f"Average_{ave_count}"
            ave_count += 1

            new_node = MotifTree(ave_motif)
            if pval:
                new_node.maxscore = (
                    1
                    - mc.compare_motifs(
                        new_node.motif, new_node.motif, match, metric, combine, pval
                    )[0]
                )
            else:
                new_node.maxscore = mc.compare_motifs(
                    new_node.motif, new_node.motif, match, metric, combine, pval
                )[0]

            new_node.mergescore = score

            n1.parent = new_node
            n2.parent = new_node
            new_node.left = n1
            new_node.right = n2

            cmp_nodes = dict([(node.motif, node) for node in nodes if not node.parent])

            if progress:
                progress = int((1 - len(cmp_nodes) / float(total)) * 100)
                bar = "#" * (progress // 10) + " " * (10 - progress // 10)
                sys.stderr.write(f"\rClustering [{bar}] {progress}%")  # TODO: tqdm

            result = mc.get_all_scores(
                [new_node.motif],
                list(cmp_nodes.keys()),
                match,
                metric,
                combine,
                pval,
                parallel=True,
            )

            for motif, n in cmp_nodes.items():
                x = result[new_node.motif.id][motif.id]
                if pval:
                    x = [1 - x[0]] + x[1:]
                scores[(new_node, n)] = x

            nodes.append(new_node)

        cluster_nodes = [node for node in nodes if not node.parent]

    if progress:
        sys.stderr.write("\n")
    root = nodes[-1]
    for node in [node for node in nodes if not node.left]:
        node.parent.checkMerge(root, threshold)

    return root


def cluster_motifs_with_report(infile, outfile, outdir, threshold, title=None):
    # Cluster significant motifs

    if title is None:
        title = infile

    motifs = read_motifs(infile, fmt="pfm")

    trim_ic = 0.2
    clusters = []
    if len(motifs) == 0:
        return []
    elif len(motifs) == 1:
        clusters = [[motifs[0], motifs]]
    else:
        logger.info(f"clustering {len(motifs)} motifs.")
        tree = cluster_motifs(
            infile,
            "total",
            "wic",
            "mean",
            True,
            threshold=float(threshold),
            include_bg=True,
            progress=False,
        )
        clusters = tree.getResult()

    ids = []
    mc = MotifComparer()

    img_dir = os.path.join(outdir, "images")

    if not os.path.exists(img_dir):
        os.mkdir(img_dir)

    for cluster, members in clusters:
        cluster.trim(trim_ic)
        png = os.path.join("images", f"{cluster.id}.png")
        cluster.plot_logo(fname=os.path.join(outdir, png))
        ids.append([cluster.id, {"src": png}, []])
        if len(members) > 1:
            scores = {}
            for motif in members:
                scores[motif] = mc.compare_motifs(
                    cluster, motif, "total", "wic", "mean", pval=True
                )
            add_pos = sorted(scores.values(), key=lambda x: x[1])[0][1]
            for motif in members:
                _score, pos, strand = scores[motif]
                add = pos - add_pos

                if strand in [1, "+"]:
                    pass
                else:
                    rc = motif.rc()
                    rc.id = motif.id
                    motif = rc
                png = os.path.join(
                    outdir, "images", f"{motif.id.replace(' ', '_')}.png"
                )
                motif.plot_logo(fname=png, add_left=add)
        ids[-1][2] = [
            dict(
                [
                    (
                        "src",
                        os.path.join("images", f"{motif.id.replace(' ', '_')}.png"),
                    ),
                    ("alt", motif.id.replace(" ", "_")),
                ]
            )
            for motif in members
        ]

    config = MotifConfig()
    env = jinja2.Environment(
        loader=jinja2.FileSystemLoader([config.get_template_dir()])
    )
    template = env.get_template("cluster_template.jinja.html")
    result = template.render(
        motifs=ids,
        inputfile=title,
        date=datetime.today().strftime("%d/%m/%Y"),
        version=__version__,
    )

    cluster_report = os.path.join(outdir, "gimme.clustereds.html")
    with open(cluster_report, "wb") as f:
        f.write(result.encode("utf-8"))

    f = open(outfile, "w")
    if len(clusters) == 1 and len(clusters[0][1]) == 1:
        f.write(f"{clusters[0][0].to_ppm()}\n")
    else:
        for motif in tree.get_clustered_motifs():
            f.write(f"{motif.to_ppm()}\n")
    f.close()

    logger.debug(f"Clustering done. See the result in {cluster_report}")
    return clusters