dwhswenson/contact_map

View on GitHub
contact_map/atom_indexer.py

Summary

Maintainability
A
0 mins
Test Coverage
import collections
import numpy as np
import mdtraj as md

def _atom_slice(traj, indices):
    """Mock MDTraj.atom_slice without rebuilding topology"""
    xyz = np.array(traj.xyz[:, indices], order='C')
    topology = traj.topology.copy()
    if traj._have_unitcell:
        unitcell_lengths = traj._unitcell_lengths.copy()
        unitcell_angles = traj._unitcell_angles.copy()
    else:
        unitcell_lengths = None
        unitcell_angles = None
    time = traj._time.copy()

    # Hackish to make the smart slicing work
    topology._atoms = indices
    topology._numAtoms = len(indices)
    return md.Trajectory(xyz=xyz, topology=topology, time=time,
                         unitcell_lengths=unitcell_lengths,
                         unitcell_angles=unitcell_angles)

def residue_query_atom_idxs(sliced_query, atom_idx_to_residue_idx):
    residue_query_atom_idxs = collections.defaultdict(list)
    for sliced_idx in sliced_query:
        residue_idx = atom_idx_to_residue_idx[sliced_idx]
        residue_query_atom_idxs[residue_idx].append(sliced_idx)
    return residue_query_atom_idxs


class AtomSlicedIndexer(object):
    """Indexer when using atom slicing.
    """
    def __init__(self, topology, real_query, real_haystack, all_atoms):
        self.all_atoms = all_atoms
        self.sliced_idx = {
            real_idx : sliced_idx
            for sliced_idx, real_idx in enumerate(all_atoms)
        }
        self.real_idx = {
            sliced_idx: real_idx
            for real_idx, sliced_idx in self.sliced_idx.items()
        }
        self.query = set([self.sliced_idx[q] for q in real_query])
        self.haystack = set([self.sliced_idx[h] for h in real_haystack])

        # atom_idx_to_residue_idx
        self.real_atom_idx_to_residue_idx = {atom.index: atom.residue.index
                                             for atom in topology.atoms}
        self.atom_idx_to_residue_idx = {
            sliced_idx: self.real_atom_idx_to_residue_idx[real_idx]
            for sliced_idx, real_idx in enumerate(all_atoms)
        }
        self.residue_query_atom_idxs = residue_query_atom_idxs(
            self.query, self.atom_idx_to_residue_idx
        )

    def ignore_atom_idx(self, atoms, all_atoms_set):
        result = set(atom.index for atom in atoms)
        result &= all_atoms_set
        result = set(self.sliced_idx[a] for a in result)
        return result

    def convert_atom_contacts(self, atom_contacts):
        result =  {frozenset(map(self.real_idx.__getitem__, pair)): value
                   for pair, value in atom_contacts.items()}
        return collections.Counter(result)

    def slice_trajectory(self, trajectory):
        # Prevent (memory) expensive atom slicing if not needed.
        # This check is also needed here because ContactFrequency slices the
        # whole trajectory before calling this function.
        if len(self.all_atoms) < trajectory.topology.n_atoms:
            sliced = _atom_slice(trajectory, self.all_atoms)
        else:
            sliced = trajectory
        return sliced


class IdentityIndexer(object):
    """Indexer when not using atom slicing.
    """
    def __init__(self, topology, real_query, real_haystack, all_atoms):
        self.all_atoms = all_atoms
        self.topology = topology
        identity_mapping = {a: a for a in range(topology.n_atoms)}
        self.sliced_idx = identity_mapping
        self.real_idx = identity_mapping
        self.query = set(real_query)
        self.haystack = set(real_haystack)
        self.real_atom_idx_to_residue_idx = {atom.index: atom.residue.index
                                             for atom in topology.atoms}
        self.atom_idx_to_residue_idx = self.real_atom_idx_to_residue_idx
        self.residue_query_atom_idxs = residue_query_atom_idxs(
            self.query, self.atom_idx_to_residue_idx
        )

    def ignore_atom_idx(self, atoms, all_atoms_set):
        return set(atom.index for atom in atoms)

    def convert_atom_contacts(self, atom_contacts):
        return atom_contacts

    def slice_trajectory(self, trajectory):
        return trajectory