vanheeringen-lab/ANANSE

View on GitHub
ananse/benchmark.py

Summary

Maintainability
A
2 hrs
Test Coverage
F
0%
import os
import re
import sys
import urllib.request
from glob import glob

import numpy as np
import pandas as pd
import seaborn as sns
from loguru import logger
from sklearn.metrics import average_precision_score, roc_auc_score

import ananse

logger.remove()
logger.add(
    sys.stderr, format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | {level} | {message}"
)

TFP_URL = "https://maayanlab.cloud/Enrichr/geneSetLibrary?mode=text&libraryName=TF_Perturbations_Followed_by_Expression"
TRRUST_URL = "https://www.grnpedia.org/trrust/data/trrust_rawdata.human.tsv"
MSIGDB_URL = "https://data.broadinstitute.org/gsea-msigdb/msigdb/release/7.4/c3.all.v7.4.symbols.gmt"


def download_trrust_reference(outfile):
    edges = []
    with urllib.request.urlopen(
        TRRUST_URL,
    ) as f:
        for line in f.readlines():
            tf, target, regtype, pmid = line.decode().strip().split("\t")
            # Just skip repression for now
            if regtype in ["Activation", "Unknown"]:
                edges.append([tf, target, 1])
    edges = pd.DataFrame(edges, columns=["tf", "target", "interaction"])
    edges.to_csv(outfile, sep="\t", index=False)


def download_msigdb_reference(outfile):
    with urllib.request.urlopen(MSIGDB_URL) as gmt, open(outfile, "w") as fl1:
        for line in gmt:
            a = line.decode("utf-8").split()
            tf = a[0].split("_")[0]
            targets = a[2:]
            for target in targets:
                fl1.write(f"{tf}\t{target}\n")


def fix_columns(df):
    """Make sure network has a tf and a target column."""

    df.columns = df.columns.str.lower()
    df = df.rename(
        columns={
            "source": "tf",
            "source_target": "tf_target",
            "target_gene": "target",
        }
    )

    if "tf_target" in df.columns:
        for sep in [ananse.SEPARATOR, "_", "-"]:
            if df["tf_target"].head(100).str.contains(sep).sum() == 100:
                break

        df[["tf", "target"]] = df["tf_target"].str.split(sep, expand=True).iloc[:, :2]
        df = df.drop(columns=["tf_target"])

    if "tf" not in df.columns:
        raise ValueError("Expect a column named 'source' or 'tf'")
    if "target" not in df.columns:
        raise ValueError("Expect a column named 'target' or 'target_gene'")

    return df


def prepare_reference_network(network, filter_tfs=True):
    """Generate reference network.

    This network contains all possible edges, based on the TFs
    and the target genes in the input. TFs are optionally filtered
    to contain only validated TFs.

    Returns
    -------
        DataFrame with column `"interaction"` having 1 for a validated
        edge and 0 otherwise.
    """
    if isinstance(network, pd.DataFrame):
        df = network.reset_index()
    elif isinstance(network, str):
        if network.endswith("feather"):
            df = pd.read_feather(network)
        else:
            df = pd.read_table(network)
    else:
        raise ValueError("Unknown network type, need DataFrame or filename.")

    df = fix_columns(df)

    interaction_column = None
    for col in df.columns:
        if col in ["tf", "target"]:
            continue
        vals = df[col].unique()

        if len(vals) in [1, 2] and 1 in vals:
            interaction_column = col
            break

    tfs = set(df["tf"].unique())
    if filter_tfs:
        valid_tfs = set(get_tfs())
        tfs = list(tfs.intersection(valid_tfs))

    targets = df["target"].unique()

    # logger.info(
    #     f"{os.path.split(network)[-1]} reference - {len(tfs)} TFs, {len(targets)} targets, {df.shape[0]} edges."
    # )

    total = []
    for tf in tfs:
        for target in targets:
            total.append([tf, target])
    total = pd.DataFrame(total, columns=["tf", "target"]).set_index(["tf", "target"])

    if interaction_column is not None:
        logger.info(f"Using '{interaction_column}' as interaction column.")
        df = df.set_index(["tf", "target"])[[interaction_column]].rename(
            columns={interaction_column: "interaction"}
        )
    else:
        logger.info("No column with 1 found, assuming all lines are positive edges.")
        df = df.set_index(["tf", "target"])
        df["interaction"] = 1

    return total.join(df[["interaction"]]).fillna(0)


