zincware/MDSuite

View on GitHub
mdsuite/calculators/calculator.py

Summary

Maintainability
A
1 hr
Test Coverage
"""
MDSuite: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
Parent class for the calculators.
"""
from __future__ import annotations

import functools
import logging
import warnings
from typing import TYPE_CHECKING, Dict, List, Union

import numpy as np
import tensorflow as tf
from tqdm import tqdm

import mdsuite.database.scheme as db
from mdsuite.database.calculator_database import CalculatorDatabase
from mdsuite.visualizer.d2_data_visualization import DataVisualizer2D

if TYPE_CHECKING:
    from mdsuite import Experiment

tqdm.monitor_interval = 0
warnings.filterwarnings("ignore")

log = logging.getLogger(__name__)


def call(func):
    """
    Decorator for the calculator call method.

    This decorator provides a unified approach for handling run_computation and
    load_data for a single or multiple experiments.
    It handles the `run.<calc>()` method, iterates over experiments and
    loads data if requested! Therefore, the __call__ method does not and can
    not return any values anymore!


    Notes
    -----
    When calling the calculator it will check if a computation with the given
    user arguments was already performed:
    >>> Calculator.get_computation_data() is not None

    if no computations are available it will
    1. prepare a database entry
    >>> Calculator.prepare_db_entry()
    2. save the user arguments
    >>> Calculator.save_computation_args()
    3. Run the analysis
    >>> Calculator.run_analysis()
    4. Save all the data to the database
    >>> Calculator.save_db_data()
    5. Finally query the the data from the database and pass them to the user / plotting
    >>> data = Calculator.get_computation_data()




    Parameters
    ----------
    func: Calculator.__call__ method

    Returns
    -------
    decorated __call__ method

    """

    @functools.wraps(func)
    def inner(self, *args, **kwargs) -> Union[db.Computation, Dict[str, db.Computation]]:
        """Manage the call method.

        Parameters
        ----------
        self: Calculator

        Returns
        -------
        data:
            A dictionary of shape {name: data} when called from the project class
            A list of [data] when called directly from the experiment class
        """
        # This is only true, when called via project.experiments.<exp>.run,
        #  otherwise the experiment will be None
        return_dict = self.experiment is None

        out = {}
        for experiment in self.experiments:
            CLS = self.__class__
            # NOTE: if the calculator accepts more than just experiment/experiments
            #  as init, this has to be changed!
            cls = CLS(experiment=experiment)
            # pass the user args to the calculator
            func(cls, *args, **kwargs)
            data = cls.get_computation_data()
            if data is None:
                # new calculation will be performed
                cls.prepare_db_entry()
                cls.save_computation_args()
                cls.run_analysis()
                cls.save_db_data()
                # Need to reset the user args, if they got change
                # or set to defaults, e.g. n_configurations = - 1 so
                # that they match the query
                func(cls, *args, **kwargs)
                data = cls.get_computation_data()

            if cls.plot:
                """Plot the data"""
                cls.plotter = DataVisualizer2D(
                    title=cls.analysis_name, path=experiment.figures_path
                )
                cls.plot_data(data.data_dict)
                cls.plotter.grid_show(cls.plot_array)

            out[cls.experiment.name] = data

        if return_dict:
            return out
        else:
            return out[self.experiment.name]

    return inner


class Calculator(CalculatorDatabase):
    """
    Parent class for analysis modules.

    Attributes
    ----------
    experiment : Experiment
                Experiment for which the calculator will be run.
    experiments : List[Experiment]
            List of experiments on which to run the calculator.
    plot : bool
            If true, the results will be plotted.
    system_property: bool (default = False)
            If the calculator returns a value for the whole system such as ionic
            conductivity or viscosity as opposed to a species-specific number.
    experimental : bool (default = False)
            If true, a warning is raised upon calling this calculator with more
            information about why it is experimental.
    selected_species: tuple
            Species currently being studied in a specific loop.
    analysis_name: str
            Name of the analysis to store in the database.
    time : np.ndarray
            Time array over which to integrate and plot.
    plotter : DataVisualizer2D
            Data visualizer class for use in the plotting.
    result_keys : list
            keys to use when storing the results. e.g.
            ["diffusion_coefficient", "uncertainty"]
    result_series_keys : list
            keys to use when storing series results e.g.
            ["time", "msd"]
    prefactor : float (optional)
            can be set if the same pre-factor is required many times.
    x_label : str
            x-label for the plots.
    y_label : str
            y-label for the plots.
    _dtype : object = tf.float64
            dtype required by the analysis.
    plot_array : list
            A list of plot objects to be show together at the end of the
            species loop.
    """

    def __init__(
        self, experiment: Experiment = None, experiments: List[Experiment] = None
    ):
        """
        Constructor for the calculator class.

        Parameters
        ----------
        experiment : Experiment
                Experiment for which the calculator will be run.
        experiments : List[Experiment]
                List of experiments on which to run the calculator.
        """
        # Set upon instantiation of parent class
        super().__init__(experiment)
        # NOTE: if the calculator accepts more than just experiment/experiments
        #  in the init the @call decorator has to be changed!
        self.experiment: Experiment = experiment
        self.experiments: List[Experiment] = experiments
        # Setting the experiment value supersedes setting experiments
        if self.experiment is not None:
            self.experiments = [self.experiment]

        self.plot = False

        # SQL data attributes.
        self.result_keys = None
        self.result_series_keys = None
        self.analysis_name = None
        self.selected_species = None

        # Calculator attributes
        self.system_property = False
        self.experimental = False
        self.time = None
        self.prefactor = None

        # Data attributes
        self._dtype: object = tf.float64

        # Plotting attributes
        self.plotter = None
        self.x_label = None
        self.y_label = None
        self.plot_array = []

    @property
    def dtype(self):
        """Get the dtype used for the calculator."""
        return self._dtype

    def run_visualization(
        self, x_data: np.ndarray, y_data: np.ndarray, title: str, layouts: object = None
    ):
        """
        Run a visualization session on the data.

        Parameters
        ----------
        layouts : object
                Additional plot features that may be added.
                See https://docs.bokeh.org/en/latest/docs/reference/models.html for
                more information.
        x_data : np.ndarray
                Data to be plotted along the x axis
        y_data : np.ndarray
                Data to be plotted along the y-axis
        title : str
                Title of the analysis.

        Returns
        -------
        Updates the plot array with a Bokeh plot object.
        """
        self.plot_array.append(
            self.plotter.construct_plot(
                x_data=x_data,
                y_data=y_data,
                title=title,
                x_label=self.x_label,
                y_label=self.y_label,
                layouts=layouts,
            )
        )

    def run_calculator(self):
        """
        Run the calculation. This should be implemented in each calculator.

        Returns
        -------

        """
        raise NotImplementedError

    def plot_data(self, data):
        """
        Plot the data coming from the database.

        Parameters
        ----------
        data: db.Compution.data_dict
                associated with the current project
        """
        for selected_species, val in data.items():
            self.run_visualization(
                x_data=np.array(val[self.result_series_keys[0]]),
                y_data=np.array(val[self.result_series_keys[1]]),
                title=(
                    f"{selected_species}: {val[self.result_keys[0]]: 0.3E} +-"
                    f" {val[self.result_keys[1]]: 0.3E}"
                ),
            )

    def run_analysis(self):
        """Run the appropriate analysis."""
        if self.experimental:
            log.warning(
                "This is an experimental calculator. Please see the "
                "documentation before using the results."
            )
        self.run_calculator()