choderalab/protons

View on GitHub
protons/scripts/run_simulation.py

Summary

Maintainability
C
7 hrs
Test Coverage
import logging
import os
import signal
import yaml
from typing import Dict
import sys
import netCDF4
from lxml import etree
from saltswap.wrappers import Salinator
from simtk import openmm as mm, unit
from tqdm import trange

from .. import app
from ..app import log, NCMCProtonDrive
from ..app.proposals import UniformSwapProposal
from ..app.driver import SAMSApproach
from .utilities import (
    timeout_handler,
    xml_to_topology,
    deserialize_openmm_element,
    deserialize_state_vector,
    TimeOutError,
    create_protons_checkpoint_file,
)


def run_main(jsonfile):
    """Main simulation loop."""

    # TODO Validate yaml/json input with json schema?
    settings = yaml.load(open(jsonfile, "r"))

    try:
        format_vars: Dict[str, str] = settings["format_vars"]
    except KeyError:
        format_vars = dict()

    # Retrieve runtime settings
    run = settings["run"]

    # Start timeout to enable clean exit on uncompleted runs
    # Note, does not work on windows!
    if os.name != "nt":
        signal.signal(signal.SIGALRM, timeout_handler)
        script_timeout = int(run["timeout_sec"])
        signal.alarm(script_timeout)

    # Input files
    inp = settings["input"]
    idir = inp["dir"].format(**format_vars)
    input_checkpoint_file = os.path.abspath(
        os.path.join(idir, inp["checkpoint"].format(**format_vars))
    )
    # Load checkpoint file
    with open(input_checkpoint_file, "r") as checkpoint:
        checkpoint_tree = etree.fromstring(checkpoint.read())

    checkpoint_date = checkpoint_tree.attrib["date"]
    log.info(f"Reading checkpoint from '{checkpoint_date}'.")

    topology_element = checkpoint_tree.xpath("TopologyFile")[0]
    topology: app.Topology = xml_to_topology(topology_element)

    # Quick fix for histidines in topology
    # Openmm relabels them HIS, which leads to them not being detected as
    # titratable. Renaming them fixes this.

    for residue in topology.residues():
        if residue.name == "HIS":
            residue.name = "HIP"
        # TODO doublecheck if ASH GLH need to be renamed
        elif residue.name == "ASP":
            residue.name = "ASH"
        elif residue.name == "GLU":
            residue.name = "GLH"

    # Naming the output files
    out = settings["output"]
    odir = out["dir"].format(**format_vars)
    obasename = out["basename"].format(**format_vars)
    runid = format_vars["run_idx"]
    if not os.path.isdir(odir):
        os.makedirs(odir)
    lastdir = os.getcwd()
    os.chdir(odir)

    # File for resuming simulation
    output_checkpoint_file = f"{obasename}-checkpoint-{runid}.xml"

    # System Configuration
    system_element = checkpoint_tree.xpath("System")[0]
    system: mm.System = deserialize_openmm_element(system_element)

    # Deserialize the integrator
    integrator_element = checkpoint_tree.xpath("Integrator")[0]
    integrator: mm.CompoundIntegrator = deserialize_openmm_element(integrator_element)

    perturbations_per_trial = int(run["perturbations_per_ncmc_trial"])
    propagations_per_step = int(run["propagations_per_ncmc_step"])

    # Deserialize the proton drive
    drive_element = checkpoint_tree.xpath("NCMCProtonDrive")[0]
    temperature = float(drive_element.get("temperature_kelvin")) * unit.kelvin
    if "pressure_bar" in drive_element.attrib:
        pressure = float(drive_element.get("pressure_bar")) * unit.bar
    else:
        pressure = None

    driver = NCMCProtonDrive(
        temperature,
        topology,
        system,
        pressure=pressure,
        perturbations_per_trial=perturbations_per_trial,
        propagations_per_step=propagations_per_step,
    )
    driver.state_from_xml_tree(drive_element)

    if driver.calibration_state is not None:
        if driver.calibration_state.approach == SAMSApproach.ONESITE:
            driver.define_pools({"calibration": driver.calibration_state.group_index})

    try:
        platform = mm.Platform.getPlatformByName("CUDA")
        properties = {
            "CudaPrecision": "mixed",
            "DeterministicForces": "true",
            "CudaDeviceIndex": os.environ["CUDA_VISIBLE_DEVICES"],
        }
    except Exception as e:
        message = str(e)
        if message == 'There is no registered Platform called "CUDA"':

            log.error(message)
            log.warn("Resorting to default OpenMM platform and properties.")
            platform = None
            properties = None
        else:
            raise

    simulation = app.ConstantPHSimulation(
        topology,
        system,
        integrator,
        driver,
        platform=platform,
        platformProperties=properties,
    )

    # Set the simulation state
    state_element = checkpoint_tree.xpath("State")[0]
    state: mm.State = deserialize_openmm_element(state_element)
    boxvec = state.getPeriodicBoxVectors()
    pos = state.getPositions()
    vel = state.getVelocities()

    simulation.context.setPositions(pos)
    simulation.context.setPeriodicBoxVectors(*boxvec)
    simulation.context.setVelocities(vel)

    # Check if the system has an associated salinator

    saltswap_element = checkpoint_tree.xpath("Saltswap")
    if saltswap_element:
        # Deserialiation workaround
        saltswap_element = saltswap_element[0]
        salt_concentration = (
            float(saltswap_element.get("salt_concentration_molar")) * unit.molar
        )
        salinator = Salinator(
            context=simulation.context,
            system=system,
            topology=topology,
            ncmc_integrator=integrator.getIntegrator(1),
            salt_concentration=salt_concentration,
            pressure=pressure,
            temperature=temperature,
        )
        swapper = salinator.swapper
        deserialize_state_vector(saltswap_element, swapper)
        # Assumes the parameters are already set and the ions are set if needed
        # Don't set the charge rule
        driver.swapper = swapper
        driver.swap_proposal = UniformSwapProposal(cation_coefficient=0.5)

    else:
        salinator = None

    # Add reporters
    ncfile = netCDF4.Dataset(f"{obasename}-{runid}.nc", "w")
    dcd_output_name = f"{obasename}-{runid}.dcd"
    reporters = settings["reporters"]
    if "metadata" in reporters:
        simulation.update_reporters.append(app.MetadataReporter(ncfile))

    if "coordinates" in reporters:
        freq = int(reporters["coordinates"]["frequency"])
        simulation.reporters.append(
            app.DCDReporter(dcd_output_name, freq, enforcePeriodicBox=True)
        )

    if "titration" in reporters:
        freq = int(reporters["titration"]["frequency"])
        simulation.update_reporters.append(app.TitrationReporter(ncfile, freq))

    if "sams" in reporters:
        freq = int(reporters["sams"]["frequency"])
        simulation.calibration_reporters.append(app.SAMSReporter(ncfile, freq))

    if "ncmc" in reporters:
        freq = int(reporters["ncmc"]["frequency"])
        if "work_interval" in reporters["ncmc"]:
            work_interval = int(reporters["ncmc"]["work_interval"])
        else:
            work_interval = 0
        simulation.update_reporters.append(
            app.NCMCReporter(ncfile, freq, work_interval)
        )

    total_iterations = int(run["total_update_attempts"])
    md_steps_between_updates = int(run["md_steps_between_updates"])

    # MAIN SIMULATION LOOP STARTS HERE

    try:
        for i in trange(total_iterations, desc="NCMC attempts"):
            if i == 2:
                log.info("Simulation seems to be working. Suppressing debugging info.")
                log.setLevel(logging.INFO)
            simulation.step(md_steps_between_updates)
            # Perform a few COOH updates in between
            driver.update("COOH", nattempts=3)
            if driver.calibration_state is not None:
                if driver.calibration_state.approach is SAMSApproach.ONESITE:
                    simulation.update(1, pool="calibration")
                else:
                    simulation.update(1)
                simulation.adapt()
            else:
                simulation.update(1)

    except TimeOutError:
        log.warn("Simulation ran out of time, saving current results.")

    finally:
        # export the context
        create_protons_checkpoint_file(
            output_checkpoint_file,
            driver,
            simulation.context,
            simulation.system,
            simulation.integrator,
            topology_element.text,
            salinator=salinator,
        )
        ncfile.close()
        os.chdir(lastdir)


if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Please provide a single json file as input.")
    else:
        # Provide the json file to main function
        run_main(sys.argv[1])