def _read_dorothea_reference(fname):
    dorothea = pd.read_table(fname)
    cols = [
        "is_evidence_chip_seq",
        "is_evidence_curated",
        "is_evidence_inferred",
        "is_evidence_tfbs",
    ]
    dorothea = dorothea.set_index(["tf", "target"])[cols]
    for col in cols:
        dorothea[col] = dorothea[col].astype(int)
    dorothea["dorothea"] = np.any(dorothea[cols] == 1, 1).astype(int)

    dorothea = dorothea.reset_index()

    tfs = set(dorothea["tf"].unique())
    valid_tfs = set(get_tfs())
    tfs = list(tfs.intersection(valid_tfs))
    targets = dorothea["target"].unique()
    logger.info(
        f"Dorothea reference - {len(tfs)} TFs, {len(targets)} targets, {dorothea.shape[0]} edges."
    )

    total = []
    for tf in tfs:
        for target in targets:
            total.append([tf, target])
    total = pd.DataFrame(total, columns=["tf", "target"]).set_index(["tf", "target"])

    dorothea = dorothea.set_index(["tf", "target"])
    dorothea = total.join(dorothea)
    dorothea = dorothea.fillna(0)
    return dorothea


def _read_enrichr_perturbation_reference(fname=None):
    """Uses the TF perturbations from Enrichr[1,2] to create reference edges.

    Targets are defined by up- or down-regulated gened from the following sets:
    Up: INDUCTION, ACTIVATION, OE.
    Down: KD, KO, INACTIVATION, DEPLETION, SIRNA, SHRNA, KNOCKOUT, DELETION INHIBITION.

    The TF and targets in the DataFrame consists of the Cartesian product of
    all TFs and target genes that occur in the set.

    Returns
    -------
        DataFrame with tf-target edges.

    References
    ----------
    .. [1] Chen EY, Tan CM, Kou Y, Duan Q, Wang Z, Meirelles GV, Clark NfR, Ma'ayan A.
       "Enrichr: interactive and collaborative HTML5 gene list enrichment analysis
       tool." BMC Bioinformatics. 2013;128(14)

    .. [2] Kuleshov MV, Jones MR, Rouillard AD, Fernandez NF, Duan Q, Wang Z,
       Koplev S, Jenkins SL, Jagodnik KM, Lachmann A, McDermott MG, Monteiro CD,
       Gundersen GW, Ma'ayan A. "Enrichr: a comprehensive gene set enrichment
       analysis web server 2016 update." Nucleic Acids Research. 2016; gkw377.
    """
    use_online = False
    if fname:
        fopen = open(fname)
    else:
        logger.info(
            "No filename provided for TF perturbations, downloading from Enrichr"
        )
        fopen = urllib.request.urlopen(TFP_URL)
        use_online = True
    p = re.compile(r"(\w+)\s+(\w+)\s+(.+)\s+(\w+)")
    all_info = []
    edges = []
    with fopen as f:
        for line in f:
            if use_online:
                line = line.decode("utf-8")
            vals = line.strip().split("\t")
            m = re.search(p, vals[0])
            all_info.append(m.groups(0))
            if (
                m.group(2) in ["INDUCTION", "ACTIVATION", "OE"] and m.group(4) == "UP"
            ) or (
                m.group(2)
                in [
                    "KD",
                    "KO",
                    "INACTIVATION",
                    "DEPLETION",
                    "SIRNA",
                    "SHRNA",
                    "KNOCKOUT",
                    "DELETION",
                    "INHIBITION",
                ]
                and m.group(4) == "DOWN"
            ):
                tf = m.group(1)
                for target in vals[2:]:
                    edges.append([tf, target])
    all_info = pd.DataFrame(all_info, columns=["tf", "exp", "info", "up_down"])

    perturb_df = pd.DataFrame(edges, columns=["tf", "target"])
    tfs = set(perturb_df["tf"].unique())
    targets = perturb_df["target"].unique()

    logger.info(
        f"TF perturbation reference - {len(tfs)} TFs, {len(targets)} targets, {perturb_df.shape[0]} edges."
    )

    perturb_df["experiments"] = 1
    perturb_df = perturb_df.groupby(["tf", "target"]).count()
    perturb_df["interaction"] = 1
    perturb_df.columns = ["perturb_experiments", "perturb_interaction"]

    valid_tfs = set(get_tfs())
    tfs = list(tfs.intersection(valid_tfs))

    total = []
    for tf in tfs:
        for target in targets:
            total.append([tf, target])
    total = pd.DataFrame(total, columns=["tf", "target"]).set_index(["tf", "target"])

    perturb_df = total.join(perturb_df).fillna(0)
    return perturb_df


