matteoferla/Fragmenstein

View on GitHub
fragmenstein/laboratory/_score.py

Summary

Maintainability
A
3 hrs
Test Coverage
import contextlib
import operator
from typing import List
from warnings import warn

import pandas as pd
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, PandasTools
from rdkit.Chem import rdFingerprintGenerator as rdfpg
from rdkit.Chem import rdMolDescriptors as rdmd
from rdkit.Chem.rdfiltercatalog import FilterCatalogParams, FilterCatalog, FilterCatalogEntry
from rdkit.ML.Cluster import Butina

from .._cli_defaults import cli_default_settings
from ..monster import Monster

# ----- Scoring -----------------------------------
params = FilterCatalogParams()
params.AddCatalog(FilterCatalogParams.FilterCatalogs.PAINS)
catalog = FilterCatalog(params)

def get_pains(mol) -> List[str]:
    with contextlib.suppress(Exception):
        entry: FilterCatalogEntry
        if not isinstance(mol, Chem.Mol) or mol.GetNumHeavyAtoms() == 0:
            return []
        AllChem.SanitizeMol(mol)
        return [entry.GetDescription() for entry in catalog.GetMatches(mol)]

class GetRowSimilarity:
    def __init__(self, hits):
        self.fpgen = rdfpg.GetRDKitFPGenerator()
        self.hit2fp = {h.GetProp('_Name'): self.fpgen.GetFingerprint(h) for h in hits}

    def __call__(self, row: pd.Series):
        with contextlib.suppress(cli_default_settings['supressed_exceptions']):
            if not isinstance(row.minimized_mol, Chem.Mol):
                return float('nan')
            elif isinstance(row.hit_names, str):
                hit_names = row.hit_names.split(',')
            elif isinstance(row.hit_names, list):
                hit_names = row.hit_names
            else:
                return float('nan')
            fp = self.fpgen.GetFingerprint(AllChem.RemoveHs(row.minimized_mol))
            return max([DataStructs.TanimotoSimilarity(fp, self.hit2fp[name]) for name in hit_names])
        return float('nan')


class HitIntxnTallier:

    def __init__(self, hit_replacements):
        self.slim_hits = self.slim_down(hit_replacements)

    def slim_down(self, hit_replacements):
        # the bleaching was fixed cf bleached_name
        # hit_replacements['new_name'] = hit_replacements.name.str.replace('-', '_')
        # undoing bleaching
        hit_replacements['new_name'] = hit_replacements.hit_mols.apply(lambda ms: ms[0].GetProp('_Name'))
        columns = [c for c in hit_replacements.columns if isinstance(c, tuple)]
        return hit_replacements.set_index('new_name')[columns].fillna(0).copy()

    def __call__(self, row: pd.Series):
        with contextlib.suppress(cli_default_settings['supressed_exceptions']):
            if not isinstance(row.minimized_mol, Chem.Mol) or isinstance(row.hit_names, float):
                return float('nan'), float('nan')
            present_tally = 0
            absent_tally = 0
            for hit_name in list(row.hit_names):
                if hit_name not in self.slim_hits.index:
                    raise Exception('Name' + hit_name)
                hit_row = self.slim_hits.loc[hit_name]
                for intxn_name, hit_value in hit_row.items():
                    if not hit_value:
                        continue
                    elif intxn_name not in row.index:
                        absent_tally += 1 if intxn_name[0] != 'hydroph_interaction' else 0.5
                    elif row[intxn_name]:
                        absent_tally += 1 if intxn_name[0] != 'hydroph_interaction' else 0.5
                    else:
                        present_tally += 1 if intxn_name[0] != 'hydroph_interaction' else 0.5
            return present_tally, absent_tally
        return float('nan'), float('nan')


class UniquenessMeter:
    def __init__(self, tallies, intxn_names, k=0.5):
        self.tallies = tallies
        self.intxn_names = intxn_names
        self.k = k

    def __call__(self, row):
        with contextlib.suppress(cli_default_settings['supressed_exceptions']):
            return sum([(row[name] / self.tallies[name]) ** self.k for name in self.intxn_names if
                        row[name] and self.tallies[name]])
        return float('nan')

    def tally_interactions(self, row):
        return sum([row[c] if self.intxn_names[0] != 'hydroph_interaction' else row[c] * 0.5 for c in self.intxn_names])


class PenaltyMeter:
    def __init__(self, weights, nan_penalty=10):
        self.weights = weights
        self.nan_penalty = nan_penalty

    def __call__(self, row):
        with contextlib.suppress(cli_default_settings['supressed_exceptions']):
            penalty = 0
            if row.outcome != 'acceptable':
                return float('inf')
            for col, w in self.weights.items():
                if col not in row.index:
                    warn(f'{col} column is missing from df')
                    continue
                penalty += row[col] * w if str(row[col]) != 'nan' else self.nan_penalty
            return penalty
        return float('nan')

    def make_weighted_df(self, df) -> pd.DataFrame:
        """
        Inspect whether the weights make sense
        """
        weighted = pd.DataFrame({k: df[k] * w for k, w in self.weights.items()})
        weighted['total'] = df.apply(self, axis=1)
        weighted.sort_values('total')
        return weighted


