mdsuite/transformations/transformations.py
"""
MDSuite: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Parent class for the transformations.
"""
from __future__ import annotations
import abc
import collections.abc
import copy
import logging
import os
import time
import typing
from typing import TYPE_CHECKING, Union
import numpy as np
import tensorflow as tf
import tqdm
import mdsuite.database.simulation_database
from mdsuite.database.data_manager import DataManager
from mdsuite.database.simulation_database import Database
from mdsuite.memory_management.memory_manager import MemoryManager
from mdsuite.utils import DatasetKeys
from mdsuite.utils.meta_functions import join_path
if TYPE_CHECKING:
from mdsuite.experiment import Experiment
class CannotFindTransformationError(Exception):
"""Exception for when a transformation cannot be found."""
pass
class CannotFindPropertyError(Exception):
"""Exception for when a property cannot be found."""
pass
class Transformations:
"""
Parent class for MDSuite transformations.
Attributes
----------
_database : Database
database class object for data loading and storing
_experiment : object
Experiment class instance to update
batch_size : int
batch size for the computation
n_batches : int
Number of batches to be looped over
remainder : int
Remainder amount to add after the batches are looped over.
data_manager : DataManager
data manager for handling the data transfer
memory_manager : MemoryManager
memory manager for the computation.
"""
def __init__(
self,
input_properties: typing.Iterable[
mdsuite.database.simulation_database.PropertyInfo
] = None,
output_property: mdsuite.database.simulation_database.PropertyInfo = None,
scale_function=None,
dtype=tf.float64,
):
"""
Init of the transformer base class.
Parameters
----------
input_properties : typing.Iterable[
mdsuite.database.simulation_database.PropertyInfo]
The properties needed to perform the transformation.
e.g. unwrapped positions for the wrap_coordinates transformation.
Properties from this list are provided to the self.transform_batch(),
in which the actual transformation happens.
output_property : mdsuite.database.simulation_database.PropertyInfo
The property that is the result of the transformation
scale_function :
specifies memory requirements of the transformation
dtype :
data type of the processed values
"""
self._experiment = None
self._database = None
self.input_properties = input_properties
self.output_property = output_property
self.logger = logging.getLogger(__name__)
self.scale_function = scale_function
self.dtype = dtype
self.batch_size: int
self.n_batches: int
self.remainder: int
self.offset = 0
self.data_manager: DataManager
self.memory_manager: MemoryManager
@property
def database(self):
"""Update the database.
replace for https://github.com/zincware/MDSuite/issues/404
"""
if self._database is None:
self._database = Database(self.experiment.database_path / "database.hdf5")
return self._database
@property
def experiment(self) -> Experiment:
"""TODO replace for https://github.com/zincware/MDSuite/issues/404."""
return self._experiment
@experiment.setter
def experiment(self, value):
self._experiment = value
def _run_dataset_check(self, path: str):
"""
Check to see if the database dataset already exists. If it does, the
transformation should extend the dataset and add data to the end of
it rather than try to add data.
Parameters
----------
path : str
dataset path to check.
Returns
-------
outcome : bool
If True, the dataset already exists and should be extended.
If False, a new dataset should be built.
"""
return self.database.check_existence(path)
def _save_output(
self,
data: Union[tf.Tensor, np.array],
index: int,
data_structure: dict,
):
"""
Save the tensor_values into the database_path
# todo for the future: this should not be part of the transformation.
# the transformation should yield a batch and the experiment should take care of
# storing it in the correct place, just as with file inputs
Returns
-------
saves the tensor_values to the database_path.
"""
# turn data into trajectory chunk
# data_structure is dict {'/path/to/property':{'indices':irrelevant,
# 'columns':deduce->deduce n_dims, 'length':n_particles}
species_list = []
# data structure only has 1 element
key, val = list(data_structure.items())[0]
path = str(copy.copy(key))
path.rstrip("/")
path = path.split("/")
prop_name = path[-1]
sp_name = path[-2]
n_particles = val.get("length")
if n_particles is None:
try:
# if length is not available try indices next
n_particles = len(val.get("indices"))
except TypeError:
raise TypeError("Could not determine number of particles")
if len(np.shape(data)) == 2:
# data not for multiple particles, instead one value for all
# -> create the n_particle axis
data = data[np.newaxis, :, :]
prop = mdsuite.database.simulation_database.PropertyInfo(
name=prop_name, n_dims=len(val["columns"])
)
species_list.append(
mdsuite.database.simulation_database.SpeciesInfo(
name=sp_name, properties=[prop], n_particles=n_particles
)
)
chunk = mdsuite.database.simulation_database.TrajectoryChunkData(
chunk_size=np.shape(data)[1], species_list=species_list
)
# data comes from transformation with time in 1st axis, add_data needs it
# in 0th axis
chunk.add_data(
data=np.swapaxes(data, 0, 1),
config_idx=0,
species_name=sp_name,
property_name=prop_name,
)
try:
self.database.add_data(chunk=chunk)
except OSError:
"""
This is used because in Windows and in WSL we got the error that
the file was still open while it should already be closed. So, we
wait, and we add again.
"""
time.sleep(0.5)
self.database.add_data(chunk=chunk)
def _prepare_monitors(self, data_path: Union[list, np.array]):
"""
Prepare the tensor_values and memory managers.
Parameters
----------
data_path : list
List of tensor_values paths to load from the hdf5
database_path.
Returns
-------
"""
self.memory_manager = MemoryManager(
data_path=data_path,
database=self.database,
memory_fraction=0.5,
scale_function=self.scale_function,
offset=self.offset,
)
(
self.batch_size,
self.n_batches,
self.remainder,
) = self.memory_manager.get_batch_size()
self.data_manager = DataManager(
data_path=data_path,
data_range=1,
database=self.database,
batch_size=self.batch_size,
n_batches=self.n_batches,
remainder=self.remainder,
offset=self.offset,
)
def _prepare_database_entry(self, species: str, system_tensor=False):
"""
Add or extend the dataset in which the transformation result is stored.
Parameters
----------
species : str
Species for which transformation is performed
Returns
-------
tensor_values structure for use in saving the tensor_values to the
database_path.
"""
if system_tensor:
output_length = 1
path = join_path(DatasetKeys.OBSERVABLES, self.output_property.name)
else:
try:
output_length = self.experiment.species[species].n_particles
except KeyError:
output_length = self.experiment.molecules[species].n_particles
path = join_path(species, self.output_property.name)
n_dims = self.output_property.n_dims
existing = self._run_dataset_check(path)
if existing:
old_shape = self.database.get_data_size(path)
resize_structure = {
path: (
output_length,
self.experiment.number_of_configurations - old_shape[0],
n_dims,
)
}
self.offset = old_shape[0]
self.database.resize_datasets(resize_structure)
else:
number_of_configurations = self.experiment.number_of_configurations
dataset_structure = {path: (output_length, number_of_configurations, n_dims)}
self.database.add_dataset(dataset_structure)
data_structure = {
path: {
"indices": np.s_[:],
"columns": list(range(n_dims)),
"length": output_length,
}
}
return data_structure
def find_property_per_config(self, sp_name, prop) -> typing.Union[None, str]:
would_be_path = join_path(sp_name, prop.name)
if self.database.check_existence(would_be_path):
return would_be_path
else:
return None
def find_property_single_val(self, sp_name, prop):
# TODO: properties in species_dict are all lowercase,
# whereas the properties in the database are upper case
try:
species_info = self.experiment.species[sp_name]
except KeyError:
species_info = self.experiment.molecules[sp_name]
per_sp_value = getattr(species_info, prop.name.lower(), None)
if per_sp_value is not None:
return per_sp_value
# if not species specific, try to find it system-wide
# todo need a dictionary of these values instead of the experiment property
try:
return self.experiment.__getattribute__(prop.name.lower())
except AttributeError:
return None
def get_prop_through_transformation(self, sp_name, prop):
# todo prevent infinite recursion
# (e.g. unwrap_pos calls wrap_pos calls unwrap_pos calls ...)
from mdsuite.transformations.transformation_dict import (
property_to_transformation_dict,
)
if prop in property_to_transformation_dict.keys():
trafo_cls = property_to_transformation_dict[prop]
if not isinstance(trafo_cls, typing.Iterable):
# simple case: only one trafo for the property
trafo = trafo_cls()
self.experiment.cls_transformation_run(trafo, species=[sp_name])
else:
# go through the list until one works
one_trafo_worked = False
for trafo_class in trafo_cls:
trafo = trafo_class()
try:
self.experiment.cls_transformation_run(trafo, species=[sp_name])
except CannotFindPropertyError:
pass
else:
one_trafo_worked = True
break
if not one_trafo_worked:
raise CannotFindTransformationError(
f"was asked to get '{prop.name}' for '{sp_name}'. "
f"There are transformations to get this property ({trafo_cls}), "
"but none of them have the required data"
)
return self.find_property_per_config(sp_name, prop)
else:
raise CannotFindTransformationError(
f"was asked to get '{prop.name}' for '{sp_name}', but there is no"
" transformation to get that property"
)
def get_generator_type_spec_and_const_data(self, species_names):
type_spec = {}
const_input_data = {}
for species_name in species_names:
const_input_data[species_name] = {}
for prop in self.input_properties:
# find out if the requested input data is there for all time steps
# for the requested species
path = self.find_property_per_config(species_name, prop)
if path is not None:
type_spec[str.encode(path)] = tf.TensorSpec(
shape=(None, None, prop.n_dims), dtype=self.dtype
)
# if not, fall back to const value for that species
else:
val = self.find_property_single_val(species_name, prop)
if val is not None:
# give single value the same dimensionality as if it was there
# for each time step and particle (i.e. add 2 axes)
if not isinstance(val, collections.abc.Iterable):
val = [val]
val = tf.convert_to_tensor(val, dtype=self.dtype)
val = val[None, None, :]
const_input_data[species_name].update({prop.name: val})
# if not there, try to produce the data
else:
try:
path = self.get_prop_through_transformation(
species_name, prop
)
type_spec[str.encode(path)] = tf.TensorSpec(
shape=(None, None, prop.n_dims), dtype=self.dtype
)
except CannotFindTransformationError:
raise CannotFindPropertyError(
"While performing transformation"
f" '{self.output_property.name}': Property '{prop.name}'"
f" for species '{species_name}' cannot be found in the"
" simulation database nor in the simulation metadata, nor"
" can it be obtained by a transformation"
)
return type_spec, const_input_data
@abc.abstractmethod
def run_transformation(self, species: typing.Iterable[str] = None):
raise NotImplementedError
class SingleSpeciesTrafo(Transformations):
"""
Base class for transformations where the transformation is applied to each species
separately.
"""
def run_transformation(self, species: typing.Iterable[str] = None):
"""
Perform the batching and data loading for the transformation,
then calls transform_batch
Parameters
----------
species : Iterable[str]
Names of the species on which to perform the transformation.
Returns
-------
"""
# species should be provided by caller (the experiment), for now we use the usual
# pseudoglobal variable
if species is None:
species = self.experiment.species.keys()
for species_name in species:
# this check should be done by the caller
if self.database.check_existence(
os.path.join(species_name, self.output_property.name)
):
self.logger.info(
f"{self.output_property.name} already exists for {species_name}, "
"skipping transformation"
)
continue
output_data_structure = self._prepare_database_entry(species_name)
type_spec, const_input_data = self.get_generator_type_spec_and_const_data(
[species_name]
)
const_input_data = const_input_data[species_name]
self._prepare_monitors(list(type_spec.keys()))
batch_generator, batch_generator_args = self.data_manager.batch_generator()
type_spec.update(
{
str.encode("data_size"): tf.TensorSpec(shape=(), dtype=tf.int32),
}
)
data_set = tf.data.Dataset.from_generator(
batch_generator, args=batch_generator_args, output_signature=type_spec
)
data_set = data_set.prefetch(tf.data.experimental.AUTOTUNE)
carryover = None
for index, batch_dict in tqdm.tqdm(
enumerate(data_set),
ncols=70,
desc=(
f"Applying transformation '{self.output_property.name}' to"
f" '{species_name}'"
),
total=self.n_batches,
):
# remove species information from batch:
# the transformation only has to know about the property
# ideally, the keys of the batch dict are already PropertyInfo instances
batch_dict.pop(str.encode("data_size"))
batch_dict_wo_species = {}
for key, val in batch_dict.items():
batch_dict_wo_species[key.decode().split("/")[-1]] = val
batch_dict_wo_species.update(const_input_data)
ret = self.transform_batch(batch_dict_wo_species, carryover=carryover)
if isinstance(ret, tuple):
transformed_batch, carryover = ret
else:
transformed_batch = ret
self._save_output(
data=transformed_batch,
data_structure=output_data_structure,
index=index * self.batch_size,
)
@abc.abstractmethod
def transform_batch(
self, batch: typing.Dict[str, tf.Tensor], carryover: typing.Any = None
) -> typing.Union[tf.Tensor, typing.Tuple[tf.Tensor, typing.Any]]:
"""
Do the actual transformation.
Parameters
----------
batch : dict
The batch to be transformed. structure is
{'Property1': tansordata, ...}
carryover : any
if the transformation batching is only possible with carryover,
this argument will provide it.
Returns
-------
Either the transformed batch (tf.Tensor)
Or tuple of (<transformed batch>, <carryover>),
where the carryover can have any type.
The carryover will be used as the optional argument for the next batch
"""
raise NotImplementedError("transformation of a batch must be implemented")
class MultiSpeciesTrafo(Transformations):
"""
Base class for all transformations, where information of multiple species is combined
in the transformation of a new property.
"""
def run_transformation(self, species: typing.Iterable[str] = None) -> None:
"""
Perform the batching and data loading for the transformation,
then calls transform_batch
Parameters
----------
species : Iterable[str]
Names of the species on which to perform the transformation.
Returns
-------
"""
# species should be provided by caller (the experiment), for now we use the usual
# pseudoglobal variable
if species is None:
species = self.experiment.species.keys()
# this check should be done by the caller
if self.database.check_existence(
os.path.join(self.output_property.name, self.output_property.name)
):
self.logger.info(
f"{self.output_property.name} already exists for this experiment, "
"skipping transformation"
)
return
output_data_structure = self._prepare_database_entry(
self.output_property.name, system_tensor=True
)
type_spec, const_input_data = self.get_generator_type_spec_and_const_data(species)
self._prepare_monitors(list(type_spec.keys()))
batch_generator, batch_generator_args = self.data_manager.batch_generator()
type_spec.update(
{
str.encode("data_size"): tf.TensorSpec(shape=(), dtype=tf.int32),
}
)
data_set = tf.data.Dataset.from_generator(
batch_generator, args=batch_generator_args, output_signature=type_spec
)
data_set = data_set.prefetch(tf.data.experimental.AUTOTUNE)
carryover = None
for index, batch_dict in tqdm.tqdm(
enumerate(data_set),
ncols=70,
desc=f"Applying transformation '{self.output_property.name}'",
total=self.n_batches,
):
batch_dict.pop(str.encode("data_size"))
batch_dict_hierachical = {sp_name: {} for sp_name in species}
for key, val in batch_dict.items():
sp_name, prop_name = key.decode().split("/")
batch_dict_hierachical[sp_name][prop_name] = val
for sp_name in batch_dict_hierachical.keys():
batch_dict_hierachical[sp_name].update(const_input_data[sp_name])
ret = self.transform_batch(batch_dict_hierachical, carryover=carryover)
if isinstance(ret, tuple):
transformed_batch, carryover = ret
else:
transformed_batch = ret
self._save_output(
data=transformed_batch,
data_structure=output_data_structure,
index=index * self.batch_size,
)
@abc.abstractmethod
def transform_batch(
self,
batch: typing.Dict[str, typing.Dict[str, tf.Tensor]],
carryover: typing.Any = None,
) -> tf.Tensor | typing.Tuple[tf.Tensor, typing.Any]:
"""
Do the actual transformation.
Parameters
----------
batch : dict
The batch to be transformed. structure is
{'Species1': {'Property1': tensordata, ...}, ...}
carryover : any
if the transformation batching is only possible with carryover,
this argument will provide it, see below.
Returns
-------
Either the transformed batch (tf.Tensor)
Or tuple of (<transformed batch>, <carryover>),
where the carryover can have any type.
The carryover will be used as the optional argument for the next batch
"""
raise NotImplementedError("transformation of a batch must be implemented")