Yank/analyze.py
#!/usr/local/bin/env python
# ==============================================================================
# MODULE DOCSTRING
# ==============================================================================
"""
Analyze
=======
YANK Specific analysis tools for YANK simulations from the :class:`yank.yank.AlchemicalPhase` classes
Extends classes from the MultiStateAnalyzer package to include the
"""
# =============================================================================================
# MODULE IMPORTS
# =============================================================================================
import os
import abc
import copy
import yaml
import mdtraj
import pickle
import logging
from functools import wraps
import mpiplus
import numpy as np
import simtk.unit as units
import openmmtools as mmtools
from pymbar import timeseries
from . import version
from . import utils
from .experiment import ExperimentBuilder
logger = logging.getLogger(__name__)
# =============================================================================================
# MODULE PARAMETERS
# =============================================================================================
# Extend registry to support standard_state_correction
yank_registry = mmtools.multistate.default_observables_registry
yank_registry.register_phase_observable('standard_state_correction')
kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
# =============================================================================================
# MODULE CLASSES
# =============================================================================================
class YankPhaseAnalyzer(mmtools.multistate.PhaseAnalyzer):
def __init__(self, *args, registry=yank_registry, **kwargs):
super().__init__(*args, registry=registry, **kwargs)
# Abstract methods
@abc.abstractmethod
def analyze_phase(self, *args, **kwargs):
"""
Auto-analysis function for the phase
Function which broadly handles "auto-analysis" for those that do not wish to call all the methods on their own.
Returns a dictionary of analysis objects
"""
raise NotImplementedError()
class YankMultiStateSamplerAnalyzer(mmtools.multistate.MultiStateSamplerAnalyzer, YankPhaseAnalyzer):
def get_standard_state_correction(self):
"""
Compute the standard state correction free energy associated with the Phase.
Returns
-------
standard_state_correction : float
Free energy contribution from the standard_state_correction
"""
if self._computed_observables['standard_state_correction'] is not None:
return self._computed_observables['standard_state_correction']
# Determine if we need to recompute the standard state correction.
compute_ssc = self.unbias_restraint
try:
restraint_force, _, _ = self._get_radially_symmetric_restraint_data()
except (mmtools.forces.NoForceFoundError, TypeError):
compute_ssc = False
if compute_ssc:
thermodynamic_state = self._get_end_thermodynamic_states()[0]
restraint_energy_cutoff, restraint_distance_cutoff = self._get_restraint_cutoffs()
# TODO: Compute average box volume here to feed to max_volume?
ssc = restraint_force.compute_standard_state_correction(
thermodynamic_state, square_well=True, radius_cutoff=restraint_distance_cutoff,
energy_cutoff=restraint_energy_cutoff, max_volume='system')
# Update observable.
self._computed_observables['standard_state_correction'] = ssc
logger.debug('Computed a new standard state correction of {} kT'.format(ssc))
# Reads the SSC from the reporter if compute_ssc is False.
if self._computed_observables['standard_state_correction'] is None:
ssc = self._reporter.read_dict('metadata')['standard_state_correction']
self._computed_observables['standard_state_correction'] = ssc
return self._computed_observables['standard_state_correction']
def analyze_phase(self, show_mixing=False, cutoff=0.05):
"""
Auto-analysis function for the phase
Function which broadly handles "auto-analysis" for those that do not wish to call all the methods on their own.
This variant computes the following:
* Equilibration data (accessible through ``n_equilibration_iterations`` and ``statistical_inefficiency``
properties)
* Optional mixing statistics printout
* Free Energy difference between all states with error
* Enthalpy difference between all states with error (as expectation of the reduced potential)
* Free energy of the Standard State Correction for the phase.
Parameters
----------
show_mixing : bool, optional. Default: False
Toggle to show mixing statistics or not. This can be a slow step so is disabled for speed by default.
cutoff : float, optional. Default: 0.05
Threshold below which % of mixing mixing from one state to another is not shown.
Makes the output table more clean to read (rendered as whitespace)
Returns
-------
data : dict
A dictionary of analysis objects, in this case Delta F, Delta H, and Standard State Free Energy
In units of kT (dimensionless)
"""
number_equilibrated, g_t, n_effective_max = self._equilibration_data
if show_mixing:
self.show_mixing_statistics(cutoff=cutoff, number_equilibrated=number_equilibrated)
data = {}
# Accumulate free energy differences
Deltaf_ij, dDeltaf_ij = self.get_free_energy()
DeltaH_ij, dDeltaH_ij = self.get_enthalpy()
data['free_energy_diff'] = Deltaf_ij[self.reference_states[0], self.reference_states[1]]
data['free_energy_diff_error'] = dDeltaf_ij[self.reference_states[0], self.reference_states[1]]
data['enthalpy_diff'] = DeltaH_ij[self.reference_states[0], self.reference_states[1]]
data['enthalpy_diff_error'] = dDeltaH_ij[self.reference_states[0], self.reference_states[1]]
data['free_energy_diff_standard_state_correction'] = self.get_standard_state_correction()
return data
class YankReplicaExchangeAnalyzer(mmtools.multistate.ReplicaExchangeAnalyzer, YankMultiStateSamplerAnalyzer):
pass
class YankParallelTemperingAnalyzer(mmtools.multistate.ParallelTemperingAnalyzer, YankMultiStateSamplerAnalyzer):
pass
def _copyout(wrap):
"""Copy output of function. Small helper to avoid ending each function with the copy call"""
@wraps(wrap)
def make_copy(*args, **kwargs):
return copy.deepcopy(wrap(*args, **kwargs))
return make_copy
class ExperimentAnalyzer(object):
"""
Semi-automated YANK Experiment analysis with serializable data.
This class is designed to replace the older ``analyze_directory`` functions by providing a common analysis data
interface which other classes and methods can draw on. This is designed to semi-automate the combination of
multi-phase data
Each of the main methods fetches the data from each phase and returns them as a dictionary to the user. The total
dump of data to serialized YAML files can also be done.
Each function documents what its output data structure and entries surrounded by curly braces (``{ }``) indicate
variables which change per experiment, often the data.
Output dictionary is of the form:
.. code-block:: python
yank_version: {YANK Version}
phase_names: {Name of each phase, depends on simulation type}
general: {See :func:`get_general_simulation_data`}
equilibration: {See :func:`get_equilibration_data`}
mixing: {See :func:`get_mixing_data`}
free_energy: {See :func:`get_experiment_free_energy_data`}
Parameters
----------
store_directory : string
Location where the analysis.yaml file is and where the NetCDF files are
**analyzer_kwargs
Keyword arguments to pass to the analyzer class. Quantities can be
passed as strings.
Attributes
----------
use_full_trajectory : bool. Analyze with subsampled or complete trajectory
nphases : int. Number of phases detected
phases_names : list of phase names. Used as keys on all attributes below
signs : dict of str. Sign assigned to each phase
analyzers : dict of YankPhaseAnalyzer
iterations : dict of int. Number of maximum iterations in each phase
u_ns : dict of np.ndarray. Timeseries of each phase
nequils : dict of int. Number of equilibrium iterations in each phase
g_ts : dict of int. Subsample rate past nequils in each phase
Neff_maxs : dict of int. Number of effective samples in each phase
Examples
--------
Start with an experiment (Running from the :class:`yank.experiment.ExperimentBuilder` example)
>>> import textwrap
>>> import openmmtools as mmtools
>>> import yank.utils
>>> import yank.experiment.ExperimentBuilder as ExperimentBuilder
>>> setup_dir = yank.utils.get_data_filename(os.path.join('..', 'examples',
... 'p-xylene-implicit', 'input'))
>>> pxylene_path = os.path.join(setup_dir, 'p-xylene.mol2')
>>> lysozyme_path = os.path.join(setup_dir, '181L-pdbfixer.pdb')
>>> with mmtools.utils.temporary_directory() as tmp_dir:
... yaml_content = '''
... ---
... options:
... default_number_of_iterations: 1
... output_dir: {}
... molecules:
... T4lysozyme:
... filepath: {}
... p-xylene:
... filepath: {}
... antechamber:
... charge_method: bcc
... solvents:
... vacuum:
... nonbonded_method: NoCutoff
... systems:
... my_system:
... receptor: T4lysozyme
... ligand: p-xylene
... solvent: vacuum
... leap:
... parameters: [leaprc.gaff, leaprc.ff14SB]
... protocols:
... absolute-binding:
... complex:
... alchemical_path:
... lambda_electrostatics: [1.0, 0.9, 0.8, 0.6, 0.4, 0.2, 0.0]
... lambda_sterics: [1.0, 0.9, 0.8, 0.6, 0.4, 0.2, 0.0]
... solvent:
... alchemical_path:
... lambda_electrostatics: [1.0, 0.8, 0.6, 0.3, 0.0]
... lambda_sterics: [1.0, 0.8, 0.6, 0.3, 0.0]
... experiments:
... system: my_system
... protocol: absolute-binding
... '''.format(tmp_dir, lysozyme_path, pxylene_path)
>>> yaml_builder = ExperimentBuilder(textwrap.dedent(yaml_content))
>>> yaml_builder.run_experiments()
Now analyze the experiment
>>> import os
>>> exp_analyzer = ExperimentAnalyzer(os.path.join(tmp_dir, 'experiment'))
>>> analysis_data = exp_analyzer.auto_analyze()
"""
def __init__(self, store_directory, **analyzer_kwargs):
# Convert analyzer string quantities into variables.
for key, value in analyzer_kwargs.items():
try:
quantity = utils.quantity_from_string(value)
except:
pass
else:
analyzer_kwargs[key] = quantity
# Read in data
analysis_script_path = os.path.join(store_directory, 'analysis.yaml')
if not os.path.isfile(analysis_script_path):
err_msg = 'Cannot find analysis.yaml script in {}'.format(store_directory)
raise RuntimeError(err_msg)
with open(analysis_script_path, 'r') as f:
analysis = yaml.load(f)
phases_names = []
signs = {}
analyzers = {}
for phase, sign in analysis:
phases_names.append(phase)
signs[phase] = sign
storage_path = os.path.join(store_directory, phase + '.nc')
analyzers[phase] = get_analyzer(storage_path, **analyzer_kwargs)
self.phase_names = phases_names
self.signs = signs
self.analyzers = analyzers
self.nphases = len(phases_names)
# Additional data
self.use_full_trajectory = False
if 'use_full_trajectory' in analyzer_kwargs:
self.use_full_trajectory = bool(analyzer_kwargs['use_full_trajectory'])
# Assign flags for other sections along with their global variables
# General Data
self._general_run = False
self.iterations = {}
# Equilibration
self._equilibration_run = False
self.u_ns = {}
self.nequils = {}
self.g_ts = {}
self.Neff_maxs = {}
self._n_discarded = 0
# Mixing Run (state)
self._mixing_run = False
# Replica mixing
self._replica_mixing_run = False
self._free_energy_run = False
self._serialized_data = {'yank_version': version.version, 'phase_names': self.phase_names}
def __del__(self):
# Explicitly close storage
for phase, analyzer in self.analyzers.items():
if analyzer is not None:
del analyzer
@_copyout
def get_general_simulation_data(self):
"""
General purpose simulation data on number of iterations, number of states, and number of atoms.
This just prints out this data in a regular, formatted pattern.
Output is of the form:
.. code-block:: python
{for phase_name in phase_names}
iterations : {int}
natoms : {int}
nreplicas : {int}
nstates : {int}
Returns
-------
general_data : dict
General simulation data by phase.
"""
if not self._general_run:
general_serial = {}
for phase_name in self.phase_names:
serial = {}
analyzer = self.analyzers[phase_name]
try:
positions = analyzer.reporter.read_sampler_states(0)[0].positions
natoms, _ = positions.shape
except AttributeError: # Trap unloaded checkpoint file
natoms = 'No Cpt.'
energies, _, _, = analyzer.reporter.read_energies()
iterations, nreplicas, nstates = energies.shape
serial['iterations'] = iterations
serial['nstates'] = nstates
serial['natoms'] = natoms
serial['nreplicas'] = nreplicas
general_serial[phase_name] = serial
self.iterations = {phase_name: general_serial[phase_name]['iterations'] for phase_name in self.phase_names}
self._serialized_data['general'] = general_serial
self._general_run = True
return self._serialized_data['general']
@_copyout
def get_equilibration_data(self, discard_from_start=1):
"""
Create the equilibration scatter plots showing the trend lines, correlation time,
and number of effective samples
Output is of the form:
.. code-block:: python
{for phase_name in phase_names}
discarded_from_start : {int}
effective_samples : {float}
subsample_rate : {float}
iterations_considered : {1D np.ndarray of int}
subsample_rate_by_iterations_considered : {1D np.ndarray of float}
effective_samples_by_iterations_considered : {1D np.ndarray of float}
count_total_equilibration_samples : {int}
count_decorrelated_samples : {int}
count_correlated_samples : {int}
percent_total_equilibration_samples : {float}
percent_decorrelated_samples : {float}
percent_correlated_samples : {float}
Returns
-------
equilibration_data : dict
Dictionary with the equilibration data
"""
if not self._equilibration_run or discard_from_start != self._n_discarded:
eq_serial = {}
for i, phase_name in enumerate(self.phase_names):
serial = {}
analyzer = self.analyzers[phase_name]
# Data crunching to get timeseries
# TODO: Figure out how not to discard the first sample
# Sample at index 0 is actually the minimized structure and NOT from the equilibrium distribution
# This throws off all of the equilibrium data
t0 = discard_from_start
self._n_discarded = t0
series = analyzer.get_effective_energy_timeseries()
# Update discard_from_start to match t0 if present
try:
iteration = len(series)
data = analyzer.reporter.read_online_analysis_data(None, 't0')
t0 = max(t0, int(data['t0'][0]))
logger.debug('t0 found; using initial t0 = {} instead of 1'.format(t0))
self._n_discarded = t0
except Exception as e:
# No t0 found
logger.debug('Could not find t0: {}'.format(e))
pass
if series.size <= t0:
# Trap case where user has dropped their whole set.
# Rare, but happens, often with debugging
t0 = 0
logger.warning("Alert: analyzed timeseries has the same or fewer number of values as "
"discard_from_start! The whole series has been preserved to ensure there is "
"*something* to analyze.")
self._n_discarded = 0
self.u_ns[phase_name] = analyzer.get_effective_energy_timeseries()[1:]
# Timeseries statistics
i_t, g_i, n_effective_i = mmtools.multistate.get_equilibration_data_per_sample(self.u_ns[phase_name])
n_effective_max = n_effective_i.max()
i_max = n_effective_i.argmax()
n_equilibration = i_t[i_max] + t0
g_t = g_i[i_max]
self.Neff_maxs[phase_name] = n_effective_max
self.nequils[phase_name] = n_equilibration
self.g_ts[phase_name] = g_t
serial['discarded_from_start'] = int(t0)
serial['effective_samples'] = float(self.Neff_maxs[phase_name])
serial['equilibration_samples'] = int(self.nequils[phase_name])
serial['subsample_rate'] = float(self.g_ts[phase_name])
serial['iterations_considered'] = i_t
serial['subsample_rate_by_iterations_considered'] = g_i
serial['effective_samples_by_iterations_considered'] = n_effective_i
# Determine total number of iterations
n_iter = self.iterations[phase_name]
eq = self.nequils[phase_name] + self._n_discarded # Make sure we include the discarded
decor = int(np.floor(self.Neff_maxs[phase_name]))
cor = n_iter - eq - decor
dat = np.array([decor, cor, eq]) / float(n_iter)
serial['count_total_equilibration_samples'] = int(eq)
serial['count_decorrelated_samples'] = int(decor)
serial['count_correlated_samples'] = int(cor)
serial['percent_total_equilibration_samples'] = float(dat[2])
serial['percent_decorrelated_samples'] = float(dat[0])
serial['percent_correlated_samples'] = float(dat[1])
eq_serial[phase_name] = serial
self._serialized_data['equilibration'] = eq_serial
# Set flag
self._equilibration_run = True
return self._serialized_data['equilibration']
@_copyout
def get_mixing_data(self):
"""
Get state diffusion mixing arrays
Output is of the form:
.. code-block:: python
{for phase_name in phase_names}
transitions : {[nstates, nstates] np.ndarray of float}
eigenvalues : {[nstates] np.ndarray of float}
stat_inefficiency : {float}
Returns
-------
mixing_data : dict
Dictionary of mixing data
"""
if not self._mixing_run:
mixing_serial = {}
# Plot a diffusing mixing map for each phase.
for phase_name in self.phase_names:
serial = {}
# Generate mixing statistics.
analyzer = self.analyzers[phase_name]
mixing_statistics = analyzer.generate_mixing_statistics(
number_equilibrated=self.nequils[phase_name])
transition_matrix, eigenvalues, statistical_inefficiency = mixing_statistics
serial['transitions'] = transition_matrix
serial['eigenvalues'] = eigenvalues
serial['stat_inefficiency'] = statistical_inefficiency
mixing_serial[phase_name] = serial
self._serialized_data['mixing'] = mixing_serial
self._mixing_run = True
return self._serialized_data['mixing']
@_copyout
def get_experiment_free_energy_data(self):
"""
Get the free Yank Experiment free energy, broken down by phase and total experiment
Output is of the form:
.. code-block:: python
{for phase_name in phase_names}
sign : {str of either '+' or '-'}
kT : {units.quantity}
free_energy_diff : {float (has units of kT)}
free_energy_diff_error : {float (has units of kT)}
free_energy_diff_standard_state_correction : {float (has units of kT)}
enthalpy_diff : {float (has units of kT)}
enthalpy_diff_error : {float (has units of kT)}
free_energy_diff : {float (has units of kT)}
free_energy_diff_error : {float (has units of kT)}
free_energy_diff_unit : {units.quantity compatible with energy/mole. Corrected for different phase kT}
free_energy_diff_error_unit : {units.quantity compatible with energy/mole. Corrected for different phase kT}
enthalpy_diff : {float (has units of kT)}
enthalpy_diff_error : {float (has units of kT)}
enthalpy_diff_unit : {units.quantity compatible with energy/mole. Corrected for different phase kT}
enthalpy_diff_error_unit : {units.quantity compatible with energy/mole. Corrected for different phase kT}
Returns
-------
free_energy_data : dict
Dictionary of free energy data
"""
# TODO: Possibly rename this function to not confuse with "experimental" free energy?
if not self._free_energy_run:
if not self._equilibration_run:
raise RuntimeError("Cannot run free energy without first running the equilibration. Please run the "
"corresponding function/cell first!")
fe_serial = dict()
data = dict()
for phase_name in self.phase_names:
analyzer = self.analyzers[phase_name]
data[phase_name] = analyzer.analyze_phase()
# Compute free energy and enthalpy
delta_f = 0.0
delta_f_err = 0.0
delta_h = 0.0
delta_h_err = 0.0
# Not assigning units to be more general to whatever kt is later
delta_f_unit = 0.0
delta_f_err_unit = 0.0
delta_h_unit = 0.0
delta_h_err_unit = 0.0
for phase_name in self.phase_names:
serial = {}
kt = self.analyzers[phase_name].kT
serial['kT'] = kt
if not isinstance(delta_f_unit, units.Quantity):
# Assign units to the float if not assigned
held_unit = kt.unit
delta_f_unit *= held_unit
delta_f_err_unit *= held_unit ** 2 # Errors are held in square until the end
delta_h_unit *= held_unit
delta_h_err_unit *= held_unit ** 2
serial['kT'] = kt
sign = self.signs[phase_name]
serial['sign'] = sign
phase_delta_f = data[phase_name]['free_energy_diff']
phase_delta_f_ssc = data[phase_name]['free_energy_diff_standard_state_correction']
phase_delta_f_err = data[phase_name]['free_energy_diff_error']
serial['free_energy_diff'] = phase_delta_f
serial['free_energy_diff_error'] = phase_delta_f_err
serial['free_energy_diff_standard_state_correction'] = data[phase_name][
'free_energy_diff_standard_state_correction'
]
phase_exp_delta_f = sign * (phase_delta_f + phase_delta_f_ssc)
delta_f -= phase_exp_delta_f
delta_f_err += phase_delta_f_err**2
delta_f_unit -= phase_exp_delta_f * kt
delta_f_err_unit += (phase_delta_f_err * kt)**2
phase_delta_h = data[phase_name]['enthalpy_diff']
phase_delta_h_err = data[phase_name]['enthalpy_diff_error']
serial['enthalpy_diff'] = phase_delta_h
serial['enthalpy_diff_error'] = np.sqrt(phase_delta_h_err)
phase_exp_delta_h = sign * (phase_delta_h + phase_delta_f_ssc)
delta_h -= phase_exp_delta_h
delta_h_err += phase_delta_h_err ** 2
delta_h_unit -= phase_exp_delta_h * kt
delta_h_err_unit += (phase_delta_h_err * kt) ** 2
fe_serial[phase_name] = serial
delta_f_err = np.sqrt(delta_f_err) # np.sqrt is fine here since these dont have units
delta_h_err = np.sqrt(delta_h_err)
# Use ** 0.5 instead of np.sqrt since the function strips units, see github issue pandegroup/openmm#2106
delta_f_err_unit = delta_f_err_unit ** 0.5
delta_h_err_unit = delta_h_err_unit ** 0.5
fe_serial['free_energy_diff'] = delta_f
fe_serial['enthalpy_diff'] = delta_h
fe_serial['free_energy_diff_error'] = delta_f_err
fe_serial['enthalpy_diff_error'] = delta_h_err
fe_serial['free_energy_diff_unit'] = delta_f_unit
fe_serial['enthalpy_diff_unit'] = delta_h_unit
fe_serial['free_energy_diff_error_unit'] = delta_f_err_unit
fe_serial['enthalpy_diff_error_unit'] = delta_h_err_unit
self._serialized_data['free_energy'] = fe_serial
self._free_energy_run = True
return self._serialized_data['free_energy']
@_copyout
def auto_analyze(self):
"""
Run the analysis
Output is of the form:
.. code-block:: python
yank_version: {YANK Version}
phase_names: {Name of each phase, depends on simulation type}
general:
{for phase_name in phase_names}
iterations : {int}
natoms : {int}
nreplicas : {int}
nstates : {int}
equilibration:
{for phase_name in phase_names}
discarded_from_start : {int}
effective_samples : {float}
subsample_rate : {float}
iterations_considered : {1D np.ndarray of int}
subsample_rate_by_iterations_considered : {1D np.ndarray of float}
effective_samples_by_iterations_considered : {1D np.ndarray of float}
count_total_equilibration_samples : {int}
count_decorrelated_samples : {int}
count_correlated_samples : {int}
percent_total_equilibration_samples : {float}
percent_decorrelated_samples : {float}
percent_correlated_samples : {float}
mixing:
{for phase_name in phase_names}
transitions : {[nstates, nstates] np.ndarray of float}
eigenvalues : {[nstates] np.ndarray of float}
stat_inefficiency : {float}
free_energy:
{for phase_name in phase_names}
sign : {str of either '+' or '-'}
kT : {units.quantity}
free_energy_diff : {float (has units of kT)}
free_energy_diff_error : {float (has units of kT)}
free_energy_diff_standard_state_correction : {float (has units of kT)}
enthalpy_diff : {float (has units of kT)}
enthalpy_diff_error : {float (has units of kT)}
free_energy_diff : {float (has units of kT)}
free_energy_diff_error : {float (has units of kT)}
free_energy_diff_unit : {units.quantity compatible with energy/mole. Corrected for different phase kT}
free_energy_diff_error_unit : {units.quantity compatible with energy/mole. Corrected for different phase kT}
enthalpy_diff : {float (has units of kT)}
enthalpy_diff_error : {float (has units of kT)}
enthalpy_diff_unit : {units.quantity compatible with energy/mole. Corrected for different phase kT}
enthalpy_diff_error_unit : {units.quantity compatible with energy/mole. Corrected for different phase kT}
Returns
-------
serialized_data : dict
Dictionary of all the auto-analysis calls organized by section headers.
See each of the functions to see each of the sub-dictionary structures
See Also
--------
get_general_simulation_data
get_equilibration_data
get_mixing_data
get_experiment_free_energy_data
"""
_ = self.get_general_simulation_data()
_ = self.get_equilibration_data()
_ = self.get_mixing_data()
_ = self.get_experiment_free_energy_data()
return self._serialized_data
def dump_serial_data(self, path):
"""
Dump the serialized data to YAML file
Parameters
----------
path : str
File name to dump the data to
"""
true_path, ext = os.path.splitext(path)
if not ext: # empty string check
ext = '.yaml'
true_path += ext
with open(true_path, 'w') as f:
f.write(yaml.dump(self._serialized_data))
@staticmethod
def report_version():
print("Rendered with YANK Version {}".format(version.version))
# =============================================================================================
# MODULE FUNCTIONS
# =============================================================================================
def get_analyzer(file_base_path, **analyzer_kwargs):
"""
Utility function to convert storage file to a Reporter and Analyzer by reading the data on file
For now this is mostly placeholder functions since there is only the implemented :class:`ReplicaExchangeAnalyzer`,
but creates the API for the user to work with.
Parameters
----------
file_base_path : string
Complete path to the storage file with filename and extension.
**analyzer_kwargs
Keyword arguments to pass to the analyzer.
Returns
-------
analyzer : instance of implemented :class:`Yank*Analyzer`
Analyzer for the specific phase.
"""
# Eventually extend this to get more reporters, but for now simple placeholder
reporter = mmtools.multistate.MultiStateReporter(file_base_path, open_mode='r')
# Eventually change this to auto-detect simulation from reporter:
if True:
analyzer = YankReplicaExchangeAnalyzer(reporter, **analyzer_kwargs)
else:
raise RuntimeError("Cannot automatically determine analyzer for Reporter: {}".format(reporter))
return analyzer
def analyze_directory(source_directory, **analyzer_kwargs):
"""
Analyze contents of store files to compute free energy differences.
This function is needed to preserve the old auto-analysis style of YANK. What it exactly does can be refined when
more analyzers and simulations are made available. For now this function exposes the API.
Parameters
----------
source_directory : string
The location of the simulation storage files.
**analyzer_kwargs
Keyword arguments to pass to the analyzer.
Returns
-------
analysis_data : dict
Dictionary containing all the automatic analysis data
"""
auto_experiment_analyzer = ExperimentAnalyzer(source_directory, **analyzer_kwargs)
analysis_data = auto_experiment_analyzer.auto_analyze()
print_analysis_data(analysis_data)
# Clean up analyzer, forcing NetCDF file to close.
del auto_experiment_analyzer
return analysis_data
@mpiplus.on_single_node(0)
def print_analysis_data(analysis_data, header=None):
"""
Helper function of printing the analysis data payload from a :func:`ExperimentAnalyzer.auto_analyze` call
Parameters
----------
analysis_data : dict
Output from :func:`ExperimentAnalyzer.auto_analyze`
header : str, Optional
Optional header string to print before the formatted helper, useful if you plan to make this call multiple
times but want to divide the outputs.
"""
if header is not None:
print(header)
try:
fe_data = analysis_data['free_energy']
delta_f = fe_data['free_energy_diff']
delta_h = fe_data['enthalpy_diff']
delta_f_err = fe_data['free_energy_diff_error']
delta_h_err = fe_data['enthalpy_diff_error']
delta_f_unit = fe_data['free_energy_diff_unit']
delta_h_unit = fe_data['enthalpy_diff_unit']
delta_f_err_unit = fe_data['free_energy_diff_error_unit']
delta_h_err_unit = fe_data['enthalpy_diff_error_unit']
except (KeyError, TypeError):
# Trap error in formatted data
print("Error in reading analysis data! It may be the analysis threw an error, please see below for what "
"was received instead:\n\n{}\n\n".format(analysis_data))
return
# Attempt to guess type of calculation
calculation_type = ''
for phase in analysis_data['phase_names']:
if 'complex' in phase:
calculation_type = ' of binding'
elif 'solvent1' in phase:
calculation_type = ' of solvation'
print('Free energy{:<13}: {:9.3f} +- {:.3f} kT ({:.3f} +- {:.3f} kcal/mol)'.format(
calculation_type, delta_f, delta_f_err, delta_f_unit / units.kilocalories_per_mole,
delta_f_err_unit / units.kilocalories_per_mole))
for phase in analysis_data['phase_names']:
delta_f_phase = fe_data[phase]['free_energy_diff']
delta_f_err_phase = fe_data[phase]['free_energy_diff_error']
detla_f_ssc_phase = fe_data[phase]['free_energy_diff_standard_state_correction']
print('DeltaG {:<17}: {:9.3f} +- {:.3f} kT'.format(phase, delta_f_phase,
delta_f_err_phase))
if detla_f_ssc_phase != 0.0:
print('DeltaG {:<17}: {:18.3f} kT'.format('standard state correction', detla_f_ssc_phase))
print('')
print('Enthalpy{:<16}: {:9.3f} +- {:.3f} kT ({:.3f} +- {:.3f} kcal/mol)'.format(
calculation_type, delta_h, delta_h_err, delta_h_unit / units.kilocalories_per_mole,
delta_h_err_unit / units.kilocalories_per_mole))
class MultiExperimentAnalyzer(object):
"""
Automatic Analysis tool for Multiple YANK Experiments from YAML file
This class takes in a YAML file, infers the experiments from expansion of all combinatorial options,
then does the automatic analysis run from :func:`ExperimentAnalyzer.auto_analyze` yielding a final
dictionary output.
Parameters
----------
script : str or dict
Full path to the YAML file which made the YANK experiments.
OR
The loaded yaml content of said script.
builder_kwargs
Additional keyword arguments which normally are passed to the :class:`ExperimentBuilder` constructor.
The experiments are not setup or built, only the output structure is referenced.
See Also
--------
ExperimentAnalyzer.auto_analyze
"""
def __init__(self, script, **builder_kwargs):
self.script = script
self.builder = ExperimentBuilder(script=script, **builder_kwargs)
self.paths = self.builder.get_experiment_directories()
def run_all_analysis(self, serialize_data=True, serial_data_path=None, **analyzer_kwargs):
"""
Run all the automatic analysis through the :func:`ExperimentAnalyzer.auto_analyze`
Parameters
----------
serialize_data : bool, Default: True
Choose whether or not to serialize the data
serial_data_path: str, Optional
Name of the serial data file. If not specified, name will be {YAML file name}_analysis.pkl`
analyzer_kwargs
Additional keywords which will be fed into the :class:`YankMultiStateSamplerAnalyzer` for each phase of
each experiment.
Returns
-------
serial_output : dict
Dictionary of each experiment's output of format
{exp_name: ExperimentAnalyzer.auto_analyze() for exp_name in ExperimentBuilder's Experiments}
The sub-dictionary of each key can be seen in :func:`ExperimentAnalyzer.auto_analyze()` docstring
See Also
--------
ExperimentAnalyzer.auto_analyze
"""
if serial_data_path is None:
serial_ending = 'analysis.pkl'
try:
if not os.path.isfile(self.script):
# Not a file, string like input
raise TypeError
script_base, _ = os.path.splitext(self.script)
serial_data_path = script_base + '_' + serial_ending
except TypeError:
# Traps both YAML content as string, and YAML content as dict
serial_data_path = os.path.join('.', serial_ending)
analysis_serials = mpiplus.distribute(self._run_analysis,
self.paths,
**analyzer_kwargs)
output = {}
for path, analysis in zip(self.paths, analysis_serials):
name = os.path.split(path)[-1] # Get the name of the experiment
# Corner case where user has specified a singular experiment and name is just an empty directory
# Impossible with any type of combinatorial action
if name == '':
name = 'experiment'
# Check to ensure the output is stable
output[name] = analysis if not isinstance(analysis, Exception) else str(analysis)
if serialize_data:
self._serialize(serial_data_path, output)
return output
@staticmethod
def _serialize(serial_path, payload):
"""
Helper function to serialize data which can be subclassed
Parameters
----------
serial_path : str
Path of serial file
payload : object
Object to serialize (pickle)
"""
with open(serial_path, 'wb') as f:
pickle.dump(payload, f)
def _run_analysis(self, path, **analyzer_kwargs):
"""
Helper function to allow parallel through MPI analysis of the experiments
Parameters
----------
path : str
Location of YANK experiment output data.
analyzer_kwargs
Additional keywords which will be fed into the :class:`YankMultiStateSamplerAnalyzer` for each phase of
each experiment.
Returns
-------
payload : dict or Exception
Results from automatic analysis output or the exception that was thrown.
Having the exception trapped but thrown later allows for one experiment to fail but not stop the others
from analyzing.
"""
try:
# Allow the _run_specific_analysis to be subclassable without requiring the error trap to be re-written.
return self._run_specific_analysis(path, **analyzer_kwargs)
except Exception as e:
# Trap any error in a non-crashing way
return e
@staticmethod
def _run_specific_analysis(path, **analyzer_kwargs):
""" Helper function to run an individual auto analysis which can be subclassed"""
return ExperimentAnalyzer(path, **analyzer_kwargs).auto_analyze()
# ==========================================
# HELPER FUNCTIONS FOR TRAJECTORY EXTRACTION
# ==========================================
def extract_u_n(ncfile):
"""
Extract timeseries of u_n = - log q(X_n) from store file
where q(X_n) = \pi_{k=1}^K u_{s_{nk}}(x_{nk})
with X_n = [x_{n1}, ..., x_{nK}] is the current collection of replica configurations
s_{nk} is the current state of replica k at iteration n
u_k(x) is the kth reduced potential
TODO: Figure out a way to remove this function
Parameters
----------
ncfile : netCDF4.Dataset
Open NetCDF file to analyze
Returns
-------
u_n : numpy array of numpy.float64
u_n[n] is -log q(X_n)
"""
# Get current dimensions.
niterations = ncfile.variables['energies'].shape[0]
nstates = ncfile.variables['energies'].shape[1]
natoms = ncfile.variables['energies'].shape[2]
# Extract energies.
logger.info("Reading energies...")
energies = ncfile.variables['energies']
u_kln_replica = np.zeros([nstates, nstates, niterations], np.float64)
for n in range(niterations):
u_kln_replica[:, :, n] = energies[n, :, :]
logger.info("Done.")
# Deconvolute replicas
logger.info("Deconvoluting replicas...")
u_kln = np.zeros([nstates, nstates, niterations], np.float64)
for iteration in range(niterations):
state_indices = ncfile.variables['states'][iteration, :]
u_kln[state_indices, :, iteration] = energies[iteration, :, :]
logger.info("Done.")
# Compute total negative log probability over all iterations.
u_n = np.zeros([niterations], np.float64)
for iteration in range(niterations):
u_n[iteration] = np.sum(np.diagonal(u_kln[:, :, iteration]))
return u_n
# ==============================================================================
# Extract trajectory from NetCDF4 file
# ==============================================================================
def extract_trajectory(nc_path, nc_checkpoint_file=None, state_index=None, replica_index=None,
start_frame=0, end_frame=-1, skip_frame=1, keep_solvent=True,
discard_equilibration=False, image_molecules=False):
"""Extract phase trajectory from the NetCDF4 file.
Parameters
----------
nc_path : str
Path to the primary nc_file storing the analysis options
nc_checkpoint_file : str or None, Optional
File name of the checkpoint file housing the main trajectory
Used if the checkpoint file is differently named from the default one chosen by the nc_path file.
Default: None
state_index : int, optional
The index of the alchemical state for which to extract the trajectory.
One and only one between state_index and replica_index must be not None
(default is None).
replica_index : int, optional
The index of the replica for which to extract the trajectory. One and
only one between state_index and replica_index must be not None (default
is None).
start_frame : int, optional
Index of the first frame to include in the trajectory (default is 0).
end_frame : int, optional
Index of the last frame to include in the trajectory. If negative, will
count from the end (default is -1).
skip_frame : int, optional
Extract one frame every skip_frame (default is 1).
keep_solvent : bool, optional
If False, solvent molecules are ignored (default is True).
discard_equilibration : bool, optional
If True, initial equilibration frames are discarded (see the method
pymbar.timeseries.detectEquilibration() for details, default is False).
Returns
-------
trajectory: mdtraj.Trajectory
The trajectory extracted from the netcdf file.
"""
# Check correct input
if (state_index is None) == (replica_index is None):
raise ValueError('One and only one between "state_index" and '
'"replica_index" must be specified.')
if not os.path.isfile(nc_path):
raise ValueError('Cannot find file {}'.format(nc_path))
# Import simulation data
reporter = None
try:
reporter = mmtools.multistate.MultiStateReporter(nc_path, open_mode='r', checkpoint_storage=nc_checkpoint_file)
metadata = reporter.read_dict('metadata')
reference_system = mmtools.utils.deserialize(metadata['reference_state']).system
topography = mmtools.utils.deserialize(metadata['topography'])
topology = topography.topology
# Determine if system is periodic
is_periodic = reference_system.usesPeriodicBoundaryConditions()
logger.info('Detected periodic boundary conditions: {}'.format(is_periodic))
# Get dimensions
# Assume full iteration until proven otherwise
last_checkpoint = True
trajectory_storage = reporter._storage_checkpoint
if not keep_solvent:
# If tracked solute particles, use any last iteration, set with this logic test
full_iteration = len(reporter.analysis_particle_indices) == 0
if not full_iteration:
trajectory_storage = reporter._storage_analysis
topology = topology.subset(reporter.analysis_particle_indices)
n_iterations = reporter.read_last_iteration(last_checkpoint=last_checkpoint)
n_frames = trajectory_storage.variables['positions'].shape[0]
n_atoms = trajectory_storage.variables['positions'].shape[2]
logger.info('Number of frames: {}, atoms: {}'.format(n_frames, n_atoms))
# Determine frames to extract.
# Convert negative indices to last indices.
if start_frame < 0:
start_frame = n_frames + start_frame
if end_frame < 0:
end_frame = n_frames + end_frame + 1
frame_indices = range(start_frame, end_frame, skip_frame)
if len(frame_indices) == 0:
raise ValueError('No frames selected')
logger.info('Extracting frames from {} to {} every {}'.format(
start_frame, end_frame, skip_frame))
# Discard equilibration samples
if discard_equilibration:
u_n = extract_u_n(reporter._storage_analysis)
# Discard frame 0 with minimized energy which throws off automatic equilibration detection.
n_equil_iterations, g, n_eff = timeseries.detectEquilibration(u_n[1:])
n_equil_iterations += 1
logger.info(("Discarding initial {} equilibration samples (leaving {} "
"effectively uncorrelated samples)...").format(n_equil_iterations, n_eff))
# Find first frame post-equilibration.
if not full_iteration:
for iteration in range(n_equil_iterations, n_iterations):
n_equil_frames = reporter._calculate_checkpoint_iteration(iteration)
if n_equil_frames is not None:
break
else:
n_equil_frames = n_equil_iterations
frame_indices = frame_indices[n_equil_frames:-1]
# Determine the number of frames that the trajectory will have.
if state_index is None:
n_trajectory_frames = len(frame_indices)
else:
# With SAMS, an iteration can have 0 or more replicas in a given state.
# Deconvolute state indices.
state_indices = [None for _ in frame_indices]
for i, iteration in enumerate(frame_indices):
replica_indices = reporter._storage_analysis.variables['states'][iteration, :]
state_indices[i] = np.where(replica_indices == state_index)[0]
n_trajectory_frames = sum(len(x) for x in state_indices)
# Initialize positions and box vectors arrays.
# MDTraj Cython code expects float32 positions.
positions = np.zeros((n_trajectory_frames, n_atoms, 3), dtype=np.float32)
if is_periodic:
box_vectors = np.zeros((n_trajectory_frames, 3, 3), dtype=np.float32)
# Extract state positions and box vectors.
if state_index is not None:
logger.info('Extracting positions of state {}...'.format(state_index))
# Extract state positions and box vectors.
frame_idx = 0
for i, iteration in enumerate(frame_indices):
for replica_index in state_indices[i]:
positions[frame_idx, :, :] = trajectory_storage.variables['positions'][iteration, replica_index, :, :].astype(np.float32)
if is_periodic:
box_vectors[frame_idx, :, :] = trajectory_storage.variables['box_vectors'][iteration, replica_index, :, :].astype(np.float32)
frame_idx += 1
else: # Extract replica positions and box vectors
logger.info('Extracting positions of replica {}...'.format(replica_index))
for i, iteration in enumerate(frame_indices):
positions[i, :, :] = trajectory_storage.variables['positions'][iteration, replica_index, :, :].astype(np.float32)
if is_periodic:
box_vectors[i, :, :] = trajectory_storage.variables['box_vectors'][iteration, replica_index, :, :].astype(np.float32)
finally:
if reporter is not None:
reporter.close()
# Create trajectory object
logger.info('Creating trajectory object...')
trajectory = mdtraj.Trajectory(positions, topology)
if is_periodic:
trajectory.unitcell_vectors = box_vectors
# Force periodic boundary conditions to molecules positions
if image_molecules and is_periodic:
logger.info('Applying periodic boundary conditions to molecules positions...')
# Use the receptor as an anchor molecule.
anchor_atom_indices = set(topography.receptor_atoms)
if len(anchor_atom_indices) == 0: # Hydration free energy.
anchor_atom_indices = set(topography.solute_atoms)
anchor_molecules = [{a for a in topology.atoms if a.index in anchor_atom_indices}]
trajectory.image_molecules(inplace=True, anchor_molecules=anchor_molecules)
elif image_molecules:
logger.warning('The molecules will not be imaged because the system is non-periodic.')
return trajectory