Ptrskay3/PySprint

View on GitHub
pysprint/core/bases/dataset.py

Summary

Maintainability
D
2 days
Test Coverage
"""
This file implements the Dataset class with all the functionality
that an interferogram should have in general.
"""
import base64
from collections.abc import Iterable
from contextlib import suppress, contextmanager
from io import BytesIO
import json
import logging
from math import factorial
import numbers
import re
from textwrap import dedent
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from jinja2 import Template

from pysprint.config import _get_config_value
from pysprint.core.bases._dataset_base import _DatasetBase
from pysprint.core.bases._dataset_base import C_LIGHT
from pysprint.core.bases._apply import _DatasetApply
from pysprint.core._evaluate import is_inside
from pysprint.core._evaluate import ifft_method
from pysprint.core._fft_tools import find_center
from pysprint.core.io._parser import _parse_raw
from pysprint.mpl_tools.spp_editor import SPPEditor
from pysprint.mpl_tools.normalize import DraggableEnvelope
from pysprint.utils import MetaData, find_nearest
from pysprint.utils.decorators import inplacify
from pysprint.core._preprocess import (
    savgol,
    find_peak,
    convolution,
    cut_data,
    cwt,
)
from pysprint.utils.exceptions import (
    InterpolationWarning,
    DatasetError,
    PySprintWarning,
)

logger = logging.getLogger(__name__)
FORMAT = "[ %(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
logging.basicConfig(format=FORMAT)

__all__ = ["Dataset"]