def butina_cluster(mol_list, cutoff=0.35):
    # https://github.com/PatWalters/workshop/blob/master/clustering/taylor_butina.ipynb
    fp_list = [rdmd.GetMorganFingerprintAsBitVect(AllChem.RemoveAllHs(m), 3, nBits=2048) for m in mol_list]
    dists = []
    nfps = len(fp_list)
    for i in range(1, nfps):
        sims = DataStructs.BulkTanimotoSimilarity(fp_list[i], fp_list[:i])
        dists.extend([1 - x for x in sims])
    mol_clusters = Butina.ClusterData(dists, nfps, cutoff, isDistData=True)
    cluster_id_list = [0] * nfps
    for idx, cluster in enumerate(mol_clusters, 1):
        for member in cluster:
            cluster_id_list[member] = idx
    return cluster_id_list

def UFF_Gibbs(mol):
    # free energy cost of bound conformer
    if not isinstance(mol, Chem.Mol) or mol.GetNumHeavyAtoms() == 0:
        return float('nan')
    with contextlib.suppress(cli_default_settings['supressed_exceptions']):
        AllChem.SanitizeMol(mol)
        # this is actually UFF
        copy = Chem.Mol(mol)
        return Monster.MMFF_score(None, mol, True)
    return float('nan')

class LabScore:

    @classmethod
    def score(cls,
              placements: pd.DataFrame,
              hit_replacements: pd.DataFrame,
              weights: dict,
              **settings):
        """
        This is very much a method for the CLI.
        A real Pythonic usage would be to address the individual components.
        """
        if 'minimized_mol' not in placements.columns:
            cls.Victor.journal.critical('No minimized_mol column')
            return placements
        # tanimoto
        hits: List[Chem.Mol] = hit_replacements.hit_mols.apply(operator.itemgetter(0)).to_list()
        get_similarity = GetRowSimilarity(hits)
        placements['max_hit_Tanimoto'] = placements.apply(get_similarity, axis=1)
        # properties
        m = placements.minimized_mol.apply(lambda m: m if isinstance(m, Chem.Mol) else Chem.Mol())
        # macrocyclics... yuck.
        placements['largest_ring'] = m.apply(lambda mol: max([0] + list(map(len, mol.GetRingInfo().AtomRings()))))
        # interactions
        with contextlib.suppress(cli_default_settings['supressed_exceptions']):
            cls.fix_intxns(placements)
            tally_hit_intxns = HitIntxnTallier(hit_replacements)
            hit_checks = placements.apply(tally_hit_intxns, axis=1)
            placements['N_interactions_kept'] = hit_checks.apply(operator.itemgetter(0))  # .fillna(0).astype(int)
            placements['N_interactions_lost'] = hit_checks.apply(operator.itemgetter(1))  # .fillna(99).astype(int)
            intxn_names = [c for c in placements.columns if isinstance(c, tuple)]
            tallies = placements[intxn_names].sum()
            ratioed = UniquenessMeter(tallies, intxn_names, k=0.5)
            placements['interaction_uniqueness_metric'] = placements.apply(ratioed, axis=1)
            placements['N_interactions'] = placements.apply(ratioed.tally_interactions, axis=1)
        with contextlib.suppress(cli_default_settings['supressed_exceptions']):
            placements['PAINSes'] = placements.minimized_mol.apply(get_pains)
            placements['N_PAINS'] = placements.PAINSes.apply(len)
        with contextlib.suppress(cli_default_settings['supressed_exceptions']):
            placements['UFF_Gibbs'] = placements.minimized_mol.apply(UFF_Gibbs)
            placements['strain_per_HA'] = placements.UFF_Gibbs / (placements.N_HA + 0.0001)
        with contextlib.suppress(cli_default_settings['supressed_exceptions']):
            penalize = PenaltyMeter(weights)
            placements['ad_hoc_penalty'] = placements.apply(penalize, axis=1)
        with contextlib.suppress(cli_default_settings['supressed_exceptions']):
            placements['cluster'] = butina_cluster(m.to_list())

    @staticmethod
    def export_sdf(df: pd.DataFrame,
                   penalty_col:str = 'ad_hoc_penalty',
                   filename: str = 'fragmenstein.sdf',
                   target_name: str = 'DOESNT_MATTER_UNLESS_YOU_WANT_TO_REMOVE_IT'):
        def fix(mol: Chem.Mol) -> None:
            assert isinstance(mol, Chem.Mol)
            assert mol.GetNumAtoms()
            mol.ClearComputedProps()
            for name in mol.GetPropNames():
                mol.ClearProp(name)

        with pd.option_context('mode.chained_assignment', None):
            df = df.loc[df.outcome == 'acceptable'] \
                .sort_values(penalty_col) \
                .rename(columns={c: ':'.join(map(str, c)) for c in df.columns if isinstance(c, tuple)}) \
                .reset_index() \
                .copy()
            # list of str to str w/ comma-separator
            df['ref_mols'] = df.hit_names.apply(lambda l: ','.join([v.replace(f'{target_name}-', '') for v in l]))
            df['washed_mol'] = df.minimized_mol.apply(fix)
            df['name'] = df['name'].apply(lambda v: v.split('-D68EV3CPROA')[0])
            # non str/float/ints
            not_okay = ('name', 'minimized_mol', 'ref_mols', 'washed_mol', 'mode',
                        'runtime', 'error', 'outcome',
                        'smiles',
                        'regarded', 'disregarded', 'hit_names',
                        '∆G_bound',
                        '∆G_unbound',
                        'unmin_binary',
                        'min_binary',
                        'hit_binaries',
                        'minimized_mol',
                        'hit_mols', 'unminimized_mol', 'hit_names')
        good_columns = df.columns[~df.map(lambda x: not isinstance(x, (float, str))).any()]
        extras: List[str] = [c for c in df.columns if c in good_columns and not c in not_okay]
        PandasTools.WriteSDF(df, out=filename, properties=extras, molColName='minimized_mol')