Ptrskay3/PySprint

View on GitHub
pysprint/core/methods/minmax.py

Summary

Maintainability
A
0 mins
Test Coverage
import logging
import warnings

import numpy as np

from pysprint.config import _get_config_value
from pysprint.core.bases.dataset import Dataset
from pysprint.mpl_tools.peak import EditPeak
from pysprint.core.phase import Phase
from pysprint.core._evaluate import min_max_method, is_inside
from pysprint.utils import (
    _maybe_increase_before_cwt,
    _calc_envelope,
)
from pysprint.utils import PySprintWarning


logger = logging.getLogger(__name__)
FORMAT = "[ %(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
logging.basicConfig(format=FORMAT)

__all__ = ["MinMaxMethod"]


class MinMaxMethod(Dataset):
    """
    Interface for Minimum-Maximum Method.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.phase = None
        self._is_onesided = False

    def init_edit_session(self, engine="normal", **kwargs):
        """
        Function to initialize peak editing on a plot.
        Right clicks (`d` key later) will delete the closest point,
        left clicks(`i` key later) will add a new point. Just close
        the window when finished. Must be called with interactive
        backend. The best practice is to call this function inside
        `~pysprint.interactive` context manager.

        Parameters
        ----------
        engine : str, optional
            Must be 'cwt', 'normal' or 'slope'.
            Peak detection algorithm to use.
            Default is normal.
        kwargs : dict, optional
            pmax, pmin, threshold, except_around, width
        """
        engines = ("cwt", "normal", "slope")

        if engine not in engines:
            raise ValueError(f"Engine must be in {str(engines)}")

        if engine == "normal":
            pmax = kwargs.pop("pmax", 0.1)
            pmin = kwargs.pop("pmin", 0.1)
            threshold = kwargs.pop("threshold", 0.1)
            except_around = kwargs.pop("except_around", None)
            side = kwargs.pop("side", "both")
            _x, _y, _xx, _yy = self.detect_peak(
                pmax=pmax, pmin=pmin, threshold=threshold, except_around=except_around, side=side
            )

            # just for validation purposes
            _ = kwargs.pop("widths", np.arange(1, 20))
            _ = kwargs.pop("floor_thres", 0.05)

        elif engine == "slope":
            side = kwargs.pop("side", "both")
            self._is_onesided = side != "both"
            x, _, _, _ = self._safe_cast()
            y = np.copy(self.y_norm)
            if _maybe_increase_before_cwt(y):
                y += 2
            _, lp, lloc = _calc_envelope(y, np.arange(len(y)), "l")
            _, up, uloc = _calc_envelope(y, np.arange(len(y)), "u")
            lp -= 2
            up -= 2
            _x, _xx = x[lloc], x[uloc]
            _y, _yy = lp, up

        elif engine == "cwt":
            widths = kwargs.pop("widths", np.arange(1, 20))
            side = kwargs.pop("side", "both")
            floor_thres = kwargs.pop("floor_thres", 0.05)
            _x, _y, _xx, _yy = self.detect_peak_cwt(
                widths=widths, floor_thres=floor_thres, side=side
            )

            # just for validation purposes
            _ = kwargs.pop("pmax", 0.1)
            _ = kwargs.pop("pmin", 0.1)
            _ = kwargs.pop("threshold", 0.1)
            _ = kwargs.pop("except_around", None)

        if side == "both":
            _xm = np.append(_x, _xx)
            _ym = np.append(_y, _yy)
        elif side == "min":
            _xm, _ym = _xx, _yy
        elif side == "max":
            _xm, _ym = _x, _y
        else:
            raise ValueError("Side must be 'both', 'min' or 'max'.")

        if kwargs:
            raise TypeError(f"Invalid argument:{kwargs}")

        try:
            _editpeak = EditPeak(self.x, self.y_norm, _xm, _ym)
        except ValueError:
            _editpeak = EditPeak(self.x, self.y, _xm, _ym)
        # Automatically propagate these points to the mins and maxes.
        # Distribute these points between min and max, just in case
        # the default argrelextrema is definitely not called
        # in `pysprint.core.evaluate.min_max_method`.

        self.xmin = _editpeak.get_dat[0][:len(_editpeak.get_dat[0]) // 2]
        self.xmax = _editpeak.get_dat[0][len(_editpeak.get_dat[0]) // 2:]
        print(f"{len(_editpeak.get_dat[0])} extremal points were recorded.")
        return _editpeak.get_dat[0]  # we should return None

    def calculate(
        self,
        reference_point,
        order,
        SPP_callbacks=None,
        show_graph=False,
        scan=False,
        onesided=False,
    ):
        """
        MinMaxMethod's calculate function.

        Parameters
        ----------
        reference_point : float
            reference point on x axis
        order : int
            Polynomial (and maximum dispersion) order to fit. Must be in [1, 5].
        SPP_callbacks : number, or numeric list-like
            The positions of SPP's on the interferogram. If not given it will check
            if there's any SPP position set on the object.
        show_graph : bool, optional
            Shows a the final graph of the spectral phase and fitted curve.
            Default is False.
        onesided : bool
            Use only minimums or maximums to build the phase. It also works for
            one characteristic point per oscillation period (e.g. zero-crossings).
            Default is False.

        Returns
        -------
        dispersion : array-like
            [GD, GDD, TOD, FOD, QOD, SOD]
        dispersion_std : array-like
            Standard deviations due to uncertainty of the fit.
            They are only calculated if lmfit is installed.
            [GD_std, GDD_std, TOD_std, FOD_std, QOD_std, SOD_std]
        fit_report : str
            lmfit report if installed, else empty string.

        Note
        ----
        Decorated with pprint_disp, so the results are
        immediately printed without explicitly saying so.
        """

        return self._calculate(
            reference_point,
            order,
            SPP_callbacks,
            show_graph,
            scan,
            onesided
        )

    def _calculate(
            self,
            reference_point,
            order,
            SPP_callbacks=None,
            show_graph=False,
            scan=False,
            onesided=False,
    ):
        phase = self.build_phase(
            reference_point=reference_point, SPP_callbacks=SPP_callbacks, onesided=onesided
        )

        if is_inside(reference_point, phase.x):
            left_phase = phase.slice(None, reference_point, inplace=False)
            right_phase = phase.slice(reference_point, None, inplace=False)
        else:
            scan = False
            logger.info("Scan is disabled, reference_point is on the border.")

        if scan:
            left_d, left_ds, left_fit_report = left_phase._fit(reference_point, order)
            right_d, right_ds, right_fit_report = right_phase._fit(reference_point, order)

            logger.info(f"left side evaluated to {left_d}, used {len(left_phase.x)} points.")
            logger.info(f"right side evaluated to {right_d}, used {len(right_phase.x)} points.")

            right_d = np.where(np.sign(left_d) != np.sign(right_d), -right_d, right_d)
            right_ds = np.where(np.sign(left_ds) != np.sign(right_ds), -right_ds, right_ds)

            diffs = np.abs(np.trim_zeros(right_d) - np.trim_zeros(left_d)) / np.trim_zeros(left_d)

            thres = _get_config_value("scan_threshold")

            if (np.abs(diffs[~np.isnan(diffs)]) > thres).any():
                dispersion = left_d if len(left_phase.x) >= len(right_phase.x) else right_d
                dispersion_std = left_ds if len(left_phase.x) >= len(right_phase.x) else right_ds
                fit_report = left_fit_report if len(left_phase.x) >= len(right_phase.x) else right_fit_report
                side = "left" if len(left_phase.x) >= len(right_phase.x) else "right"
                logger.info(f"Max relative difference is too high, using {side} side.")

                if side == "left":
                    if show_graph:
                        phase.plot()
                        left_phase.plot(label="used part")
                else:
                    if show_graph:
                        phase.plot()
                        right_phase.plot(label="used part")

            else:
                dispersion = np.mean([left_d, right_d], axis=0)

                dispersion_std = np.mean([left_ds, right_ds], axis=0)

                fit_report = ''.join([left_fit_report, right_fit_report])
                if show_graph:
                    left_phase.plot()
                    right_phase.plot()
        else:
            dispersion, dispersion_std, fit_report = phase._fit(reference_point, order)
            if show_graph:
                phase.plot()

        return dispersion, dispersion_std, fit_report

    def build_phase(self, reference_point, SPP_callbacks=None, onesided=False):
        """
        Build **only the phase** using reference point and SPP positions.

        Parameters
        ----------
        reference_point : float
            The reference point from where the phase building starts.
        SPP_callbacks : number, or numeric list-like
            The positions of SPP's on the interferogram. If not given it will check
            if there's any SPP position set on the object.
        onesided : bool
            If `True`, use only the minimums or maximums to build the phase. It also
            works for one characteristic point per oscillation period (e.g. zero-crossings).
            Default is False.

        Returns
        -------
        phase : pysprint.core.phase.Phase
            The phase object. See its docstring for more info.
        """
        if SPP_callbacks is None and self._positions is not None:
            SPP_callbacks = np.array(self._positions)

        if onesided and not self._is_onesided:
            warnings.warn(
                "Trying to build phase as one-sided, but the detection was two-sided. Use `onesided=False`.",
                PySprintWarning
            )

        if not onesided and self._is_onesided:
            warnings.warn(
                "Trying to build phase as two-sided, but the detection was one-sided. Use `onesided=True`.",
                PySprintWarning
            )

        x, y = min_max_method(
            self.x,
            self.y,
            self.ref,
            self.sam,
            ref_point=reference_point,
            maxx=self.xmax,
            minx=self.xmin,
            SPP_callbacks=SPP_callbacks,
            onesided=onesided,
        )
        self.phase = Phase(x, y)
        return self.phase