Ptrskay3/PySprint

View on GitHub
pysprint/core/methods/wft.py

Summary

Maintainability
A
2 hrs
Test Coverage
import sys
import warnings
from inspect import isfunction

import numpy as np
import matplotlib.pyplot as plt

from pysprint.core.methods.fftmethod import FFTMethod
from pysprint.core.phase import Phase
from pysprint.core._fft_tools import find_roi
from pysprint.core._fft_tools import find_center
from pysprint.utils.decorators import inplacify
from pysprint.utils import NotCalculatedException
from pysprint.utils import PySprintWarning
from pysprint.utils.misc import find_nearest
from pysprint.core.window import GaussianWindow
from pysprint.core.window import WindowBase

try:
    from dask import delayed, compute
    from dask.diagnostics import ProgressBar
    CAN_PARALLELIZE = True
except ImportError:
    CAN_PARALLELIZE = False

    def delayed(func=None, *args, **kwargs):
        if isfunction(func):
            return func


class WFTMethod(FFTMethod):
    """Basic interface for Windowed Fourier Transform Method.
    The `window_class` attribute can be set up for custom windowing.
    """

    def __init__(self, *args, **kwargs):
        self.window_class = kwargs.pop("window_class", GaussianWindow)
        assert issubclass(self.window_class, WindowBase), "window_class must subclass pysprint.core.window.WindowBase"
        super().__init__(*args, **kwargs)
        self.window_seq = {}
        self.found_centers = {}
        self.GD = None
        self.cachedlen = 0
        self.X_cont = np.array([])
        self.Y_cont = np.array([])
        self.Z_cont = np.array([])
        self.fastmath = True
        self.errorcounter = 0

    @inplacify
    def add_window(self, center, **kwargs):
        """
        Add a Gaussian window to the interferogram.

        Parameters
        ----------
        center : float
            The center of the Gaussian window.
        kwargs : dict
            Keyword arguments to pass to the `window_class`.
        """
        window = self.window_class(self.x, center=center, **kwargs)
        self.window_seq[center] = window
        return self

    @property
    def windows(self):
        return self.window_seq

    @property
    def centers(self):
        return self.window_seq.keys()

    @inplacify
    def add_window_generic(self, array, **kwargs):
        """
        Build a window sequence of given parameters with centers
        specified with ``array`` argument.

        Parameters
        ----------
        array : list, np.ndarray
            The array containing the centers of windows.
        kwargs : dict
            Keyword arguments to pass to the `window_class`.
        """
        if not isinstance(array, (list, np.ndarray)):
            raise TypeError("Expected list-like as ``array``.")
        for center in array:
            self.add_window(center=center, **kwargs)
        return self

    @inplacify
    def add_window_arange(self, start, stop, step, **kwargs):
        """
        Build a window sequence of given parameters to apply on ifg.
        Works similar to numpy.arange.

        Parameters
        ----------
        start : float
            The start of the centers.
        stop : float
            The end value of the center
        step : float
            The step value to increment center.
        kwargs : dict
            Keyword arguments to pass to the `window_class`.
        """
        arr = np.arange(start, stop, step)
        for cent in arr:
            self.add_window(center=cent, **kwargs)
        return self

    @inplacify
    def add_window_linspace(self, start, stop, num, **kwargs):
        """
        Build a window sequence of given parameters to apply on ifg.
        Works similar to numpy.linspace.

        Parameters
        ----------
        start : float
            The start of the centers.
        stop : float
            The end value of the center
        num : float
            The number of Gaussian windows.
        kwargs : dict
            Keyword arguments to pass to the `window_class`.
        """
        arr = np.linspace(start, stop, num)
        for cent in arr:
            self.add_window(center=cent, **kwargs)
        return self

    @inplacify
    def add_window_geomspace(self, start, stop, num, **kwargs):
        """
        Build a window sequence of given parameters to apply on ifg.
        Works similar to numpy.geomspace.

        Parameters
        ----------
        start : float
            The start of the centers.
        stop : float
            The end value of the center
        num : float
            The number of Gaussian windows.
        kwargs : dict
            Keyword arguments to pass to the `window_class`.
        """
        arr = np.geomspace(start, stop, num)
        for cent in arr:
            self.add_window(center=cent, **kwargs)
        return self

    def view_windows(self, ax=None, maxsize=80, **kwargs):
        """
        Gives a rough view of the different windows along with the ifg.

        Parameters
        ----------
        ax : matplotlib.axes.Axes, optional
            An axis to draw the plot on. If not given, it will plot
            of the last used axis.
        maxsize : int, optional
            The maximum number of Gaussian windows to display on plot.
            Default is 80, but be aware that setting a high value can
            drastically reduce performance.
        kwargs : dict, optional
            Additional keyword arguments to pass to plot function.
        """
        winlen = len(self.window_seq)
        if maxsize != 0:
            ratio = winlen // maxsize
            if winlen > maxsize:
                warnings.warn(
                    "Image seems crowded, displaying only a subsample of the given windows.",
                    PySprintWarning
                )
                for i, (_, val) in enumerate(self.window_seq.items()):
                    if i % ratio == 0:
                        val.plot(ax=ax, scalefactor=np.max(self.y) * .75, **kwargs)
            else:
                for _, val in self.window_seq.items():
                    val.plot(ax=ax, scalefactor=np.max(self.y) * .75, **kwargs)
        self.plot(ax=ax)

    @inplacify
    def remove_all_windows(self):
        """
        Remove all the Gaussian windows.
        """
        self.window_seq.clear()
        return self

    @inplacify
    def reset_state(self):
        """
        Reset the object's state fully: delete all the
        calculated GD, caches, heatmaps and window sequences.
        """
        self.remove_all_windows()
        self.found_centers.clear()
        self.X_cont = np.array([])
        self.Y_cont = np.array([])
        self.Z_cont = np.array([])
        self.GD = None
        self.cachedlen = 0
        self.fastmath = True
        return self

    @inplacify
    def remove_window_at(self, center):
        """
        Removes a window at center.

        Parameters
        ----------
        center : float
            The center of the window to remove.
            Raises ValueError if there is not such window.
        """
        if center not in self.window_seq.keys():
            c = find_nearest(
                np.fromiter(self.window_seq.keys(), dtype=float), center
            )
            raise ValueError(
                f"There is no window with center {center}. "
                f"Did you mean {c[0]}?"
            )
        self.window_seq.pop(center, None)
        return self

    @inplacify
    def remove_window_interval(self, start, stop):
        """
        Remove window interval inclusively.

        Parameters
        ----------
        start : float
            The start value of the interval.
        stop : float
            The stop value of the interval.
        """
        wins = np.fromiter(self.window_seq.keys(), dtype=float)
        mask = wins[(wins <= stop) & (wins >= start)]
        for center in mask:
            self.window_seq.pop(center, None)
        return self

    @inplacify
    def cover(self, N, **kwargs):
        """
        Cover the whole domain with `N` number of windows
        uniformly built with the given parameters.

        Parameters
        ----------
        N : float
            The number of Gaussian windows.
        kwargs : dict
            Keyword arguments to pass to the `window_class`.
        """

        self.add_window_linspace(np.min(self.x), np.max(self.x), N, **kwargs)

    def _calculate(
        self,
        reference_point,
        order,
        show_graph=False,
        silent=False,
        force_recalculate=False,
        fastmath=True,
        usenifft=False,
        parallel=False,
        ransac=False,
        errors="ignore",
        **kwds
    ):
        if len(self.window_seq) == 0:
            raise ValueError("Before calculating a window sequence must be set.")

        if self.cachedlen != len(self.window_seq) or fastmath != self.fastmath:
            force_recalculate = True
        self.fastmath = fastmath
        if force_recalculate:
            self.found_centers.clear()
            self.build_GD(
                silent=silent, fastmath=fastmath, usenifft=usenifft, parallel=parallel, errors=errors
            )
        if self.GD is None:
            self.build_GD(
                silent=silent, fastmath=fastmath, usenifft=usenifft, parallel=parallel, errors=errors
            )

        self.cachedlen = len(self.window_seq)

        if order == 1 or order > 6:
            raise ValueError("Order must be in [2, 6].")

        if ransac:
            print("Running RANSAC-filter..")
            self.GD.ransac_filter(order=order, plot=show_graph, **kwds)
            self.GD.apply_filter()

        d, ds, fr = self.GD._fit(
            reference_point=reference_point, order=order
        )
        if show_graph:
            self.GD.plot()
        return d, ds, fr

    def calculate(
            self,
            reference_point,
            order,
            show_graph=False,
            silent=False,
            force_recalculate=False,
            fastmath=True,
            usenifft=False,
            parallel=False,
            ransac=False,
            errors="ignore",
            **kwds
    ):
        """
        Calculates the dispersion.

        Parameters
        ----------
        reference_point : float
            The reference point.
        order : int
            The dispersion order to look for. Must be in [2, 6].
        show_graph : bool, optional
            Whether to show the GD graph on complete. Default is False.
        silent : bool, optional
            Whether to print progressbar. By default it will print.
        force_recalculate : bool, optional
            Force to recalculate the GD graph not only the curve fitting.
            Default is False.
        fastmath : bool, optional
            Whether to build additional arrays to display heatmap.
            Default is True.
        usenifft : bool, optional
            Whether to use Non-unifrom FFT when calculating GD.
            Default is False. **Not stable.**
        parallel : bool, optional
            Whether to use parallel computation. Only availabe if `Dask`
            is installed. The speedup is about 50-70%. Default is False.
        ransac : bool, optional
            Whether to use RANSAC filtering on the detected peaks. Default
            is False.
        errors : str, optional
            Whether to raise an error is the algorithm couldn't find the
            center of the peak. Default is "ignore".
        kwds : optional
            Other keyword arguments to pass to RANSAC filter.

        Raises
        ------
        ValueError, if no window sequence is added to the interferogram.
        ValueError, if order is 1.
        ModuleNotFoundError, if `Dask` is not available when using parallel=True.
        """
        return self._calculate(
            reference_point,
            order,
            show_graph,
            silent,
            force_recalculate,
            fastmath,
            usenifft,
            parallel,
            ransac,
            errors,
            **kwds
        )

    def build_GD(self, silent=False, fastmath=True, usenifft=False, parallel=False, errors="ignore"):
        """
        Build the GD.

        Parameters
        ----------
        silent : bool, optional
            Whether to print progressbar. By default it will print.
        fastmath : bool, optional
            Whether to build additional arrays to display heatmap.
            Default is True.
        usenifft : bool, optional
            Whether to use Non-unifrom FFT when calculating GD.
            Default is False. **Not stable.**
        parallel : bool, optional
            Whether to use parallel computation. Only availabe if `Dask`
            is installed. The speedup is about 50-70%. Default is False.
        errors : str, optional
            Whether to raise an error is the algorithm couldn't find the
            center of the peak.

        Returns
        -------
        GD : pysprint.core.phase.Phase
            The phase object with `GD_mode=True`. See its docstring for more info.
        """
        if parallel:

            if not CAN_PARALLELIZE:
                raise ModuleNotFoundError(
                    "Module `dask` not found. Please install it in order to use parallelism."
                )

            else:
                self.fastmath = fastmath
                self._apply_window_seq_parallel(fastmath=fastmath, usenifft=usenifft, errors=errors)

                if not silent:
                    with ProgressBar():
                        computed = compute(*self.found_centers.values())

                else:
                    computed = compute(*self.found_centers.values())

                cleaned_delays = [
                    k for i, k in enumerate(self.found_centers.keys()) if computed[i] is not None
                ]
                delay = np.fromiter(cleaned_delays, dtype=float)

                omega = np.fromiter([c for c in computed if c is not None], dtype=float)

                if not silent:
                    print(f"Skipped: {len(self.window_seq) - sum(1 for _ in filter(None.__ne__, computed))}")

        else:
            self.fastmath = fastmath
            self._apply_window_sequence(silent=silent, fastmath=fastmath, usenifft=usenifft)
            self._clean_centers(silent=silent)

            delay = np.fromiter(self.found_centers.keys(), dtype=float)
            omega = np.fromiter(self.found_centers.values(), dtype=float)

        self.GD = Phase(delay, omega, GD_mode=True)
        return self.GD

    def build_phase(self):
        raise NotImplementedError("Use `build_GD` instead.")

    def _predict_ideal_window_fwhm(self):
        pass

    def _apply_window_sequence(
            self, silent=False, fastmath=True, usenifft=False, errors="ignore"
    ):
        winlen = len(self.window_seq)
        self.errorcounter = 0
        if not fastmath:
            # here we setup the shape for the Z array because
            # it is much faster than using np.append in every iteration
            _x, _y, _, _ = self._safe_cast()
            _obj = FFTMethod(_x, _y)
            _obj.ifft(usenifft=usenifft)
            x, y = find_roi(_obj.x, _obj.y)
            self.Y_cont = np.array(x)
            yshape = y.size
            xshape = len(self.window_seq)
            self.Z_cont = np.empty(shape=(yshape, xshape))

        for idx, (_center, _window) in enumerate(self.window_seq.items()):
            _x, _y, _, _ = self._safe_cast()
            _obj = FFTMethod(_x, _y)
            _obj.y *= _window.y
            _obj.ifft(usenifft=usenifft)
            x, y = find_roi(_obj.x, _obj.y)
            if not fastmath:
                self.Z_cont[:, idx] = y
            try:
                centx, _ = find_center(x, y)
                self.found_centers[_center] = centx
            except ValueError as err:
                self.errorcounter += 1
                if errors == "ignore":
                    self.found_centers[_center] = None
                else:
                    raise err
            if not silent:  # This creates about 5-15% overhead.. maybe create a buffer
                sys.stdout.write('\r')
                j = (idx + 1) / winlen
                sys.stdout.write(
                    "Progress : [%-30s] %d%% (Skipped: %d)" % ('=' * int(30 * j), 100 * j, self.errorcounter)
                )
                sys.stdout.flush()

    def _apply_window_seq_parallel(
            self, fastmath=True, usenifft=False, errors="ignore"
    ):
        self.errorcounter = 0
        if not fastmath:
            # here we setup the shape for the Z array and allocate Y, because
            # it is much faster than using np.append in every iteration
            _x, _y, _, _ = self._safe_cast()
            _obj = FFTMethod(_x, _y)
            _obj.ifft(usenifft=usenifft)
            x, y = find_roi(_obj.x, _obj.y)
            yshape = y.size
            self.Y_cont = np.array(x)
            xshape = len(self.window_seq)
            self.Z_cont = np.empty(shape=(yshape, xshape))

        for idx, (_center, _window) in enumerate(self.window_seq.items()):
            element = self._prepare_element(idx, _window, fastmath, usenifft, errors)
            if element is None:
                self.errorcounter += 1  # This might be useless, since we lazy evaluate things..
            self.found_centers[_center] = element

    @delayed
    def _prepare_element(self, idx, window, fastmath=True, usenifft=False, errors="ignore"):
        _x, _y, _, _ = self._safe_cast()
        _obj = FFTMethod(_x, _y)
        _obj.y *= window.y
        _obj.ifft(usenifft=usenifft)
        x, y = find_roi(_obj.x, _obj.y)
        if not fastmath:
            self.Z_cont[:, idx] = y
        try:
            centx, _ = find_center(x, y)
            return centx
        except ValueError as err:
            if errors == "ignore":
                return None
            else:
                raise err

    def _clean_centers(self, silent=False):
        dct = {k: v for k, v in self.found_centers.items() if v is not None}
        self.found_centers = dct
        winlen = len(self.window_seq)
        usefullen = len(self.found_centers)

        if not silent:
            if winlen != usefullen:
                print(
                    f"\n{abs(winlen-usefullen)} points skipped "
                    f"due to ambiguous peak position."
                )

    def errorplot(self, *args, **kwargs):
        """
        Plot the errors of fitting.

        Parameters
        ----------
        ax : matplotlib.axes.Axes, optional
            An axis to draw the plot on. If not given, it will plot
            of the last used axis.
        percent : bool, optional
            Whether to plot percentage difference. Default is False.
        title : str, optional
            The title of the plot. Default is "Errors".
        kwargs : dict, optional
            Additional keyword arguments to pass to plot function.
        """
        try:
            getattr(self.GD, "errorplot", None)(*args, **kwargs)
        except TypeError:
            raise NotCalculatedException("Must calculate before plotting errors.")

    @property
    def get_GD(self):
        """
        Return the GD if it is already calculated.
        """
        if self.GD is not None:
            return self.GD
        raise NotCalculatedException("Must calculate GD first.")

    @property
    def errors(self):
        """
        Return the fitting errors as np.ndarray.
        """
        return getattr(self.GD, "errors", None)

    def _collect_failures(self):
        return [k for k in self.window_seq.keys() if k not in self.found_centers.keys()]

    def _construct_heatmap_data(self):
        self.X_cont = np.fromiter(self.window_seq.keys(), dtype=float)

    def heatmap(self, ax=None, levels=None, cmap="viridis", include_ridge=True):
        """
        Plot the heatmap.

        Parameters
        ----------
        ax : matplotlib.axes.Axes, optional
            An axis to draw the plot on. If not given, it will plot
            of the last used axis.
        levels : np.ndarray, optional
            The levels to use for plotting.
        cmap : str, optional
            The colormap to use.
        include_ridge : bool, optional
            Whether to mark the detected ridge of the plot.
            Default is True.
        """
        if self.GD is None:
            raise NotCalculatedException("Must calculate GD first.")

        if self.fastmath:
            raise ValueError(
                "You need to recalculate with `fastmath=False` to plot the heatmap."
            )
        # Only construct if we need to..
        if not (self.Y_cont.size, self.X_cont.size) == self.Z_cont.shape:
            self._construct_heatmap_data()
        if ax is None:
            plt.contourf(
                self.X_cont, self.Y_cont, self.Z_cont, levels=levels, cmap=cmap, extend="both"
            )
        else:
            ax.contourf(
                self.X_cont, self.Y_cont, self.Z_cont, levels=levels, cmap=cmap, extend="both"
            )
        if include_ridge:
            if ax is None:
                plt.plot(*self.GD.data, color='red', label='detected ridge')
            else:
                ax.plot(*self.GD.data, color='red', label='detected ridge')
            plt.legend()
        if ax is None:
            plt.xlabel('Window center [PHz]')
            plt.ylabel('Delay [fs]')
            try:
                upper_bound = min(1.5 * np.max(self.GD.data[1]), np.max(self.Y_cont))
                plt.ylim(None, upper_bound)
            except ValueError:
                pass
            # ValueError: zero-size array to reduction operation maximum which has no identity
            # This means that the array is empty, we should pass that case.
        else:
            ax.set_autoscalex_on(False)
            try:
                upper_bound = min(1.5 * np.max(self.GD.data[1]), np.max(self.Y_cont))
                ax.set(
                    xlabel="Window center [PHz]",
                    ylabel="Delay [fs]",
                    ylim=(None, upper_bound)
                )
            except ValueError:
                pass

    def get_heatmap_data(self):
        """
        Return the data which was used to create the heatmap.

        Returns
        -------
        X_cont : np.ndarray
            The window centers with shape (n,).
        Y_cont : np.ndarray
            The time axis calculated from the IFFT of the dataset with shape (m,).
        Z_cont : np.ndarray
            2D array with shape (m, n) containing the depth information.
        """
        if all([self.Y_cont.size != 0, self.Z_cont.size != 0]):
            self._construct_heatmap_data()
        else:
            raise ValueError("Must calculate with `fastmath=False` before trying to access the heatmap data.")

        return self.X_cont, self.Y_cont, self.Z_cont