def get_tfs():
    valid_factors = pd.read_excel(
        "https://www.biorxiv.org/content/biorxiv/early/2020/12/07/2020.10.28.359232/DC1/embed/media-1.xlsx",
        engine="openpyxl",
        sheet_name=1,
    )
    valid_factors = valid_factors.loc[
        valid_factors["Pseudogene"].isnull(), "HGNC approved gene symbol"
    ].values
    valid_factors = [f for f in valid_factors if f != "EP300"]
    return valid_factors


def read_network(fname, name=None):
    network = fname
    if fname.endswith("feather"):
        df = pd.read_feather(network)
    else:
        df = pd.read_table(network)
    df = fix_columns(df)

    # Make sure there are no duplicate interactions
    df.drop_duplicates(subset=["tf", "target"], inplace=True)
    df = df.set_index(["tf", "target"])

    # Assuming last column is the edge weight
    df = df.iloc[:, [-1]]
    if name is not None:
        df.columns = [name]

    return df


def _read_correlation_reference(network, corCutoff=0.6):
    tfs_name = f"{os.path.dirname(ananse.__file__)}/db/tfs.txt"
    tfs = pd.read_csv(tfs_name, header=None)[0].tolist()
    edb = pd.read_csv(network, sep="\t")

    edb["iscorrelation"] = [1 if i > corCutoff else 0 for i in edb["correlationRank"]]

    edb[["tf", "target"]] = edb["source_target"].str.split("_", expand=True).iloc[:, :2]
    edb = edb.drop(
        columns=["source_target", "ocorrelation", "correlation", "correlationRank"]
    )
    edb = edb[edb.tf.isin(tfs)]
    edb = edb.set_index(["tf", "target"])
    return edb


def _read_goterm_reference(network, goCutoff=0):
    tfs_name = f"{os.path.dirname(ananse.__file__)}/db/tfs.txt"
    tfs = pd.read_csv(tfs_name, header=None)[0].tolist()
    gdb = pd.read_csv(network, sep="\t", header=None)

    gdb["isgo"] = [1 if i > goCutoff else 0 for i in gdb[2]]

    gdb = gdb.rename(columns={3: "tf", 1: "target"})
    gdb = gdb[gdb.tf.isin(tfs)]

    gdb = gdb.drop(columns=[0, 2])
    gdb = gdb.set_index(["tf", "target"])
    return gdb


def _read_msigdb_reference(network):
    msidb = pd.read_csv(network, sep="\t", header=None)
    msidb = msidb.rename(columns={0: "tf", 1: "target"})
    msidb = msidb.set_index(["tf", "target"])
    msidb["interaction"] = 1
    return msidb


def _read_regnet_reference(network):
    regnet = pd.read_csv(network)
    regnet = regnet.rename(
        columns={"regulator_symbol": "tf", "target_symbol": "target"}
    )
    regnet = regnet.set_index(["tf", "target"])
    regnet["interaction"] = 1
    return regnet[["interaction"]]


def read_reference(name, fname=None):
    """
    Valid reference networks (name):
    - dorothea
    - perturbation
    - correlation
    - goterm
    - msigdb
    - regnet
    - trrust
    """
    if name.lower() == "dorothea":
        return _read_dorothea_reference(fname)
    if name.lower() == "perturbation":
        return prepare_reference_network(_read_enrichr_perturbation_reference(fname))
    if name.lower() == "correlation":
        return prepare_reference_network(_read_correlation_reference(fname, 0.6))
    if name.lower() == "goterm":
        return prepare_reference_network(_read_goterm_reference(fname, 0))
    if name.lower() == "msigdb":
        return prepare_reference_network(_read_msigdb_reference(fname))
    if name.lower() == "regnet":
        return prepare_reference_network(_read_regnet_reference(fname))
    if name.lower() == "trrust":
        return prepare_reference_network(fname)


def validate_files(fnames, ignore_missing=False):
    file_error = False
    for fname in fnames:
        if not os.path.exists(fname):
            logger.error(f"file {fname} does not exist")
            file_error = True
    if not ignore_missing and file_error:
        raise ValueError("One or more files not found!")


def read_networks(network_dict, ignore_missing=False):
    """Read predicted networks.

    Input is a dictionary with name as key and filename as value.
    """
    # Validate files first
    validate_files(network_dict.values(), ignore_missing=ignore_missing)

    df = pd.DataFrame({"tf": [], "target": []}).set_index(["tf", "target"])

    for name, fname in network_dict.items():
        if os.path.exists(fname):
            logger.info(f"Reading {name}")
            tmp = read_network(fname, name=name)
            logger.info(f"Merging {name}")
            df = df.join(tmp, how="outer")

    return df