class Dataset(metaclass=_DatasetBase):
    """
    This class implements all the functionality a dataset
    should have in general.
    """

    meta = MetaData("""Additional info about the dataset""", copy=False)

    def __init__(
            self,
            x,
            y,
            ref=None,
            sam=None,
            meta=None,
            errors="raise",
            callback=None,
            parent=None,
            **kwargs
    ):
        """
        Base constructor for Dataset.

        Parameters
        ----------
        x : np.ndarray
            The x values.
        y : np.ndarray
            The y values.
        ref : np.ndarray, optional
            The reference arm's spectra.
        sam : np.ndarray, optional
            The sample arm's spectra.
        meta : dict-like
            The dictionary containing further information about the dataset.
            Can be extended, or set to be any valid ~collections.abc.Mapping.
        errors: str, optional
            Whether to raise on missmatching sized data. Must be "raise" or
            "force". If "force" then truncate to the shortest size. Default is
            "raise".
        callback : callable, optional
            The function that notifies parent objects about SPP related
            changes. In most cases the user should leave this empty. The
            default callback is only initialized if this object is constructed
            by the `pysprint.SPPMethod` object.
        parent : any class, optional
            The object which handles the callback function. In most cases
            the user should leave this empty.
        kwargs : dict, optional
            The window class to use in WFTMethod. Has no effect while using other
            methods. Must be a subclass of pysprint.core.windows.WindowBase.

        Note
        ----
        To load in data by files, see the other constructor `parse_raw`.
        """
        super().__init__()

        if errors not in ("raise", "force"):
            raise ValueError("errors must be `raise` or `force`.")

        self.callback = callback or (lambda *args: args)
        self.parent = parent

        self.x = np.array(x, dtype=np.float64)
        self.y = np.array(y, dtype=np.float64)
        if ref is None:
            self.ref = []
        else:
            self.ref = ref
        if sam is None:
            self.sam = []
        else:
            self.sam = sam
        self._is_normalized = False
        if not isinstance(self.x, np.ndarray):
            try:
                self.x = np.array(self.x).astype(float)
            except ValueError:
                raise DatasetError("Invalid type of data")
        if not isinstance(self.y, np.ndarray):
            try:
                self.y = np.array(self.y).astype(float)
            except ValueError:
                raise DatasetError("Invalid type of data")
        if not isinstance(self.ref, np.ndarray):
            try:
                self.ref = np.array(self.ref).astype(float)
            except ValueError:
                pass  # just ignore invalid arms
        if not isinstance(self.sam, np.ndarray):
            try:
                self.sam = np.array(self.sam).astype(float)
            except ValueError:
                pass  # just ignore invalid arms

        if not len(self.x) == len(self.y):
            if errors == 'raise':
                raise ValueError(
                    f"Mismatching data shapes with {self.x.shape} and {self.y.shape}."
                )
            else:
                truncated_shape = min(len(self.x), len(self.y))
                # probably we should cut down the first half
                self.x, self.y = self.x[-truncated_shape:], self.y[-truncated_shape:]

        if len([x for x in (self.ref, self.sam) if len(x) != 0]) == 1:
            warnings.warn(
                "Reference and sample arm should be passed together or neither one.",
                PySprintWarning
            )

        if len(self.ref) == 0 or len(self.sam) == 0:
            self.y_norm = self.y
            self._is_normalized = self._ensure_norm()

        else:
            if not np.all([len(self.sam) == len(self.x), len(self.ref) == len(self.x)]):
                if errors == 'raise':
                    raise ValueError(
                        f"Mismatching data shapes with {self.x.shape}, "
                        f"{self.ref.shape} and {self.sam.shape}."
                    )
                else:
                    truncated_shape = min(len(self.x), len(self.ref), len(self.sam), len(self.y))
                    # same as above..
                    self.ref, self.sam = self.ref[-truncated_shape:], self.sam[-truncated_shape:]

            self.y_norm = (self.y - self.ref - self.sam) / (
                    2 * np.sqrt(self.sam * self.ref)
            )
            self._is_normalized = True

        self.plt = plt
        self.xmin = None
        self.xmax = None
        self.probably_wavelength = None
        self.unit = None
        self._check_domain()

        if meta is not None:
            self.meta = meta

        self._delay = None
        self._positions = None

        nanwarning = np.isnan(self.y_norm).sum()
        infwarning = np.isinf(self.y_norm).sum()
        if nanwarning > 0 or infwarning > 0:
            warnings.warn(
                ("Extreme values encountered during normalization.\n"
                f"Nan values: {nanwarning}\nInf values: {infwarning}"),
                PySprintWarning
            )

        self._dispersion_array = None

    @inplacify
    def chrange(self, current_unit, target_unit="phz"):
        """
        Change the domain range of the dataset.

        Supported units for frequency:
            * PHz
            * THz
            * GHz
        Supported units for wavelength:
            * um
            * nm
            * pm
            * fm

        Parameters
        ----------
        current_unit : str
            The current unit of the domain. Case insensitive.
        target_unit : str, optional
            The target unit. Must be compatible with the currect unit.
            Case insensitive. Default is `phz`.
        """
        current_unit, target_unit = current_unit.lower(), target_unit.lower()
        conversions = {
            "um": {"um": 1, "nm": 1000, "pm": 1E6, "fm": 1E9},
            "nm": {"um": 1 / 1000, "nm": 1, "pm": 1000, "fm": 1E6},
            "pm": {"um": 1 / 1E6, "nm": 1 / 1000, "pm": 1, "fm": 1000},
            "fm": {"um": 1 / 1E9, "nm": 1 / 1E6, "pm": 1 / 1000, "fm": 1},
            "phz": {"phz": 1, "thz": 1000, "ghz": 1E6},
            "thz": {"phz": 1 / 1000, "thz": 1, "ghz": 1000},
            "ghz": {"phz": 1 / 1E6, "thz": 1 / 1000, "ghz": 1}
        }
        try:
            ratio = float(conversions[current_unit][target_unit])
        except KeyError as error:
            raise ValueError("Units are not compatible") from error
        self.x = (self.x * ratio)
        self.unit = self._render_unit(target_unit)
        return self

    def __len__(self):
        return len(self.x)

    @staticmethod
    def _render_unit(unit, mpl=False):
        unit = unit.lower()
        charmap = {
            "um": (r"\mu m", "um"),
            "nm": ("nm", "nm"),
            "pm": ("pm", "pm"),
            "fm": ("fm", "fm"),
            "phz": ("PHz", "PHz"),
            "thz": ("THz", "THz"),
            "ghz": ("GHz", "GHz")
        }
        if mpl:
            return charmap[unit][0]
        return charmap[unit][1]

    @inplacify
    def transform(self, func, axis=None, args=None, kwargs=None):
        """
        Function which enables to apply arbitrary function to the
        dataset.

        Parameters
        ----------
        func : callable
            The function to apply on the dataset.
        axis : int or str, optional
            The axis which is the operation is performed on.
            Must be 'x', 'y', '0' or '1'.
        args : tuple, optional
            Additional arguments to pass to func.
        kwargs : dict, optional
            Additional keyword arguments to pass to func.
        """
        operation = _DatasetApply(
            obj=self, func=func, axis=axis, args=args, kwargs=kwargs
        )
        operation.perform()
        return self

    #  TODO : Rewrite this
    def phase_plot(self, exclude_GD=False):
        """
        Plot the phase if the dispersion is already calculated.

        Parameters
        ----------
        exclude_GD : bool
            Whether to exclude the GD part of the polynomial.
            Default is `False`.
        """
        if not np.all(self._dispersion_array):
            raise ValueError("Dispersion must be calculated before plotting the phase.")

        coefs = np.array(
            [
                self._dispersion_array[i] / factorial(i + 1)
                for i in range(len(self._dispersion_array))
            ]
        )

        if exclude_GD:
            coefs[0] = 0

        phase_poly = np.poly1d(coefs[::-1], r=False)

        self.plt.plot(self.x, phase_poly(self.x))
        self.plt.grid()
        self.plt.ylabel(r"$\Phi\, [rad]$")
        self.plt.xlabel(r"$\omega \,[PHz]$")
        self.plt.show()

    @property
    def delay(self):
        """
        Return the delay value if set.
        """
        return self._delay

    @delay.setter
    def delay(self, value):
        self._delay = value
        try:
            self.callback(self, self.parent)
        except ValueError:
            pass  # delay or position is missing

    @property
    def positions(self):
        """
        Return the SPP position(s) if set.
        """
        return self._positions

    @positions.setter
    def positions(self, value):
        if isinstance(value, numbers.Number):
            if value < np.min(self.x) or value > np.max(self.x):
                raise ValueError(
                    f"Cannot set SPP position to {value} since it's not in the dataset's range."
                )
        # FIXME: maybe we don't need to distinguish between np.ndarray and Iterable
        elif isinstance(value, np.ndarray) or isinstance(value, Iterable):
            for val in value:
                if not isinstance(val, numbers.Number):
                    raise ValueError(
                        f"Expected numeric values, got {type(val)} instead."
                    )
                if val < np.min(self.x) or val > np.max(self.x):
                    raise ValueError(
                        f"Cannot set SPP position to {val} since it's not in the dataset's range."
                    )
        self._positions = value
        try:
            self.callback(self, self.parent)
        except ValueError:
            pass  # delay or position is missing

    def _ensure_norm(self):
        """
        Ensure the interferogram is normalized and only a little part
        which is outlying from the [-1, 1] interval (because of noise).
        """
        try:
            idx = np.where((self.y_norm > 2))
            val = len(idx[0]) / len(self.y_norm)
        except TypeError as e:
            raise DatasetError("Non-numeric values found while reading dataset.") from e
        if val > 0.015:  # this is a custom threshold, which often works..
            return False
        return True

    def scale_up(self):
        """
        If the interferogram is normalized to [0, 1] interval, scale
        up to [-1, 1] with easy algebra.
        """
        self.y_norm = (self.y_norm - 0.5) * 2
        self.y = (self.y - 0.5) * 2

    def GD_lookup(self, reference_point=None, engine="cwt", silent=False, **kwargs):
        """
        Quick GD lookup: it finds extremal points near the
        `reference_point` and returns an average value of 2*pi
        divided by distances between consecutive minimal or maximal values.
        Since it's relying on peak detection, the results may be irrelevant
        in some cases. If the parent class is `~pysprint.CosFitMethod`, then
        it will set the predicted value as initial parameter for fitting.

        Parameters
        ----------
        reference_point : float
            The reference point for the algorithm.
        engine : str, optional
            The backend to use. Must be "cwt", "normal" or "fft".
            "cwt" will use `scipy.signal.find_peaks_cwt` function to
            detect peaks, "normal" will use `scipy.signal.find_peaks`
            to detect peaks. The "fft" engine uses Fourier-transform and
            looks for the outer peak to guess delay value. It's not
            reliable when working with low delay values.
        silent : bool, optional
            Whether to print the results immediately. Default in `False`.
        kwargs : dict, optional
            Additional keyword arguments to pass for peak detection
            algorithms. These are:
            pmin, pmax, threshold, width, floor_thres, etc..
            Most of them are described in the `find_peaks` and
            `find_peaks_cwt` docs.
        """
        precision = _get_config_value("precision")

        if engine not in ("cwt", "normal", "fft"):
            raise ValueError("Engine must be `cwt`, `fft` or `normal`.")

        if reference_point is None and engine != "fft":
            warnings.warn(
                f"Engine `{engine}` isn't available without reference point, falling back to FFT based prediction.",
                PySprintWarning
            )
            engine = "fft"

        if engine == "fft":
            pred, _ = find_center(*ifft_method(self.x, self.y))
            if pred is None:
                if not silent:
                    print("Prediction failed, skipping.")
                return
            print(f"The predicted GD is ± {pred:.{precision}f} fs.")

            if hasattr(self, "params"):
                self.params[3] = pred
            return

        if engine == "cwt":
            widths = kwargs.pop("widths", np.arange(1, 20))
            floor_thres = kwargs.pop("floor_thres", 0.05)
            x_min, _, x_max, _ = self.detect_peak_cwt(
                widths=widths, floor_thres=floor_thres
            )

            # just validation
            _ = kwargs.pop("pmin", 0.1)
            _ = kwargs.pop("pmax", 0.1)
            _ = kwargs.pop("threshold", 0.35)

        else:
            pmin = kwargs.pop("pmin", 0.1)
            pmax = kwargs.pop("pmax", 0.1)
            threshold = kwargs.pop("threshold", 0.35)
            x_min, _, x_max, _ = self.detect_peak(
                pmin=pmin, pmax=pmax, threshold=threshold
            )

            # just validation
            _ = kwargs.pop("widths", np.arange(1, 10))
            _ = kwargs.pop("floor_thres", 0.05)

        if kwargs:
            raise TypeError(f"Invalid argument:{kwargs}")

        try:
            closest_val, idx1 = find_nearest(x_min, reference_point)
            m_closest_val, m_idx1 = find_nearest(x_max, reference_point)
        except (ValueError, IndexError):
            if not silent:
                print("Prediction failed, skipping.. ")
            return
        try:
            truncated = np.delete(x_min, idx1)
            second_closest_val, _ = find_nearest(truncated, reference_point)
        except (IndexError, ValueError):
            if not silent:
                print("Prediction failed, skipping.. ")
            return
        try:
            m_truncated = np.delete(x_max, m_idx1)
            m_second_closest_val, _ = find_nearest(m_truncated, reference_point)
        except (IndexError, ValueError):
            if not silent:
                print("Prediction failed, skipping.. ")
            return

        lowguess = 2 * np.pi / np.abs(closest_val - second_closest_val)
        highguess = 2 * np.pi / np.abs(m_closest_val - m_second_closest_val)

        #  estimate the GD with that
        if hasattr(self, "params"):
            self.params[3] = (lowguess + highguess) / 2

        if not silent:
            print(
                f"The predicted GD is ± {((lowguess + highguess) / 2):.{precision}f} fs"
                f" based on reference point of {reference_point:.{precision}f}."
            )

    def _safe_cast(self):
        """
        Return a copy of key attributes in order to prevent
        inplace modification.
        """
        x, y, ref, sam = (
            np.copy(self.x),
            np.copy(self.y),
            np.copy(self.ref),
            np.copy(self.sam),
        )
        return x, y, ref, sam

    @staticmethod
    def wave2freq(value):
        """Switches a single value between wavelength and angular frequency."""
        return (2 * np.pi * C_LIGHT) / value

    _dispatch = wave2freq.__func__

    @staticmethod
    def freq2wave(value):
        """Switches a single value between angular frequency and wavelength."""
        return Dataset._dispatch(value)

    def _check_domain(self):
        """
        Checks the domain of data just by looking at x axis' minimal value.
        Units are obviously not added yet, we work in nm and PHz...
        """
        try:
            if min(self.x) > 50:
                self.probably_wavelength = True
                self.unit = "nm"
            else:
                self.probably_wavelength = False
                self.unit = "PHz"

        # This is the first function to fail if the user sets up
        # wrong values. Usually..
        except TypeError as error:
            msg = ValueError(
                "The file could not be parsed properly."
            )
            raise msg from error

    @classmethod
    def parse_raw(
        cls,
        filename,
        ref=None,
        sam=None,
        skiprows=0,
        decimal=".",
        sep=None,
        delimiter=None,
        comment=None,
        usecols=None,
        names=None,
        swapaxes=False,
        na_values=None,
        skip_blank_lines=True,
        keep_default_na=False,
        meta_len=1,
        errors="raise",
        callback=None,
        parent=None,
        **kwargs
    ):
        """
        Dataset object alternative constructor.
        Helps to load in data just by giving the filenames in
        the target directory.

        Parameters
        ----------
        filename: `str`
            base interferogram
            file generated by the spectrometer
        ref: `str`, optional
            reference arm's spectra
            file generated by the spectrometer
        sam: `str`, optional
            sample arm's spectra
            file generated by the spectrometer
        skiprows: `int`, optional
            Skip rows at the top of the file. Default is `0`.
        decimal: `str`, optional
            Character recognized as decimal separator in the original dataset.
            Often `,` for European data.
            Default is `.`.
        sep: `str`, optional
            The delimiter in the original interferogram file.
            Default is `,`.
        delimiter: `str`, optional
            The delimiter in the original interferogram file.
            This is preferred over the `sep` argument if both given.
            Default is `,`.
        comment: `str`, optional
            Indicates remainder of line should not be parsed. If found at the beginning
            of a line, the line will be ignored altogether. This parameter must be a
            single character. Default is `'#'`.
        usecols: list-like or callable, optional
            If there a multiple columns in the file, use only a subset of columns.
            Default is [0, 1], which will use the first two columns.
        names: array-like, optional
            List of column names to use. Default is ['x', 'y']. Column marked
            with `x` (`y`) will be treated as the x (y) axis. Combined with the
            usecols argument it's possible to select data from a large number of
            columns.
        swapaxes: bool, optional
            Whether to swap x and y values in every parsed file. Default is False.
        na_values: scalar, str, list-like, or dict, optional
            Additional strings to recognize as NA/NaN. If dict passed, specific
            per-column NA values. By default the following values are interpreted as
            NaN: ‘’, ‘#N/A’, ‘#N/A N/A’, ‘#NA’, ‘-1.#IND’,
            ‘-1.#QNAN’, ‘-NaN’, ‘-nan’, ‘1.#IND’, ‘1.#QNAN’,
            ‘<NA>’, ‘N/A’, ‘NA’, ‘NULL’, ‘NaN’, ‘n/a’, ‘nan’, ‘null’.
        skip_blank_lines: bool
            If True, skip over blank lines rather than interpreting as NaN values.
            Default is True.
        keep_default_na: bool
            Whether or not to include the default NaN values when parsing the data.
            Depending on whether na_values is passed in, the behavior changes. Default
            is False. More information available at:
            https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html
        meta_len: `int`, optional
            The first `n` lines in the original file containing the meta
            information about the dataset. It is parsed to be dict-like.
            If the parsing fails, a new entry will be created in the
            dictionary with key `unparsed`.
            Default is `1`.
        errors: string, optional
            Determines the way how mismatching sized datacolumns behave.
            The default is `raise`, and it will raise on any error.
            If set to `force`, it will truncate every array to have the
            same shape as the shortest column. It truncates from
            the top of the file.
        callback : callable, optional
            The function that notifies parent objects about SPP related
            changes. In most cases the user should leave this empty. The
            default callback is only initialized if this object is constructed
            by the `pysprint.SPPMethod` object.
        parent : any class, optional
            The object which handles the callback function. In most cases
            the user should leave this empty.
        kwargs : dict, optional
            The window class to use in WFTMethod. Has no effect while using other
            methods. Must be a subclass of pysprint.core.windows.WindowBase.
        """

        parsed = _parse_raw(
            filename=filename,
            ref=ref,
            sam=sam,
            skiprows=skiprows,
            decimal=decimal,
            sep=sep,
            delimiter=delimiter,
            comment=comment,
            usecols=usecols,
            names=names,
            swapaxes=swapaxes,
            na_values=na_values,
            skip_blank_lines=skip_blank_lines,
            keep_default_na=keep_default_na,
            meta_len=meta_len
        )

        return cls(**parsed, errors=errors, callback=callback, parent=parent, **kwargs)

    def __str__(self):
        _unit = self._render_unit(self.unit)
        precision = _get_config_value("precision")
        string = dedent(
            f"""
        {type(self).__name__}
        ----------
        Parameters
        ----------
        Datapoints: {len(self.x)}
        Predicted domain: {'wavelength' if self.probably_wavelength else 'frequency'}
        Range: from {np.min(self.x):.{precision}f} to {np.max(self.x):.{precision}f} {_unit}
        Normalized: {self._is_normalized}
        Delay value: {str(self._format_delay()) + ' fs' if self._delay is not None else 'Not given'}
        SPP position(s): {str(self._format_positions()) + ' PHz' if np.all(self._positions) else 'Not given'}
        ----------------------------
        Metadata extracted from file
        ----------------------------
        """
        )
        string = re.sub('^\s+', '', string, flags=re.MULTILINE)
        string += json.dumps(self.meta, indent=4, sort_keys=True)
        return string

    def _repr_html_(self):  # TODO: move this to a separate template file
        _unit = self._render_unit(self.unit)
        precision = _get_config_value("precision")
        t = f"""
        <div id="header" class="row" style="height:10%;width:100%;">
        <div style='float:left' class="column">
        <table style="border:1px solid black;float:top;">
        <tbody>
        <tr>
        <td colspan=2 style="text-align:center">
        <font size="5">{type(self).__name__}</font>
        </td>
        </tr>
        <tr>
        <td colspan=2 style="text-align:center">
        <font size="3.5">Parameters</font>
        </td>
        </tr>
        <tr>
        <td style="text-align:center"><b>Datapoints<b></td>
            <td style="text-align:center"> {len(self.x)}</td>
        </tr>
        <tr>
            <td style="text-align:center"><b>Predicted domain<b> </td>
            <td style="text-align:center"> {'wavelength' if self.probably_wavelength else 'frequency'} </td>
        </tr>
        <tr>
        <td style="text-align:center"> <b>Range min</b> </td>
        <td style="text-align:center">{np.min(self.x):.{precision}f} {_unit}</td>
        </tr>
        <tr>
        <td style="text-align:center"> <b>Range max</b> </td>
        <td style="text-align:center">{np.max(self.x):.{precision}f} {_unit}</td>
        </tr>
        <tr>
        <td style="text-align:center"> <b>Normalized</b></td>
        <td style="text-align:center"> {self._is_normalized} </td>
        </tr>
        <tr>
        <td style="text-align:center"><b>Delay value</b></td>
        <td style="text-align:center">{str(self._format_delay()) + ' fs' if self._delay is not None else 'Not given'}</td>
        </tr>
        <tr>
        <td style="text-align:center"><b>SPP position(s)</b></td>
        <td style="text-align:center">{str(self._format_positions()) + ' PHz' if np.all(self._positions) else 'Not given'}</td>
        </tr>
        <tr>
        <td colspan=2 style="text-align:center">
        <font size="3.5">Metadata</font>
        </td>
        </tr>
        """
        jjstring = Template("""
        {% for key, value in meta.items() %}
           <tr>
        <th style="text-align:center"> <b>{{ key }} </b></th>
        <td style="text-align:center"> {{ value }} </td>
           </tr>
        {% endfor %}
            </tbody>
        </table>
        </div>
        <div style='float:leftt' class="column">""")
        rendered_fig = self._render_html_plot()
        return t + jjstring.render(meta=self.meta) + rendered_fig

    def _render_html_plot(self):
        fig, ax = plt.subplots(figsize=(7, 5))
        self.plot(ax=ax)
        plt.close()
        tmpfile = BytesIO()
        fig.savefig(tmpfile, format='png')
        encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8')
        html_fig = "<img src=\'data:image/png;base64,{}\'>".format(encoded)
        return html_fig

    @property
    def data(self):
        """
        Returns the *current* dataset as `pandas.DataFrame`.
        """
        if self._is_normalized:
            try:
                self._data = pd.DataFrame(
                    {
                        "x": self.x,
                        "y": self.y,
                        "sample": self.sam,
                        "reference": self.ref,
                        "y_normalized": self.y_norm,
                    }
                )
            except ValueError:
                self._data = pd.DataFrame({"x": self.x, "y": self.y})
        else:
            self._data = pd.DataFrame({"x": self.x, "y": self.y})
        return self._data

    # from : https://stackoverflow.com/a/15774013/11751294
    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        result.__dict__.update(self.__dict__)
        return result

    @property
    def is_normalized(self):
        """Retuns whether the dataset is normalized."""
        return self._is_normalized

    @inplacify
    def chdomain(self):
        """
        Changes from wavelength [nm] to ang. freq. [PHz]
        domain and vica versa.
        """
        self.x = (2 * np.pi * C_LIGHT) / self.x
        self._check_domain()
        if hasattr(self, "original_x"):
            self.original_x = self.x
        return self

    def detect_peak_cwt(self, widths, floor_thres=0.05, side="both"):
        """
        Basic algorithm to find extremal points in data
        using ``scipy.signal.find_peaks_cwt``.

        Parameters
        ----------
        widths : np.ndarray
            The widths passed to `find_peaks_cwt`.
        floor_thres : float
            Will be removed.
        side : str
            The side to use. Must be "both", "max" or "min".
            Default is "both".

        Returns
        -------
        xmax : `array-like`
            x coordinates of the maximums
        ymax : `array-like`
            y coordinates of the maximums
        xmin : `array-like`
            x coordinates of the minimums
        ymin : `array-like`
            y coordinates of the minimums

        Note
        ----
        When using "min" or "max" as side, all the detected minimal and
        maximal values will be returned, but only the given side will be
        recorded for further calculation.
        """
        if side not in ("both", "min", "max"):
            raise ValueError("Side must be 'both', 'min' or 'max'.")

        if hasattr(self, "_is_onesided"):
            self._is_onesided = side != "both"

        x, y, ref, sam = self._safe_cast()
        xmax, ymax, xmin, ymin = cwt(
            x, y, ref, sam, widths=widths, floor_thres=floor_thres
        )
        self.xmax = xmax
        self.xmin = xmin

        if side == "both":
            self.xmax = xmax
            self.xmin = xmin

        elif side == "min":
            self.xmin = xmin

        elif side == "max":
            self.xmax = xmax

        logger.info(f"{len(xmax)} max values and {len(xmin)} min values were found.")
        return xmax, ymax, xmin, ymin

    def savgol_fil(self, window=5, order=3):
        """
        Applies Savitzky-Golay filter on the dataset.

        Parameters
        ----------
        window : int
            Length of the convolutional window for the filter.
            Default is `10`.
        order : int
            Degree of polynomial to fit after the convolution.
            If not odd, it's incremented by 1. Must be lower than window.
            Usually it's a good idea to stay with a low degree, e.g 3 or 5.
            Default is 3.

        Note
        ----
        If arms were given, it will merge them into the `self.y` and
        `self.y_norm` variables. Also applies a linear interpolation o
        n dataset (and raises warning).
        """
        self.x, self.y_norm = savgol(
            self.x, self.y, self.ref, self.sam, window=window, order=order
        )
        self.y = self.y_norm
        self.ref = []
        self.sam = []
        warnings.warn(
            "Linear interpolation have been applied to data.", InterpolationWarning,
        )

    @inplacify
    def slice(self, start=None, stop=None):
        """
        Cuts the dataset on x axis.

        Parameters
        ----------
        start : float
            Start value of cutting interval.
            Not giving a value will keep the dataset's original minimum value.
            Note that giving `None` will leave original minimum untouched too.
            Default is `None`.
        stop : float
            Stop value of cutting interval.
            Not giving a value will keep the dataset's original maximum value.
            Note that giving `None` will leave original maximum untouched too.
            Default is `None`.

        Note
        ----
        If arms were given, it will merge them into the `self.y` and
        `self.y_norm` variables. After this operation, the arms' spectra
        cannot be retrieved.
        """
        self.x, self.y_norm = cut_data(
            self.x, self.y, self.ref, self.sam, start=start, stop=stop
        )
        self.ref = []
        self.sam = []
        self.y = self.y_norm
        # Just to make sure it's correctly shaped. Later on we might
        # delete this.
        if hasattr(self, "original_x"):
            self.original_x = self.x
        self._is_normalized = self._ensure_norm()
        return self

    def convolution(self, window_length, std=20):
        """
        Convolve the dataset with a specified Gaussian window.

        Parameters
        ----------
        window_length : int
            Length of the gaussian window.
        std : float
            Standard deviation of the gaussian window.
            Default is `20`.

        Note
        ----
        If arms were given, it will merge them into the `self.y` and
        `self.y_norm` variables.
        Also applies a linear interpolation on dataset.
        """
        self.x, self.y_norm = convolution(
            self.x, self.y, self.ref, self.sam, window_length, standev=std
        )
        self.ref = []
        self.sam = []
        self.y = self.y_norm
        warnings.warn(
            "Linear interpolation have been applied to data.", InterpolationWarning,
        )

    @inplacify
    def resample(self, N, kind="linear", **kwds):
        """
        Resample the interferogram to have `N` datapoints.

        Parameters
        ----------
        N : int
            The number of datapoints required.
        kind : str, optional
            The type of interpolation to use. Default is `linear`.
        kwds : optional
            Additional keyword argument to pass to `scipy.interpolate.interp1d`.

        Raises
        ------
        PySprintWarning, if trying to subsample to lower `N` datapoints than original.
        """
        f = interp1d(self.x, self.y_norm, kind, **kwds)
        if N < len(self.x):
            N = len(self.x)
            warnings.warn(
                "Trying to resample to lower resolution, keeping shape..", PySprintWarning
            )
        xnew = np.linspace(np.min(self.x), np.max(self.x), N)
        ynew = f(xnew)
        setattr(self, "x", xnew)
        setattr(self, "y_norm", ynew)
        return self

    def detect_peak(
        self, pmax=0.1, pmin=0.1, threshold=0.1, except_around=None, side="both"
    ):
        """
        Basic algorithm to find extremal points in data
        using ``scipy.signal.find_peaks``.

        Parameters
        ----------
        pmax : float
            Prominence of maximum points.
            The lower it is, the more peaks will be found.
            Default is `0.1`.
        pmin : float
            Prominence of minimum points.
            The lower it is, the more peaks will be found.
            Default is `0.1`.
        threshold : float
            Sets the minimum distance (measured on y axis) required for a
            point to be accepted as extremal.
            Default is 0.
        except_around : interval (array or tuple),
            Overwrites the threshold to be 0 at the given interval.
            format is `(lower, higher)` or `[lower, higher]`.
            Default is None.
        side : str
            The side to use. Must be "both", "max" or "min".
            Default is "both".

        Returns
        -------
        xmax : `array-like`
            x coordinates of the maximums
        ymax : `array-like`
            y coordinates of the maximums
        xmin : `array-like`
            x coordinates of the minimums
        ymin : `array-like`
            y coordinates of the minimums

        Note
        ----
        When using "min" or "max" as side, all the detected minimal and
        maximal values will be returned, but only the given side will be
        recorded for further calculation.
        """
        if side not in ("both", "min", "max"):
            raise ValueError("Side must be 'both', 'min' or 'max'.")

        if hasattr(self, "_is_onesided"):
            self._is_onesided = side != "both"

        x, y, ref, sam = self._safe_cast()
        xmax, ymax, xmin, ymin = find_peak(
            x,
            y,
            ref,
            sam,
            pro_max=pmax,
            pro_min=pmin,
            threshold=threshold,
            except_around=except_around,
        )
        if side == "both":
            self.xmax = xmax
            self.xmin = xmin

        elif side == "min":
            self.xmin = xmin

        elif side == "max":
            self.xmax = xmax

        logger.info(f"{len(xmax)} max values and {len(xmin)} min values were found.")
        return xmax, ymax, xmin, ymin

    def _plot_SPP_if_valid(self, ax=None, **kwargs):
        """
        Mark SPPs on the plot if they are valid.
        """
        if ax is None:
            ax = self.plt
        if isinstance(self.positions, numbers.Number):
            if is_inside(self.positions, self.x):
                x_closest, idx = find_nearest(self.x, self.positions)
                try:
                    ax.plot(x_closest, self.y_norm[idx], **kwargs)
                except (ValueError, TypeError):
                    ax.plot(x_closest, self.y[idx], **kwargs)

        if isinstance(self.positions, np.ndarray) or isinstance(
                self.positions, Iterable
        ):
            if np.array(self.positions).ndim == 0:
                self.positions = np.atleast_1d(self.positions)
            # iterate over 0-d array: need to cast np.atleast_1d
            for i, val in enumerate(self.positions):
                if is_inside(self.positions[i], self.x):
                    x_closest, idx = find_nearest(self.x, self.positions[i])
                    try:
                        ax.plot(x_closest, self.y_norm[idx], **kwargs)
                    except (ValueError, TypeError):
                        ax.plot(x_closest, self.y[idx], **kwargs)

    def _format_delay(self):
        if self.delay is None:
            return ""
        if isinstance(self.delay, np.ndarray):
            if self.delay.size == 0:
                return 0
            delay = np.atleast_1d(self.delay).flatten()
            return delay[0]
        elif isinstance(self.delay, (list, tuple)):
            return self.delay[0]
        elif isinstance(self.delay, numbers.Number):
            return self.delay
        elif isinstance(self.delay, str):
            try:
                delay = float(self.delay)
            except ValueError as e:
                raise TypeError("Delay value not understood.") from e
            return delay
        else:
            raise TypeError("Delay value not understood.")

    def _format_positions(self):
        if self.positions is None:
            return "Not given"
        if isinstance(self.positions, np.ndarray):
            positions = np.atleast_1d(self.positions).flatten()
            return ", ".join(map(str, positions))
        elif isinstance(self.positions, (list, tuple)):
            return ", ".join(map(str, self.positions))
        elif isinstance(self.positions, numbers.Number):
            return self.positions
        elif isinstance(self.positions, str):
            split = self.positions.split(",")
            try:
                positions = [float(p) for p in split]
            except ValueError as e:
                raise TypeError("Delay value not understood.") from e
            return ", ".join(map(str, positions))
        else:
            raise TypeError("Delay value not understood.")

    def _prepare_SPP_data(self):
        pos_x, pos_y = [], []
        if self.positions is not None:
            position = np.array(self.positions, dtype=np.float64).flatten()
            for i, val in enumerate(position):
                if is_inside(position[i], self.x):
                    x_closest, idx = find_nearest(self.x, position[i])
                    try:
                        pos_x.append(x_closest)
                        pos_y.append(self.y_norm[idx])
                    except (ValueError, TypeError):
                        pos_x.append(x_closest)
                        pos_y.append(self.y[idx])
            pos_x = np.array(pos_x)
            pos_y = np.array(pos_y)
        return pos_x, pos_y

    # TODO: Remove the duplicated logic. This function is in pysprint's init.py
    # and we can't circular import it. It should be moved to a separate file.
    def plot_outside(self, *args, **kwargs):
        """
        Plot the current dataset out of the notebook. For detailed
        parameters see `Dataset.plot` function.
        """
        backend = kwargs.pop("backend", "Qt5Agg")
        original_backend = plt.get_backend()
        try:
            plt.switch_backend(backend)
            self.plot(*args, **kwargs)
            plt.show(block=True)
        except (AttributeError, ImportError, ModuleNotFoundError) as err:
            raise ValueError(
                f"Couldn't set backend {backend}, you should manually "
                "change to an appropriate GUI backend. (Matplotlib 3.3.1 "
                "is broken. In that case use backend='TkAgg')."
            ) from err
        finally:
            plt.switch_backend(original_backend)

    def plot(self, ax=None, title=None, xlim=None, ylim=None, **kwargs):
        """
        Plot the dataset.

        Parameters
        ----------
        ax : matplotlib.axes.Axes, optional
            An axis to draw the plot on. If not given, it will plot
            on the last used axis.
        title : str, optional
            The title of the plot.
        xlim : tuple, optional
            The limits of x axis.
        ylim : tuple, optional
            The limits of y axis.
        kwargs : dict, optional
            Additional keyword arguments to pass to plot function.

        Note
        ----
        If SPP positions are correctly set, it will mark them on plot.
        """
        datacolor = kwargs.pop("color", "red")
        nospp = kwargs.pop("nospp", False)
        _unit = self._render_unit(self.unit, mpl=True)
        xlabel = f"$\lambda\,[{_unit}]$" if self.probably_wavelength else f"$\omega\,[{_unit}]$"
        overwrite = kwargs.pop("overwrite", None)
        if overwrite is not None:
            xlabel = overwrite

        if ax is None:
            ax = self.plt
            self.plt.ylabel("I")
            self.plt.xlabel(xlabel)
            if xlim:
                self.plt.xlim(xlim)
            if ylim:
                self.plt.ylim(ylim)
            if title:
                self.plt.title(title)
        else:
            ax.set(ylabel="I")
            ax.set(xlabel=xlabel)
            if xlim:
                ax.set(xlim=xlim)
            if ylim:
                ax.set(ylim=ylim)
            if title:
                ax.set(title=title)

        if np.iscomplexobj(self.y):
            ax.plot(self.x, np.abs(self.y), color=datacolor, **kwargs)
        else:
            try:
                ax.plot(self.x, self.y_norm, color=datacolor, **kwargs)
            except (ValueError, TypeError):
                ax.plot(self.x, self.y, color=datacolor, **kwargs)
        if not nospp:
            self._plot_SPP_if_valid(ax=ax, color="black", marker="o", markersize=10, label="SPP")

    def show(self):
        """
        Equivalent with plt.show().
        """
        self.plt.show(block=True)

    @inplacify
    def normalize(self, filename=None, smoothing_level=0):
        """
        Normalize the interferogram by finding upper and lower envelope
        on an interactive matplotlib editor. Points can be deleted with
        key `d` and inserted with key `i`. Also points can be dragged
        using the mouse. On complete just close the window. Must be
        called with interactive backend. The best practice is to call
        this function inside `~pysprint.interactive` context manager.

        Parameters
        ----------
        filename : str, optional
            Save the normalized interferogram named by filename in the
            working directory. If not given it will not be saved.
            Default None.
        smoothing_level : int, optional
            The smoothing level used on the dataset before finding the
            envelopes. It applies Savitzky-Golay filter under the hood.
            Default is 0.
        """
        x, y, _, _ = self._safe_cast()
        if smoothing_level != 0:
            x, y = savgol(x, y, [], [], window=smoothing_level)
        _l_env = DraggableEnvelope(x, y, "l")
        y_transform = _l_env.get_data()
        _u_env = DraggableEnvelope(x, y_transform, "u")
        y_final = _u_env.get_data()
        self.y = y_final
        self.y_norm = y_final
        self._is_normalized = True
        self.plt.title("Final")
        self.plot()
        self.show()
        if filename:
            if not filename.endswith(".txt"):
                filename += ".txt"
            np.savetxt(filename, np.column_stack((self.x, self.y)), delimiter=",")
            print(f"Successfully saved as {filename}.")
        return self

    def open_SPP_panel(self, header=None):
        """
        Opens the interactive matplotlib editor for SPP data.
        Use `i` button to add a new point, use `d` key to delete one.
        The delay field is parsed to only get the numeric values.
        Close the window on finish. Must be called with interactive
        backend. The best practice is to call this function inside
        `~pysprint.interactive` context manager.

        Parameters
        ----------
        header : str, optional
            An arbitary string to include as header. This can be
            any attribute's name, or even metadata key.
        """
        if header is not None:
            if isinstance(header, str):
                head = getattr(self, header, None)
                metahead = self.meta.get(header, None)
                info = head or metahead or header
            else:
                info = None
        else:
            info = None
        spp_x, spp_y = self._prepare_SPP_data()
        _spp = SPPEditor(
            self.x, self.y_norm, info=info, x_pos=np.array(spp_x), y_pos=np.array(spp_y)
        )

        textbox = _spp._get_textbox()
        textbox.set_val(self._format_delay())

        _spp._show()

        # We need to split this into separate lines,
        # because empty results are broadcasted twice.
        delay, positions = _spp.get_data()
        self.delay, self.positions = delay, positions

    def emit(self):
        """
        Emit the current SPP data.

        Returns
        -------
        delay : np.ndarray
            The delay value for the current dataset, shaped exactly like
            positions.
        positions : np.ndarray
            The given SPP positions.
        """
        if self.positions is None:
            raise ValueError("SPP positions are missing.")
        if self.delay is None:
            raise ValueError("Delay value is missing.")
        # Important: Use underscored variables to avoid invoking the
        # setter again, which invokes the callback again, resulting in
        # a never-ending cycle.
        if not isinstance(self._positions, np.ndarray):
            self._positions = np.asarray(self.positions)
        if not isinstance(self.delay, np.ndarray):
            self._delay = np.ones_like(self._positions) * self._delay
        return np.atleast_1d(self.delay), np.atleast_1d(self.positions)

    def set_SPP_data(self, delay, positions, force=False):
        """
        Set the SPP data (delay and SPP positions) for the dataset.

        Parameters
        ----------
        delay : float
            The delay value that belongs to the current interferogram.
            Must be given in `fs` units.
        positions : float or iterable
            The SPP positions that belong to the current interferogram.
            Must be float or sequence of floats (tuple, list, np.ndarray, etc.)
        force : bool, optional
            Can be used to set specific SPP positions which are outside of
            the dataset's range. Note that in most cases you should avoid using
            this option. Default is `False`.

        Note
        ----
        Every position given must be in the current dataset's range, otherwise
        `ValueError` is raised. Be careful to change domain to frequency before
        feeding values into this function.
        """
        if not isinstance(delay, float):
            delay = float(delay)
        delay = np.array(np.ones_like(positions) * delay)
        self.delay = delay
        if force:
            with suppress(ValueError):
                self._positions = positions
        else:
            self.positions = positions
        # trigger the callback here too
        try:
            self.callback(self, self.parent)
        except ValueError:
            pass


