Yank/utils.py
#!/usr/local/bin/env python
# ==============================================================================
# MODULE DOCSTRING
# ==============================================================================
"""
Utils
=====
Utilities for the YANK modules
Provides many helper functions and common operations used by the various YANK suites
"""
# ==============================================================================
# GLOBAL IMPORTS
# ==============================================================================
import os
import re
import copy
import glob
import shutil
import inspect
import logging
import importlib
import functools
import itertools
import contextlib
import subprocess
import collections
from pkg_resources import resource_filename
import mdtraj
import parmed
import numpy as np
from simtk import unit
import openmmtools as mmtools
# ========================================================================================
# Logging functions
# ========================================================================================
def is_terminal_verbose():
"""Check whether the logging on the terminal is configured to be verbose.
This is useful in case one wants to occasionally print something that is not really
relevant to yank's log (e.g. external library verbose, citations, etc.).
Returns
-------
is_verbose : bool
True if the terminal is configured to be verbose, False otherwise.
"""
# If logging.root has no handlers this will ensure that False is returned
is_verbose = False
for handler in logging.root.handlers:
# logging.FileHandler is a subclass of logging.StreamHandler so
# isinstance and issubclass do not work in this case
if type(handler) is logging.StreamHandler and handler.level <= logging.DEBUG:
is_verbose = True
break
return is_verbose
def config_root_logger(verbose, log_file_path=None):
"""
Setup the the root logger's configuration.
The log messages are printed in the terminal and saved in the file specified
by log_file_path (if not None) and printed. Note that logging use sys.stdout
to print logging.INFO messages, and stderr for the others. The root logger's
configuration is inherited by the loggers created by logging.getLogger(name).
Different formats are used to display messages on the terminal and on the log
file. For example, in the log file every entry has a timestamp which does not
appear in the terminal. Moreover, the log file always shows the module that
generate the message, while in the terminal this happens only for messages
of level WARNING and higher.
Parameters
----------
verbose : bool
Control the verbosity of the messages printed in the terminal. The logger
displays messages of level logging.INFO and higher when verbose=False.
Otherwise those of level logging.DEBUG and higher are printed.
log_file_path : str, optional, default = None
If not None, this is the path where all the logger's messages of level
logging.DEBUG or higher are saved.
"""
class TerminalFormatter(logging.Formatter):
"""
Simplified format for INFO and DEBUG level log messages.
This allows to keep the logging.info() and debug() format separated from
the other levels where more information may be needed. For example, for
warning and error messages it is convenient to know also the module that
generates them.
"""
# This is the cleanest way I found to make the code compatible with both
# Python 2 and Python 3
simple_fmt = logging.Formatter('%(asctime)-15s: %(message)s')
default_fmt = logging.Formatter('%(asctime)-15s: %(levelname)s - %(name)s - %(message)s')
def format(self, record):
if record.levelno <= logging.INFO:
return self.simple_fmt.format(record)
else:
return self.default_fmt.format(record)
# Check if root logger is already configured
n_handlers = len(logging.root.handlers)
if n_handlers > 0:
root_logger = logging.root
for i in range(n_handlers):
root_logger.removeHandler(root_logger.handlers[0])
# If this is a worker node, don't save any log file
import mpiplus
mpicomm = mpiplus.get_mpicomm()
if mpicomm:
rank = mpicomm.rank
else:
rank = 0
# Create different log files for each MPI process
if rank != 0 and log_file_path is not None:
basepath, ext = os.path.splitext(log_file_path)
log_file_path = '{}_{}{}'.format(basepath, rank, ext)
# Add handler for stdout and stderr messages
terminal_handler = logging.StreamHandler()
terminal_handler.setFormatter(TerminalFormatter())
if rank != 0:
terminal_handler.setLevel(logging.WARNING)
elif verbose:
terminal_handler.setLevel(logging.DEBUG)
else:
terminal_handler.setLevel(logging.INFO)
logging.root.addHandler(terminal_handler)
# Add file handler to root logger
file_format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
if log_file_path is not None:
file_handler = logging.FileHandler(log_file_path)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter(file_format))
logging.root.addHandler(file_handler)
# Do not handle logging.DEBUG at all if unnecessary
if log_file_path is not None:
logging.root.setLevel(logging.DEBUG)
else:
logging.root.setLevel(terminal_handler.level)
# Setup critical logger file if a logfile is specified
# No need to worry about MPI due to it already being set above
if log_file_path is not None:
basepath, ext = os.path.splitext(log_file_path)
critical_log_path = basepath + "_CRITICAL" + ext
# Create the critical file handler to only create the file IF a critical message is sent
critical_file_handler = logging.FileHandler(critical_log_path, delay=True)
critical_file_handler.setLevel(logging.CRITICAL)
# Add blank lines to space out critical errors
critical_file_format = file_format + "\n\n\n"
critical_file_handler.setFormatter(logging.Formatter(critical_file_format))
logging.root.addHandler(critical_file_handler)
# =======================================================================================
# Profiling Functions
# This is a series of functions and wrappers used for debugging, hence their private nature
# =======================================================================================
def _profile_block_separator_string(message):
"""Write a simple block spacing separator"""
import time
time_format = '%d %b %Y %H:%M:%S'
current_time = time.strftime(time_format)
spacing_min = 50
spacing = max(len(current_time), len(message), spacing_min)
filler = '{' + '0: ^{}'.format(spacing) + '}'
separator = '#' * (spacing + 2)
output_string = ''
output_string += separator + '\n'
output_string += '#' + filler.format(current_time) + '#\n'
output_string += '#' + filler.format(message) + '#\n'
output_string += separator + '\n'
return output_string
@contextlib.contextmanager
def _profile(output_file='profile.log'):
"""
Function that allows a ``with _profile():`` to wrap around a calls
Parameters
----------
output_file: str, Default: 'profile.log'
Name of the profile you want to write to
"""
# Imports only used for debugging, not making this part of the name space
import pstats
import cProfile
start_string = _profile_block_separator_string('START PROFILE')
pr = cProfile.Profile()
pr.enable()
yield
pr.disable()
end_string = _profile_block_separator_string('END PROFILE')
sort_by = ['filename', 'cumulative']
with open(output_file, 'a+') as s:
s.write(start_string)
ps = pstats.Stats(pr, stream=s).sort_stats(*sort_by)
ps.print_stats()
s.write(end_string)
def _with_profile(output_file='profile.log'):
"""Decorator that profiles the full function wrapper to :func:`_profile`
Parameters
----------
output_file: str, Default: 'profile.log'
Name of the profile you want to write to
"""
def __with_profile(func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
with _profile(output_file):
return func(*args, **kwargs)
return _wrapper
return __with_profile
# =======================================================================================
# Method dispatch wrapping to better handle cyclomatic complexity with input types
# Relies on functools.singledispatch (Python 3.5+ only)
# https://stackoverflow.com/questions/24601722/how-can-i-use-functools-singledispatch-with-instance-methods/24602374#24602374
# And the comment to use update_wrapper(wrapper, dispatcher) instead
# =======================================================================================
def methoddispatch(func):
dispatcher = functools.singledispatch(func)
def wrapper(*args, **kw):
return dispatcher.dispatch(args[1].__class__)(*args, **kw)
wrapper.register = dispatcher.register
functools.update_wrapper(wrapper, dispatcher)
return wrapper
# =======================================================================================
# Combinatorial tree
# =======================================================================================
class CombinatorialLeaf(list):
"""List type that can be expanded combinatorially in :class:`CombinatorialTree`."""
def __repr__(self):
return "Combinatorial({})".format(super(CombinatorialLeaf, self).__repr__())
class CombinatorialTree(collections.MutableMapping):
"""A tree that can be expanded in a combinatorial fashion.
Each tree node with its subnodes is represented as a nested dictionary. Nodes can be
accessed through their specific "path" (i.e. the list of the nested dictionary keys
that lead to the node value).
Values of a leaf nodes that are list-like objects can be expanded combinatorially in
the sense that it is possible to iterate over all possible combinations of trees that
are generated by taking leaf node list and create a sequence of trees, each one
defining only one of the single values in those lists per leaf node (see Examples).
Examples
--------
Set an arbitrary nested path
>>> tree = CombinatorialTree({'a': {'b': 2}})
>>> path = ('a', 'b')
>>> tree[path]
2
>>> tree[path] = 3
>>> tree[path]
3
Paths can be accessed also with the usual dict syntax
>>> tree['a']['b']
3
Deletion of a node leave an empty dict!
>>> del tree[path]
>>> print(tree)
{'a': {}}
Expand all possible combinations of a tree. The iterator return a dict, not another
CombinatorialTree object.
>>> import pprint # pprint sort the dictionary by key before printing
>>> tree = CombinatorialTree({'a': 1, 'b': CombinatorialLeaf([1, 2]),
... 'c': {'d': CombinatorialLeaf([3, 4])}})
>>> for t in tree:
... pprint.pprint(t)
{'a': 1, 'b': 1, 'c': {'d': 3}}
{'a': 1, 'b': 2, 'c': {'d': 3}}
{'a': 1, 'b': 1, 'c': {'d': 4}}
{'a': 1, 'b': 2, 'c': {'d': 4}}
Expand all possible combinations and assign unique names
>>> for name, t in tree.named_combinations(separator='_', max_name_length=5):
... print(name)
3_1
3_2
4_1
4_2
"""
def __init__(self, dictionary):
"""Build a combinatorial tree from the given dictionary."""
self._d = copy.deepcopy(dictionary)
def __getitem__(self, path):
try:
return self._d[path]
except KeyError:
return self._resolve_path(self._d, path)
def __setitem__(self, path, value):
d_node = self.__getitem__(path[:-1])
d_node[path[-1]] = value
def __delitem__(self, path):
d_node = self.__getitem__(path[:-1])
del d_node[path[-1]]
def __len__(self):
return len(self._d)
def __str__(self):
return str(self._d)
def __eq__(self, other):
return self._d == other
def __iter__(self):
"""Iterate over all possible combinations of trees.
The iterator returns dict objects, not other CombinatorialTrees.
"""
leaf_paths, leaf_vals = self._find_combinatorial_leaves()
return self._combinations_generator(leaf_paths, leaf_vals)
def named_combinations(self, separator, max_name_length):
"""Generator to iterate over all possible combinations of trees and assign them unique names.
The names are generated by gluing together the first letters of the values of
the combinatorial leaves only, separated by the given separator. If the values
contain special characters, they are ignored. Only letters, numbers and the
separator are found in the generated names. Values representing paths to
existing files contribute to the name only with they file name without extensions.
The iterator yields tuples of ``(name, dict)``, not other :class:`CombinatorialTree`'s. If
there is only a single combination, an empty string is returned for the name.
Parameters
----------
separator : str
The string used to separate the words in the name.
max_name_length : int
The maximum length of the generated names, excluding disambiguation number.
Yields
------
name : str
Unique name of the combination. Empty string returned if there is only one combination
combination : dict
Combination of leafs that was used to create the name
"""
leaf_paths, leaf_vals = self._find_combinatorial_leaves()
generated_names = {} # name: count, how many times we have generated the same name
# Compile regular expression used to discard special characters
filter = re.compile('[^A-Za-z\d]+')
# Iterate over combinations
for combination in self._combinations_generator(leaf_paths, leaf_vals):
# Retrieve single values of combinatorial leaves
filtered_vals = [str(self._resolve_path(combination, path)) for path in leaf_paths]
# Strip down file paths to only the file name without extensions
for i, val in enumerate(filtered_vals):
if os.path.exists(val):
filtered_vals[i] = os.path.basename(val).split(os.extsep)[0]
# Filter special characters in values that we don't use for names
filtered_vals = [filter.sub('', val) for val in filtered_vals]
# Generate name
if len(filtered_vals) == 0:
name = ''
elif len(filtered_vals) == 1:
name = filtered_vals[0][:max_name_length]
else:
name = separator.join(filtered_vals)
original_vals = filtered_vals[:]
while len(name) > max_name_length:
# Sort the strings by descending length, if two values have the
# same length put first the one whose original value is the shortest
sorted_vals = sorted(enumerate(filtered_vals), reverse=True,
key=lambda x: (len(x[1]), -len(original_vals[x[0]])))
# Find how many strings have the maximum length
max_val_length = len(sorted_vals[0][1])
n_max_vals = len([x for x in sorted_vals if len(x[1]) == max_val_length])
# We trim the longest str by the necessary number of characters
# to reach max_name_length or the second longest value
length_diff = len(name) - max_name_length
if n_max_vals < len(filtered_vals):
second_max_val_length = len(sorted_vals[n_max_vals][1])
length_diff = min(length_diff, max_val_length - second_max_val_length)
# Trim all the longest strings by few characters
for i in range(n_max_vals - 1, -1, -1):
# Division truncation ensures that we trim more the
# ones whose original value is the shortest
char_per_str = int(length_diff / (i + 1))
if char_per_str != 0:
idx = sorted_vals[i][0]
filtered_vals[idx] = filtered_vals[idx][:-char_per_str]
length_diff -= char_per_str
name = separator.join(filtered_vals)
if name in generated_names:
generated_names[name] += 1
name += separator + str(generated_names[name])
else:
generated_names[name] = 1
yield name, combination
def expand_id_nodes(self, id_nodes_path, update_nodes_paths):
"""Return a new :class:`CombinatorialTree` with id-bearing nodes expanded
and updated in the rest of the script.
Parameters
----------
id_nodes_path : tuple of str
The path to the parent node containing ids.
update_nodes_paths : list of tuple of str
A list of all the paths referring to the ids expanded. The string '*'
means every node.
Returns
-------
expanded_tree : CombinatorialTree
The tree with id nodes expanded.
Examples
--------
>>> d = {'molecules':
... {'mol1': {'mol_value': CombinatorialLeaf([1, 2])}},
... 'systems':
... {'sys1': {'molecules': 'mol1'},
... 'sys2': {'prmtopfile': 'mysystem.prmtop'}}}
>>> update_nodes_paths = [('systems', '*', 'molecules')]
>>> t = CombinatorialTree(d).expand_id_nodes('molecules', update_nodes_paths)
>>> t['molecules'] == {'mol1_1': {'mol_value': 1}, 'mol1_2': {'mol_value': 2}}
True
>>> t['systems'] == {'sys1': {'molecules': CombinatorialLeaf(['mol1_2', 'mol1_1'])},
... 'sys2': {'prmtopfile': 'mysystem.prmtop'}}
True
"""
expanded_tree = copy.deepcopy(self)
combinatorial_id_nodes = {} # map combinatorial_id -> list of combination_ids
for id_node_key, id_node_val in self.__getitem__(id_nodes_path).items():
# Find all combinations and expand them
id_node_val = CombinatorialTree(id_node_val)
combinations = {id_node_key + '_' + name: comb for name, comb
in id_node_val.named_combinations(separator='_', max_name_length=30)}
if len(combinations) > 1:
# Substitute combinatorial node with all combinations
del expanded_tree[id_nodes_path][id_node_key]
expanded_tree[id_nodes_path].update(combinations)
# We need the combinatorial_id_nodes substituted to an id_node_key
# to have a deterministic value or MPI parallel processes will
# iterate over combinations in different orders
combinatorial_id_nodes[id_node_key] = sorted(combinations.keys())
# Update ids in the rest of the tree
for update_path in update_nodes_paths:
for update_node_key, update_node_val in self._resolve_paths(self._d, update_path):
# Check if the value is a collection or a scalar
if isinstance(update_node_val, list):
for v in update_node_val:
if v in combinatorial_id_nodes:
i = expanded_tree[update_node_key].index(v)
expanded_tree[update_node_key][i:i+1] = combinatorial_id_nodes[v]
elif update_node_val in combinatorial_id_nodes:
comb_leaf = CombinatorialLeaf(combinatorial_id_nodes[update_node_val])
expanded_tree[update_node_key] = comb_leaf
return expanded_tree
@staticmethod
def _resolve_path(d, path):
"""Retrieve the value of a nested key in a dictionary.
Parameters
----------
d : dict
The nested dictionary.
path : iterable of keys
The "path" to the node of the dictionary.
Return
------
The value contained in the node pointed by the path.
"""
accum_value = d
for node_key in path:
accum_value = accum_value[node_key]
return accum_value
@staticmethod
def _resolve_paths(d, path):
"""Retrieve all the values of a nested key in a dictionary.
Paths containing the string '*' are interpreted as any node and
are yielded one by one.
Parameters
----------
d : dict
The nested dictionary.
path : iterable of str
The "path" to the node of the dictionary. The character '*'
means any node.
Examples
--------
>>> d = {'nested': {'correct1': {'a': 1}, 'correct2': {'a': 2}, 'wrong': {'b': 3}}}
>>> p = [x for x in CombinatorialTree._resolve_paths(d, ('nested', '*', 'a'))]
>>> print(sorted(p))
[(('nested', 'correct1', 'a'), 1), (('nested', 'correct2', 'a'), 2)]
"""
try:
if len(path) == 0:
yield (), d
elif len(path) == 1:
yield (path[0],), d[path[0]]
else:
if path[0] == '*':
keys = d.keys()
else:
keys = [path[0]]
for key in keys:
for p, v in CombinatorialTree._resolve_paths(d[key], path[1:]):
if v is not None:
yield (key,) + p, v
except KeyError:
yield None, None
def _find_leaves(self):
"""Traverse a dict tree and find the leaf nodes.
Returns
-------
A tuple containing two lists. The first one is a list of paths to the leaf
nodes in a tuple format (e.g. the path to ``node['a']['b']`` is ``('a', 'b')``) while
the second one is a list of all the values of those leaf nodes.
Examples
--------
>>> simple_tree = CombinatorialTree({'simple': {'scalar': 1,
... 'vector': [2, 3, 4],
... 'nested': {
... 'leaf': ['a', 'b', 'c']}}})
>>> leaf_paths, leaf_vals = simple_tree._find_leaves()
>>> leaf_paths
[('simple', 'scalar'), ('simple', 'vector'), ('simple', 'nested', 'leaf')]
>>> leaf_vals
[1, [2, 3, 4], ['a', 'b', 'c']]
"""
def recursive_find_leaves(node):
leaf_paths = []
leaf_vals = []
for child_key, child_val in node.items():
if isinstance(child_val, collections.Mapping):
subleaf_paths, subleaf_vals = recursive_find_leaves(child_val)
# prepend child key to path
leaf_paths.extend([(child_key,) + subleaf for subleaf in subleaf_paths])
leaf_vals.extend(subleaf_vals)
else:
leaf_paths.append((child_key,))
leaf_vals.append(child_val)
return leaf_paths, leaf_vals
return recursive_find_leaves(self._d)
def _find_combinatorial_leaves(self):
"""Traverse a dict tree and find CombinatorialLeaf nodes.
Returns
-------
combinatorial_leaf_paths, combinatorial_leaf_vals : tuple of tuples
``combinatorial_leaf_paths`` is a tuple of paths to combinatorial leaf
nodes in tuple format (e.g. the path to ``node['a']['b']`` is ``('a', 'b')``)
while ``combinatorial_leaf_vals`` is the tuple of the values of those nodes.
The list of paths is guaranteed to be sorted by alphabetical order.
"""
leaf_paths, leaf_vals = self._find_leaves()
# Filter leaves that are not combinatorial
combinatorial_ids = [i for i, val in enumerate(leaf_vals) if isinstance(val, CombinatorialLeaf)]
combinatorial_leaf_paths = [leaf_paths[i] for i in combinatorial_ids]
combinatorial_leaf_vals = [leaf_vals[i] for i in combinatorial_ids]
# Sort leaves by alphabetical order of the path
if len(combinatorial_leaf_paths) > 0:
combinatorial_leaf_paths, combinatorial_leaf_vals = zip(*sorted(zip(combinatorial_leaf_paths,
combinatorial_leaf_vals)))
return combinatorial_leaf_paths, combinatorial_leaf_vals
def _combinations_generator(self, leaf_paths, leaf_vals):
"""Generate all possible combinations of experiments.
The iterator returns dict objects, not other :class:`CombinatorialTree`s.
Parameters
----------
leaf_paths : list of tuples of strings
The list of paths as returned by _find_leaves().
leaf_vals : list
The list of the correspondent values as returned by _find_leaves().
"""
template_tree = CombinatorialTree(self._d)
# All leaf values must be CombinatorialLeafs at this point
assert all(isinstance(leaf_val, CombinatorialLeaf) for leaf_val in leaf_vals)
# generating all combinations
for combination in itertools.product(*leaf_vals):
# update values of template tree
for leaf_path, leaf_val in zip(leaf_paths, combination):
template_tree[leaf_path] = leaf_val
yield copy.deepcopy(template_tree._d)
# ========================================================================================
# Miscellaneous functions
# ========================================================================================
def get_data_filename(relative_path):
"""Get the full path to one of the reference files shipped for testing
In the source distribution, these files are in ``examples/*/``,
but on installation, they're moved to somewhere in the user's python
site-packages directory.
Parameters
----------
relative_path : str
Name of the file to load, with respect to the yank egg folder which
is typically located at something like
``~/anaconda/lib/python3.6/site-packages/yank-*.egg/examples/``
Returns
-------
fn : str
Resource Filename
"""
fn = resource_filename('yank', relative_path)
if not os.path.exists(fn):
raise ValueError("Sorry! {} does not exist. If you just added it, you'll have to re-install".format(fn))
return fn
def find_phases_in_store_directory(store_directory):
"""Build a list of phases in the store directory.
Parameters
----------
store_directory : str
The directory to examine for stored phase NetCDF data files.
Returns
-------
phases : dict of str
A dictionary phase_name -> file_path that maps phase names to its NetCDF
file path.
"""
full_paths = glob.glob(os.path.join(store_directory, '*.nc'))
phases = {}
for full_path in full_paths:
file_name = os.path.basename(full_path)
short_name, _ = os.path.splitext(file_name)
phases[short_name] = full_path
if len(phases) == 0:
raise RuntimeError("Could not find any valid YANK store (*.nc) files in "
"store directory: {}".format(store_directory))
return phases
def update_nested_dict(original, updated):
"""
Return a copy of a (possibly) nested dict of arbitrary depth
Parameters
----------
original : dict
Original dict which we want to update, can contain nested dicts
updated : dict
Dictionary of updated values to place in original
Returns
-------
new : dict
Copy of original with values updated from updated
"""
new = original.copy()
for key, value in updated.items():
if isinstance(value, collections.Mapping):
replacement = update_nested_dict(new.get(key, {}), value)
new[key] = replacement
else:
new[key] = updated[key]
return new
# ==============================================================================
# Conversion utilities
# ==============================================================================
def underscore_to_camelcase(underscore_str):
"""Convert the given string from ``underscore_case`` to ``camelCase``.
Underscores at the beginning or at the end of the string are ignored. All
underscores in the middle of the string are removed.
Parameters
----------
underscore_str : str
String in underscore_case to convert to camelCase style.
Returns
-------
camelcase_str : str
String in camelCase style.
Examples
--------
>>> underscore_to_camelcase('__my___variable_')
'__myVariable_'
"""
# Count leading and trailing '_' characters
n_leading = re.search(r'[^_]', underscore_str)
if n_leading is None: # this is empty or contains only '_'s
return underscore_str
n_leading = n_leading.start()
n_trailing = re.search(r'[^_]', underscore_str[::-1]).start()
# Remove all underscores, join and capitalize
words = underscore_str.split('_')
camelcase_str = '_' * n_leading + words[n_leading]
camelcase_str += ''.join(str.capitalize(word) for word in words[n_leading + 1:])
camelcase_str += '_' * n_trailing
return camelcase_str
def camelcase_to_underscore(camelcase_str):
"""Convert the given string from ``camelCase`` to ``underscore_case``.
Underscores at the beginning and end of the string are preserved. All capital letters are cast to lower case.
Parameters
----------
camelcase_str : str
String in camelCase to convert to underscore style.
Returns
-------
underscore_str : str
String in underscore style.
Examples
--------
>>> camelcase_to_underscore('myVariable')
'my_variable'
>>> camelcase_to_underscore('__my_Variable_')
'__my__variable_'
"""
underscore_str = re.sub(r'([A-Z])', '_\g<1>', camelcase_str)
return underscore_str.lower()
def quantity_from_string(expression, compatible_units=None):
"""Create a Quantity object from a string expression.
All the functions in the standard module math are available together
with most of the methods inside the ``simtk.unit`` module.
Parameters
----------
expression : str
The mathematical expression to rebuild a Quantity as a string.
compatible_units : simtk.unit.Unit, optional
If given, the result is checked for compatibility against the
specified units, and an exception raised if not compatible.
`Note`: The output is not converted to ``compatible_units``, they
are only used as a unit to validate the input.
Returns
-------
quantity
The result of the evaluated expression.
Raises
------
TypeError
If ``compatible_units`` is given and the quantity in expression is
either unit-less or has incompatible units.
Examples
--------
>>> expr = '4 * kilojoules / mole'
>>> quantity_from_string(expr)
Quantity(value=4.000000000000002, unit=kilojoule/mole)
>>> expr = '1.0*second'
>>> quantity_from_string(expr, compatible_units=unit.femtosecond)
Quantity(value=1.0, unit=second)
"""
# Retrieve units from unit module.
if not hasattr(quantity_from_string, '_units'):
units_tuples = inspect.getmembers(unit, lambda x: isinstance(x, unit.Unit))
quantity_from_string._units = dict(units_tuples)
# Eliminate nested quotes and excess whitespace
try:
expression = expression.strip('\'" ')
except AttributeError:
raise TypeError('The expression {} must be a string defining units, '
'not a {} instance'.format(expression, type(expression)))
# Handle a special case of the unit when it is just "inverse unit",
# e.g. Hz == /second
if expression[0] == '/':
expression = '(' + expression[1:] + ')**(-1)'
# Evaluate expressions.
quantity = mmtools.utils.math_eval(expression, variables=quantity_from_string._units)
# Check to make sure units are compatible with expected units.
if compatible_units is not None:
try:
is_compatible = quantity.unit.is_compatible(compatible_units)
except AttributeError:
raise TypeError("String {} does not have units attached.".format(expression))
if not is_compatible:
raise TypeError("Units of {} must be compatible with {}"
"".format(expression, str(compatible_units)))
return quantity
def get_keyword_args(function, try_mro_from_class=None):
"""Inspect function signature and return keyword args with their default values.
Parameters
----------
function : callable
The function to interrogate.
try_mro_from_class : any Class or None
Try and trace the method resolution order (MRO) of the ``function_to_inspect`` by inferring a method stack from
the supplied class.
The signature of the function is checked in every MRO up the stack so long as there exists as
``**kwargs`` in the method call. This is setting will yield expected results in every case, for instance, if
the method does not call `super()`, or the Super class has a different function name.
In the case of conflicting keywords, the lower MRO function is preferred.
Returns
-------
kwargs : dict
A dictionary ``{'keyword argument': 'default value'}``. The arguments of the
function that do not have a default value will not be included.
"""
def extract_kwargs(input_argspec):
defaults = input_argspec.defaults
if defaults is None:
defaults = []
n_defaults = len(defaults)
n_args = len(input_argspec.args)
# Cycle through the kwargs only
cycle_kwargs = input_argspec.args[n_args - n_defaults:]
cycle_kwargs = {arg: value for arg, value in zip(cycle_kwargs, defaults)}
# Handle the kwonlyargs for calls with `def F(a,b *args, x=True, **kwargs)
if input_argspec.kwonlydefaults is not None:
cycle_kwargs = {**cycle_kwargs, **input_argspec.kwonlydefaults}
return cycle_kwargs
agspec = inspect.getfullargspec(function)
kwargs = extract_kwargs(agspec)
if try_mro_from_class is not None and agspec.varkw is not None:
try:
mro = inspect.getmro(try_mro_from_class)
except AttributeError:
# No MRO
mro = [try_mro_from_class]
for cls in mro[1:]:
try:
parent_function = getattr(cls, function.__name__)
except AttributeError:
# Class does not have a method name
pass
else:
inner_argspec = inspect.getfullargspec(parent_function)
kwargs = {**extract_kwargs(inner_argspec), **kwargs}
return kwargs
def validate_parameters(parameters, template_parameters, check_unknown=False,
process_units_str=False, float_to_int=False,
ignore_none=True, special_conversions=None):
"""
Utility function for parameters and options validation.
Use the given template to filter the given parameters and infer their expected
types. Perform various automatic conversions when requested. If the template is
None, the parameter to validate is not checked for type compatibility.
Parameters
----------
parameters : dict
The parameters to validate.
template_parameters : dict
The template used to filter the parameters and infer the types.
check_unknown : bool
If True, an exception is raised when parameters contain a key that is not
contained in ``template_parameters``.
process_units_str: bool
If True, the function will attempt to convert the strings whose template
type is simtk.unit.Quantity.
float_to_int : bool
If True, floats in parameters whose template type is int are truncated.
ignore_none : bool
If True, the function do not process parameters whose value is None.
special_conversions : dict
Contains a converter function with signature convert(arg) that must be
applied to the parameters specified by the dictionary key.
Returns
-------
validate_par : dict
The converted parameters that are contained both in parameters and
``template_parameters``.
Raises
------
TypeError
If ``check_unknown`` is True and there are parameters not in ``template_parameters``.
ValueError
If a parameter has an incompatible type with its template parameter.
Examples
--------
Create the template parameters
>>> template_pars = dict()
>>> template_pars['bool'] = True
>>> template_pars['int'] = 2
>>> template_pars['unspecified'] = None # this won't be checked for type compatibility
>>> template_pars['to_be_converted'] = [1, 2, 3]
>>> template_pars['length'] = 2.0 * unit.nanometers
Now the parameters to validate
>>> input_pars = dict()
>>> input_pars['bool'] = None # this will be skipped with ignore_none=True
>>> input_pars['int'] = 4.3 # this will be truncated to 4 with float_to_int=True
>>> input_pars['unspecified'] = 'input' # this can be of any type since the template is None
>>> input_pars['to_be_converted'] = {'key': 3}
>>> input_pars['length'] = '1.0*nanometers'
>>> input_pars['unknown'] = 'test' # this will be silently filtered if check_unknown=False
Validate the parameters
>>> valid = validate_parameters(input_pars, template_pars, process_units_str=True,
... float_to_int=True, special_conversions={'to_be_converted': list})
>>> import pprint
>>> pprint.pprint(valid)
{'bool': None,
'int': 4,
'length': Quantity(value=1.0, unit=nanometer),
'to_be_converted': ['key'],
'unspecified': 'input'}
"""
if special_conversions is None:
special_conversions = {}
# Create validated parameters
validated_par = {par: parameters[par] for par in parameters
if par in template_parameters}
# Check for unknown parameters
if check_unknown and len(validated_par) < len(parameters):
diff = set(parameters) - set(template_parameters)
raise TypeError("found unknown parameter {}".format(', '.join(diff)))
for par, value in validated_par.items():
templ_value = template_parameters[par]
# Convert requested types
if ignore_none and value is None:
continue
# Special conversions have priority
if par in special_conversions:
converter_func = special_conversions[par]
validated_par[par] = converter_func(value)
else: # Automatic conversions and type checking
# bool inherits from int in Python so we can't simply use isinstance
if float_to_int and type(templ_value) is int:
validated_par[par] = int(value)
elif process_units_str and isinstance(templ_value, unit.Quantity):
validated_par[par] = quantity_from_string(value, templ_value.unit)
# Check for incompatible types
if type(validated_par[par]) != type(templ_value) and templ_value is not None:
raise ValueError("parameter {}={} is incompatible with {}".format(
par, validated_par[par], template_parameters[par]))
return validated_par
# ==============================================================================
# Stuff to move to openmoltools/ParmEd when they'll be stable
# ==============================================================================
class Mol2File(object):
"""Wrapper of ParmEd mol2 parser for easy manipulation of mol2 files.
This is not efficient as every operation access the file. The purpose
of this class is simply to provide a shortcut to read and write the mol2
file with a one-liner. If you need to do multiple operations before
saving the file, use ParmEd directly.
This works only for single-structure mol2 files.
Parameters
-----------
file_path : str
Path to the mol2 path.
Attributes
----------
resname
resnames
net_charge
"""
def __init__(self, file_path):
"""Constructor."""
self._file_path = file_path
@property
def resname(self):
"""The residue name of the first molecule found in the mol2 file.
This assumes that each molecule in the mol2 file has a single residue name.
"""
return next(self.resnames)
@property
def resnames(self):
"""Iterate over the names of all the molecules in the file (read-only).
This assumes that each molecule in the mol2 file has a single residue name.
"""
new_resname = False
with open(self._file_path, 'r') as f:
for line in f:
# If the previous line was the ATOM directive, yield the resname.
if new_resname:
# The residue name is the 8th word in the line.
yield line.split()[7]
new_resname = False
# Go on until you find an atom.
elif line.startswith('@<TRIPOS>'):
section = line[9:].strip()
if section == 'ATOM':
new_resname = True
@property
def net_charge(self):
"""Net charge of the file as a float (read-only)."""
structure = parmed.load_file(self._file_path, structure=True)
return self._compute_net_charge(structure)
def round_charge(self):
"""Round the net charge to the nearest integer to 6-digit precision.
Raises
------
RuntimeError
If the total net charge is far from the nearest integer by more
than 0.05.
"""
precision = 6
# Load mol2 file. We load as structure as residues are buggy (see ParmEd#898).
structure = parmed.load_file(self._file_path, structure=True)
old_net_charge = self._compute_net_charge(structure)
# We don't rewrite the mol2 file with ParmEd if the
# net charge is already within precision.
expected_net_charge = round(old_net_charge)
if abs(expected_net_charge - old_net_charge) < 10**(-precision):
return
# Convert to residue to use the fix_charges method.
residue_container = parmed.modeller.ResidueTemplateContainer.from_structure(structure)
if len(residue_container) > 1:
logging.warning("Found mol2 file with multiple residues. The charge of "
"each residue will be rounded to the nearest integer.")
# Round the net charge.
residue_container.fix_charges(precision=precision)
# Compute new net charge.
new_net_charge = self._compute_net_charge(residue_container)
logging.debug('Fixing net charge from {} to {}'.format(old_net_charge, new_net_charge))
# Something is wrong if the new rounded net charge is very different.
if abs(old_net_charge - new_net_charge) > 0.05:
raise RuntimeError('The rounded net charge is too different from the original one.')
# Copy new charges to structure.
for structure_residue, residue in zip(structure.residues, residue_container):
for structure_atom, atom in zip(structure_residue.atoms, residue.atoms):
structure_atom.charge = atom.charge
# Rewrite charges.
parmed.formats.Mol2File.write(structure, self._file_path)
@staticmethod
def _compute_net_charge(residue):
try:
tot_charge = sum(a.charge for a in residue.atoms)
except AttributeError: # residue is a ResidueTemplateContainer
tot_charge = 0.0
for res in residue:
tot_charge += sum(a.charge for a in res.atoms)
return tot_charge
def is_modeller_installed():
"""
Check if a Salilab Modeller tool is installed and Licensed.
If Modeller is not installed and licensed, returns False.
Returns
-------
installed : bool
True if all tools in ``oetools`` are installed and licensed, False otherwise.
"""
try:
import modeller
except:
# This has to be broad because we cant trap the ModellerError invalid license
# since the act of even trying to import Modeller triggers it,
# and its NOT an import error which is raised.
return False
return True
# -----------------
# OpenEye functions
# -----------------
def is_openeye_installed(oetools=('oechem', 'oequacpac', 'oeiupac', 'oeomega')):
"""
Check if a given OpenEye tool is installed and Licensed.
If the OpenEye toolkit is not installed, returns False.
Parameters
----------
oetools : str or iterable of strings, Optional, Default: ('oechem', 'oequacpac', 'oeiupac', 'oeomega')
Set of tools to check by their string name. Defaults to the
complete set that YANK *could* use, depending on feature requested.
Only checks the subset of tools if passed. Also accepts a single
tool to check as a string instead of an iterable of length 1.
Returns
-------
all_installed : bool
True if all tools in ``oetools`` are installed and licensed, False otherwise.
"""
# Complete list of module: License function name.
tools_license = {'oechem': 'OEChemIsLicensed',
'oequacpac': 'OEQuacPacIsLicensed',
'oeiupac': 'OEIUPACIsLicensed',
'oeomega': 'OEOmegaIsLicensed'}
# Cast oetools to tuple if its a single string.
if type(oetools) is str:
oetools = (oetools,)
# Check if the input oetools are known.
if not set(oetools).issubset(set(tools_license)):
raise ValueError("Expected an OpenEye tools subset of {}, but instead "
"got {}".format(tuple(tools_license), oetools))
# Try loading the module.
for tool in oetools:
try:
module = importlib.import_module('openeye.' + tool)
except ImportError:
return False
# Check that we have the license.
if not getattr(module, tools_license[tool])():
return False
return True
def load_oe_molecules(file_path, molecule_idx=None):
"""Read one or more molecules from a file.
Requires OpenEye Toolkit. Several formats are supported (including
mol2, sdf and pdb).
Parameters
----------
file_path : str
Complete path to the file on disk.
molecule_idx : None or int, optional, default: None
Index of the molecule on the file. If None, all of them are
returned.
Returns
-------
molecule : openeye.oechem.OEMol or list of openeye.oechem.OEMol
The molecules stored in the file. If molecule_idx is specified
only one molecule is returned, otherwise a list (even if the
file contain only 1 molecule).
"""
from openeye import oechem
extension = os.path.splitext(file_path)[1][1:] # Remove dot.
# Open input file stream
ifs = oechem.oemolistream()
if extension == 'mol2':
mol2_flavor = (oechem.OEIFlavor_Generic_Default |
oechem.OEIFlavor_MOL2_Default |
oechem.OEIFlavor_MOL2_Forcefield)
ifs.SetFlavor(oechem.OEFormat_MOL2, mol2_flavor)
if not ifs.open(file_path):
oechem.OEThrow.Fatal('Unable to open {}'.format(file_path))
# Read all molecules.
molecules = []
for mol in ifs.GetOEMols():
molecules.append(oechem.OEMol(mol))
# Select conformation of interest
if molecule_idx is not None:
return molecules[molecule_idx]
return molecules
def write_oe_molecule(oe_mol, file_path, mol2_resname=None):
"""Write all conformations in a file and automatically detects format.
Requires OpenEye Toolkit
Parameters
----------
oe_mol : OpenEye Molecule
Molecule to write to file
file_path : str
Complete path to file with filename and extension
mol2_resname : None or str, Optional, Default: None
Name to replace the residue name if the file is a .mol2 file
Requires ``file_path`` to match ``*mol2``
"""
from openeye import oechem
# Get correct OpenEye format
extension = os.path.splitext(file_path)[1][1:] # remove dot
oe_format = getattr(oechem, 'OEFormat_' + extension.upper())
# Open stream and write molecule
ofs = oechem.oemolostream()
ofs.SetFormat(oe_format)
if not ofs.open(file_path):
oechem.OEThrow.Fatal('Unable to create {}'.format(file_path))
oechem.OEWriteMolecule(ofs, oe_mol)
ofs.close()
# If this is a mol2 file, we need to replace the resname
# TODO when you merge to openmoltools, incapsulate this and add to molecule_to_mol2()
if mol2_resname is not None and extension == 'mol2':
with open(file_path, 'r') as f:
lines = f.readlines()
lines = [line.replace('<0>', mol2_resname) for line in lines]
with open(file_path, 'w') as f:
f.writelines(lines)
def get_oe_mol_positions(molecule, conformer_idx=0):
"""
Get the molecule positions from an OpenEye Molecule
Requires OpenEye Toolkit
Parameters
----------
molecule : OpenEye Molecule
Molecule to extract coordinates from
conformer_idx : int, Optional, Default: 0
Index of the conformer on the file, leave as 0 to not use
"""
from openeye import oechem
# Extract correct conformer
if conformer_idx > 0:
try:
if molecule.NumConfs() <= conformer_idx:
raise UnboundLocalError # same error message
molecule = oechem.OEGraphMol(molecule.GetConf(oechem.OEHasConfIdx(conformer_idx)))
except UnboundLocalError:
raise ValueError('conformer_idx {} out of range'.format(conformer_idx))
# Extract positions
oe_coords = oechem.OEFloatArray(3)
molecule_pos = np.zeros((molecule.NumAtoms(), 3))
for i, atom in enumerate(molecule.GetAtoms()):
molecule.GetCoords(atom, oe_coords)
molecule_pos[i] = oe_coords
return molecule_pos
def _sanitize_tleap_unit_name(func):
"""Decorator version of TLeap._sanitize_unit_name.
This takes as unit name a keyword argument called "unit_name" or the
second sequential argument (skipping self).
"""
@functools.wraps(func)
def _wrapper(*args, **kwargs):
try:
kwargs['unit_name'] = TLeap._sanitize_unit_name(kwargs['unit_name'])
except KeyError:
# Tuples are immutable so we need to use concatenation.
args = args[:1] + (TLeap._sanitize_unit_name(args[1]), ) + args[2:]
func(*args, **kwargs)
return _wrapper
class TLeap:
"""
Programmatic interface to write and run AmberTools' ``tLEaP`` scripts.
To avoid problems with special characters in file paths, the class run the
tleap script in a temporary folder with hardcoded names for files and then
copy the output files in their respective folders.
Attributes
----------
script
"""
@property
def script(self):
"""
Complete and return the finalized script string
Adds a ``quit`` command to the end of the script.
"""
return self._script.format(**self._file_paths) + '\nquit\n'
def __init__(self):
self._script = ''
self._file_paths = {} # paths of input/output files to copy in/from temp dir
self._loaded_parameters = set() # parameter files already loaded
def add_commands(self, *args):
"""
Append commands to the script
Parameters
----------
args : iterable of strings
Individual commands to add to the script written in full as strings.
Newline characters are added after each command
"""
for command in args:
self._script += command + '\n'
def load_parameters(self, *args):
"""
Load the LEaP parameters into the working TLEaP script if not already loaded
This adds to the script
Uses ``loadAmberParams`` for ``frcmod.*`` files
Uses ``loadOff`` for ``*.off`` and ``*.lib`` files
Uses ``source`` for other files.
Parameters
----------
args : iterable of strings
File names for each type of leap file that can be loaded.
Method to load them is automatically determined from file extension or base name
"""
for par_file in args:
# Check that this is not already loaded
if par_file in self._loaded_parameters:
continue
# Check whether this is a user file or a tleap file, and
# update list of input files to copy in temporary folder before run
if os.path.isfile(par_file):
local_name = 'moli{}'.format(len(self._file_paths))
self._file_paths[local_name] = par_file
local_name = '{' + local_name + '}'
else: # tleap file
local_name = par_file
# use loadAmberParams if this is a frcmod file and source otherwise
base_name = os.path.basename(par_file)
extension = os.path.splitext(base_name)[1]
if 'frcmod' in base_name or extension == '.dat':
self.add_commands('loadAmberParams ' + local_name)
elif extension == '.off' or extension == '.lib':
self.add_commands('loadOff ' + local_name)
else:
self.add_commands('source ' + local_name)
# Update loaded parameters cache
self._loaded_parameters.add(par_file)
@_sanitize_tleap_unit_name
def load_unit(self, unit_name, file_path):
"""
Load a Unit into LEaP, this is typically a molecule or small complex.
This adds to the script
Accepts ``*.mol2`` or ``*.pdb`` files
Parameters
----------
unit_name : str
Name of the unit as it should be represented in LEaP
file_path : str
Full file path with extension of the file to read into LEaP as a new unit
"""
extension = os.path.splitext(file_path)[1]
if extension == '.mol2':
load_command = 'loadMol2'
elif extension == '.pdb':
load_command = 'loadPdb'
else:
raise ValueError('cannot load format {} in tLeap'.format(extension))
local_name = 'moli{}'.format(len(self._file_paths))
self.add_commands('{} = {} {{{}}}'.format(unit_name, load_command, local_name))
# Update list of input files to copy in temporary folder before run
self._file_paths[local_name] = file_path
@_sanitize_tleap_unit_name
def combine(self, unit_name, *args):
"""
Combine units in LEaP
This adds to the script
Parameters
----------
unit_name : str
Name of LEaP unit to assign the combination to
args : iterable of strings
Name of LEaP units to combine into a single unit called leap_name
"""
# Sanitize unit names.
args = [self._sanitize_unit_name(arg) for arg in args]
components = ' '.join(args)
self.add_commands('{} = combine {{{{ {} }}}}'.format(unit_name, components))
@_sanitize_tleap_unit_name
def add_ions(self, unit_name, ion, num_ions=0, replace_solvent=False):
"""
Add ions to a unit in LEaP
This adds to the script
Parameters
----------
unit_name : str
Name of the existing LEaP unit which Ions will be added into
ion : str
LEaP recognized name of ion to add
num_ions : int, optional
Number of ions of type ion to add to unit_name. If 0, the unit
is neutralized (default is 0).
replace_solvent : bool, optional
If True, ions will replace solvent molecules rather than being
added.
"""
if replace_solvent:
self.add_commands('addIonsRand {} {} {}'.format(unit_name, ion, num_ions))
else:
self.add_commands('addIons2 {} {} {}'.format(unit_name, ion, num_ions))
@_sanitize_tleap_unit_name
def solvate(self, unit_name, solvent_model, clearance, box_geometry="cubic"):
"""
Solvate a unit in LEaP isometrically
This adds to the script
Parameters
----------
unit_name : str
Name of the existing LEaP unit which will be solvated
solvent_model : str
LEaP recognized name of the solvent model to use, e.g. "TIP3PBOX"
clearance : unit.Quantity
Add solvent up to clearance distance away (units of length) from the unit_name (radial)
box_geometry : "cubic" or "truncated_octahedral"
Shape of the box to be solvated (Default is "cubic").
"""
if box_geometry=="cubic":
solvate_command='solvateBox'
elif box_geometry=="truncated_octahedral":
solvate_command='solvateOct'
else:
raise ValueError('The argument box_geometry must take one of the following values: \
"cubic" or "truncated_octahedral".')
self.add_commands('{} {} {} {} iso'.format(solvate_command, unit_name,
solvent_model, str(clearance.value_in_unit(unit.angstroms))))
@_sanitize_tleap_unit_name
def save_unit(self, unit_name, output_path):
"""
Write a LEaP unit to file.
Accepts either ``*.prmtop``, ``*.inpcrd``, or ``*.pdb`` files
This adds to the script
Parameters
----------
unit_name : str
Name of the unit to save
output_path : str
Full file path with extension to save.
Outputs with multiple files (e.g. Amber Parameters) have their names derived from this instead
"""
file_name = os.path.basename(output_path)
file_name, extension = os.path.splitext(file_name)
local_name = 'molo{}'.format(len(self._file_paths))
# Update list of output files to copy from temporary folder after run
self._file_paths[local_name] = output_path
# Add command
if extension == '.prmtop' or extension == '.inpcrd':
local_name2 = 'molo{}'.format(len(self._file_paths))
command = 'saveAmberParm ' + unit_name + ' {{{}}} {{{}}}'
# Update list of output files with the one not explicit
if extension == '.inpcrd':
extension2 = '.prmtop'
command = command.format(local_name2, local_name)
else:
extension2 = '.inpcrd'
command = command.format(local_name, local_name2)
output_path2 = os.path.join(os.path.dirname(output_path), file_name + extension2)
self._file_paths[local_name2] = output_path2
self.add_commands(command)
elif extension == '.pdb':
self.add_commands('savePDB {} {{{}}}'.format(unit_name, local_name))
else:
raise ValueError('cannot export format {} from tLeap'.format(extension[1:]))
@_sanitize_tleap_unit_name
def transform(self, unit_name, transformation):
"""Transformation is an array-like representing the affine transformation matrix."""
command = 'transform {} {}'.format(unit_name, transformation)
command = command.replace(r'[', '{{').replace(r']', '}}')
command = command.replace('\n', '').replace(' ', ' ')
self.add_commands(command)
def new_section(self, comment):
"""Adds a comment line to the script"""
self.add_commands('\n# ' + comment)
def export_script(self, file_path):
"""
Write script to file
Parameters
----------
file_path : str
Full file path with extension of the script to save
"""
with open(file_path, 'w') as f:
f.write(self.script)
def run(self):
"""Run script and return warning messages in leap log file."""
def create_dirs_and_copy(path_to_copy, copied_path):
"""Create directories before copying the file."""
output_dir_path = os.path.dirname(copied_path)
if not os.path.isdir(output_dir_path):
os.makedirs(output_dir_path)
shutil.copy(path_to_copy, copied_path)
# Transform paths in absolute paths since we'll change the working directory
input_files = {local + os.path.splitext(path)[1]: os.path.abspath(path)
for local, path in self._file_paths.items() if 'moli' in local}
output_files = {local + os.path.splitext(path)[1]: os.path.abspath(path)
for local, path in self._file_paths.items() if 'molo' in local}
# Resolve all the names in the script
local_files = {local: local + os.path.splitext(path)[1]
for local, path in self._file_paths.items()}
script = self._script.format(**local_files) + 'quit\n'
with mdtraj.utils.enter_temp_directory():
# Copy input files
for local_file, file_path in input_files.items():
shutil.copy(file_path, local_file)
# Save script and run tleap
with open('leap.in', 'w') as f:
f.write(script)
leap_output = subprocess.check_output(['tleap', '-f', 'leap.in']).decode()
# Save leap.log in directory of first output file
log_path = ''
if len(output_files) > 0:
# Get first output path in Py 3.X way that is also thread-safe
for val in output_files.values():
first_output_path = val
break
first_output_name = os.path.basename(first_output_path).split('.')[0]
first_output_dir = os.path.dirname(first_output_path)
log_path = os.path.join(first_output_dir, first_output_name + '.leap.log')
create_dirs_and_copy('leap.log', log_path)
# Copy back output files. If something goes wrong, some files may not exist
known_error_msg = []
try:
for local_file, file_path in output_files.items():
create_dirs_and_copy(local_file, file_path)
except IOError:
known_error_msg.append("Could not create one of the system files.")
# Look for errors in log that don't raise CalledProcessError
error_patterns = ['Argument #\d+ is type \S+ must be of type: \S+']
for pattern in error_patterns:
m = re.search(pattern, leap_output)
if m is not None:
known_error_msg.append(m.group(0))
break
# Analyze log file for water mismatch
m = re.search("Could not find bond parameter for: EP - \w+W", leap_output)
if m is not None:
# Found mismatch water and missing parameters
known_error_msg.append('It looks like the water used has virtual sites, but '
'missing parameters.\nMake sure your leap parameters '
'use the correct water model as specified by '
'solvent_model.')
if len(known_error_msg) > 0:
final_error = ('Some things went wrong with LEaP\nWe caught a few but their may be more.\n'
'Please see the log file for LEaP for more info:\n{}\n============\n{}')
raise RuntimeError(final_error.format(log_path, '\n---------\n'.join(known_error_msg)))
# Check for and return warnings
return re.findall('WARNING: (.+)', leap_output)
@staticmethod
def _sanitize_unit_name(unit_name):
"""Sanitize tleap unit names.
Leap doesn't like names that start with digits so, in this case, we
prepend an arbitrary character.
This takes as unit name a keyword argument called "unit_name" or the
second sequential argument (skipping self).
"""
if unit_name[0].isdigit():
unit_name = 'M' + unit_name
return unit_name
def generate_development_feature(feature_dict):
"""
Helper function for generating a class which can flag classes, tests, and functions that are developmental.
Output class not quite a mixin because it has to be the first class due to the `__init__` flag
Parameters
----------
feature_dict : dict
Dictionary of form "test_string : pre-computed test" where "test_string" is just an identifier and
"pre-computed test" is a boolean-like object, usually the result of some test. All pre-computed tests will
be cast to bool
Returns
-------
DevelopmentFeature : class
Class which checks against the feature_dict and can be used in several ways:
* Class Inherited: When inherited as a class, calling its ``__init__()`` will raise an error if features are
not met
* True/False check function: When calling ``dev_validate()`` will return bool if all features are true.
* True/False decorator: When decorating function with ``dev_validation``, function will only be called if
``dev_validate()`` would return True, otherwise simply returns. Helpful for running tests.
* Dict of reasons: Property ``dev_reasons`` will return the dictionary of failed dependencies
* Dict of all: Property ``dev_features`` will return the dictionary of features it expects and their tests
With the exception of the `__init__``, all other functions are properties are Class based and do not
require instantiation. Function names are all given the `dev_` prefix to avoid clashes with other names its
a part of its psudo-mixin properties
"""
base_err = ('This feature cannot be used because it has been marked as "Developmental" and '
'the following conditions have not been met:\n')
valid = True # Assume valid until proven otherwise
check_dict = {}
for test_string, test in feature_dict.items():
test = bool(test) # Cast tests to bool
if not test:
base_err += "\t- {}: {}\n".format(test_string, test) # Add to error message
valid = False # Not all features met
check_dict = {**check_dict, **{test_string: test}} # Create check sub-dict
class DevelopmentFeature(object):
DEV_ERROR = base_err if not valid else None
dev_reasons = check_dict
dev_features = feature_dict
dev_validate = valid
def __init__(self, *args, **kwargs):
if not self.dev_validate:
raise RuntimeError(self.DEV_ERROR)
@classmethod
def dev_validation(cls, wrapped_function):
"""
Decorator function which will only execute the wrapped_function if ``validate()`` is true, else will
do nothing.
"""
def _empty_function(*args, **kwargs):
return
if cls.dev_validate:
return wrapped_function
else:
return _empty_function
return DevelopmentFeature
# =============================================================================================
# Main and tests
# =============================================================================================
if __name__ == "__main__":
import doctest
doctest.testmod()