class NetworkBenchmark:
    def __init__(self, ref_dir, outdir):
        logger.info("Reading reference networks")
        self.dorothea = read_reference(
            "dorothea", fname=f"{ref_dir}/dorothea/entire_database.txt"
        )
        self.perturb = read_reference("perturbation")
        self.trrust = read_reference("trrust", fname=f"{ref_dir}/trrust.txt")

        self.references = [
            [self.dorothea, "dorothea", "is_evidence_curated"],
            [self.perturb, "TF perturbations", "interaction"],
            [self.trrust, "TRRUST", "interaction"],
        ]
        self.outdir = outdir
        if not os.path.exists(outdir):
            os.makedirs(outdir)

        self.benchmark_results = []

    def _merge_files(self):
        for network in self.networks:
            outfile = f"{self.outdir}/{network}.feather"

            if os.path.exists(outfile):
                continue

            logger.info(f"Creating merged file for {network} networks.")
            network_dfs = []
            fnames = glob(self.network_files[network])
            for fname in fnames:
                name = fname
                for part in self.network_files[network].split("*"):
                    name = name.replace(part, "")
                logger.info(fname)
                df = read_network(fname)
                df.columns = [f"{network}.{name}"]
                network_dfs.append(df)

            logger.info("Concatenating...")
            df = pd.concat(network_dfs, axis=1)
            df.reset_index().to_feather(outfile)
            logger.info("Done.")

    def benchmark(self, network_files):
        self.network_files = network_files
        self.networks = list(self.network_files.keys())

        # Create a feather file with all individual GRNs merged
        # Saves time if (part of) the benchmark needs to be rerun
        self._merge_files()

        for network in self.networks:  # networks:
            fname = f"{self.outdir}/{network}.feather"
            logger.info(f"reading {fname}")
            df = pd.read_feather(fname)
            logger.info("done")
            df.set_index(df.columns[:2].tolist(), inplace=True)
            cols = df.columns
            min_value = df.min().min() - 1

            for ref_network, name, ref_col in self.references:
                logger.info(f"joining {name}")
                test_network = ref_network.join(df)
                logger.info("done")
                test_network = test_network.fillna(min_value)

                for col in cols:
                    pr_auc = average_precision_score(
                        test_network[ref_col], test_network[col]
                    )
                    roc_auc = roc_auc_score(test_network[ref_col], test_network[col])
                    exp_name = os.path.splitext(os.path.basename(col))[0]
                    logger.info(
                        f"{name}\t{network}\t{exp_name}\t{pr_auc:0.5f}\t{roc_auc:0.2f}"
                    )
                    self.benchmark_results.append(
                        [name, network, exp_name, pr_auc, roc_auc]
                    )

    def plot(self, outname=None):
        bench_df = pd.DataFrame(
            [row[:2] + [row[2][-1]] + row[3:] for row in self.benchmark_results],
            columns=["reference", "network", "tissue", "pr_auc", "roc_auc"],
        )

        for metric in ["pr_auc", "roc_auc"]:
            sns.set_context(
                "paper",
                rc={
                    "font.size": 16,
                    "axes.titlesize": 16,
                    "axes.labelsize": 16,
                    "xtick.labelsize": 16,
                    "ytick.labelsize": 16,
                },
            )
            order = (
                bench_df[bench_df["network"] != "baseline"]
                .groupby("network")
                .mean()[metric]
                .sort_values()
                .index[::-1]
            )

            g = sns.catplot(
                data=bench_df[bench_df["network"] != "baseline"],
                y="network",
                x=metric,
                col="reference",
                kind="box",
                order=order,
                fliersize=0,
                boxprops={"alpha": 0.5},
                sharex=False,
                aspect=1.3,
            )

            for subplot in g.axes:
                for ax in subplot:
                    for line in ax.lines:
                        line.set_alpha(0.3)

            g.map(
                sns.stripplot,
                metric,
                "network",
                color="black",
                # data=bench_df[bench_df["network"] != "baseline"],
                s=8,
                jitter=0.25,
                alpha=0.7,
                order=order,
            )
            # g.set_xticklabels(rotation=30)

            if outname:
                extension = os.path.splitext(outname)[1]
                g.savefig(outname.replace(extension, f".{metric}{extension}"))