class MimickedDataset(Dataset):
    '''
    Class that pretends to be a dataset, but its x-y values are missing.
    It allows to set delay and SPP positions arbitrarily.
    '''
    def __init__(self, delay, positions, *args, **kwargs):
        if delay is None or positions is None:
            raise ValueError("must specify SPP data.")

        x = np.empty(1)
        y = np.empty(1)

        super().__init__(x=x, y=y, *args, **kwargs)

        self.set_SPP_data(delay=delay, positions=positions, force=True)

    @contextmanager
    def _suppress_callbacks(self):
        try:
            self.restore_parent = self.parent
            self.restore_callback = self.callback
            self.parent = None
            self.callback = lambda x, y: (_ for _ in ()).throw(ValueError('mimicked'))
            yield
        finally:
            self.parent, self.callback = self.restore_parent, self.restore_callback

    def plot(self, *args, **kwargs):
        if self.x.size == 1:
            self.plt.text(0.31, 0.5, 'Dataset is missing.', size=15)

    # Redefine getter-setter without boundscheck
    @property
    def positions(self):
        return self._positions

    @positions.setter
    def positions(self, value):
        if isinstance(value, np.ndarray) or isinstance(value, Iterable):
            for val in value:
                if not isinstance(val, numbers.Number):
                    raise ValueError(
                        f"Expected numeric values, got {type(val)} instead."
                    )

        self._positions = value
        try:
            self.callback(self, self.parent)
        except ValueError:
            pass  # delay or position is missing

    def to_dataset(self, x, y=None, ref=None, sam=None, parse=True, **kwargs):
        if parse:
            if y is not None:
                raise ValueError("cannot specify `y` explicitly if `parse=True`.")
            self.parent.ifg_names.append(x)
            if ref is not None and sam is not None:
                self.parent.sam_names.append(sam)
                self.parent.ref_names.append(ref)

            ds = Dataset.parse_raw(x, ref=ref, sam=sam, callback=self.callback, parent=self.parent, **kwargs)

        else:
            ds = Dataset(x=x, y=y, ref=ref, sam=sam, callback=self.callback, parent=self.parent, **kwargs)

        # replace the MimickedDataset with the real one
        # FIXME: need to invalidate 1 item in cache, not all

        idx = self.parent._mimicked_index(self)
        ds.parent._mimicked_set[idx] = ds
        ds.parent.__getitem__.cache_clear()

        # drop the reference
        self.parent._container.pop(self, None)
        self.parent = None
        self.callback = lambda x, y: (_ for _ in ()).throw(ValueError('mimicked'))

        # with self._suppress_callbacks():
        ds.set_SPP_data(delay=self.delay, positions=self.positions, force=True)
        return ds