from __future__ import absolute_import, division, print_function

import os
import warnings
import numpy as np
from functools import partial

from astropy import units as u
from astropy.nddata import StdDevUncertainty, InverseVariance, VarianceUncertainty
from astropy.utils.console import ProgressBar
from astropy.wcs import WCS

from .nikamap import NikaMap
from .utils import _shuffled_average, cpu_count

__all__ = ["HalfDifference", "Jackknife", "Bootstrap"]

def compare_header(header_ref, header_target):
    """Crude comparison of two header

    header_ref : astropy.io.fits.Header
        the reference header
    header_target : astropy.io.fits.Header
        the target header to check

    This will raise assertion error if the two header are not equivalent
    wcs_ref = WCS(header_ref)
    wcs_target = WCS(header_target)

    assert wcs_ref.wcs == wcs_target.wcs, "Different header found"
    for key in ["UNIT", "NAXIS1", "NAXIS2"]:
        if key in header_ref:
            assert header_ref[key] == header_target[key], "Different key found"

def check_filenames(filenames):
    """check filenames existence

    filenames : list of str
        filenames list to be checked

    list of str
        curated list of files
    _filenames = []
    for filename in filenames:
        if os.path.isfile(filename):
            warnings.warn("{} does not exist, removing from list".format(filename), UserWarning)
    return _filenames

class MultiScans(object):
    """A class to hold multi single scans from a list of fits files.

    This acts as a python lazy iterator and/or a callable

    filenames : list or `~MultiScans` object
        the list of fits files to produce the Jackknifes or an already filled object
    ipython_widget : bool, optional
        If True, the progress bar will display as an IPython notebook widget.
    ignore_header : bool, optional
        if True, the check on header is ignored
    n : int
        the number of iteration for the iterator

    A crude check is made on the wcs of each map when instanciated

    dataclass = None
    filenames = None
    header = None
    unit = None
    shape = None
    datas = None
    weights = None
    hits = None
    mask = None
    extra_kwargs = dict()

    def __init__(self, filenames, n=None, ipython_widget=False, ignore_header=False, dataclass=NikaMap, **kwd):
        self.i = 0
        self.n = n
        self.dataclass = dataclass
        self.kwargs = kwd
        self.ipython_widget = ipython_widget

        if isinstance(filenames, MultiScans):
            data = filenames

            self.filenames = data.filenames
            self.header = data.header
            self.unit = data.unit
            self.shape = data.shape
            self.datas = data.datas
            self.weights = data.weights
            self.hits = data.hits
            self.mask = data.mask

            for key in ["sampling_freq", "primary_header"]:
                if hasattr(data, key):
                    self.extra_kwargs[key] = getattr(data, key)
            self.filenames = check_filenames(filenames)

            nm = self.dataclass.read(self.filenames[0], **kwd)

            self.header = nm.meta
            self.unit = nm.unit
            self.shape = nm.shape

            for key in ["sampling_freq", "primary_header"]:
                if hasattr(nm, key):
                    self.extra_kwargs[key] = getattr(nm, key)

            # This is a low_mem=False case ...
            # TODO: How to refactor that for low_mem=True ?
            datas = np.zeros((len(self.filenames),) + self.shape)
            weights = np.zeros((len(self.filenames),) + self.shape)
            hits = np.zeros(self.shape)

            for i, filename in enumerate(ProgressBar(self.filenames, ipython_widget=self.ipython_widget)):
                nm = self.dataclass.read(filename, **kwd)
                    compare_header(self.header, nm.meta)
                except AssertionError as e:
                    if ignore_header:
                        warnings.warn("{} for {}".format(e, filename), UserWarning)
                        raise ValueError("{} for {}".format(e, filename))

                datas[i, :, :] = nm.data
                with np.errstate(invalid="ignore", divide="ignore"):
                    weights[i, :, :] = nm.uncertainty.array**-2
                hits += nm.hits

                # make sure that we do not have nans in the data
                unobserved = nm.hits == 0
                datas[i, unobserved] = 0
                weights[i, unobserved] = 0

            self.datas = datas
            self.weights = weights
            self.hits = hits
            self.mask = hits == 0

    def __len__(self):
        # to retrieve the legnth of the iterator, enable ProgressBar on it
        return self.n

    def __iter__(self):
        # Iterators are iterables too.
        # Adding this functions to make them so.
        return self

    def __call__(self):
        """The main method which should be overrided

        should return a  :class:`nikamap.NikaMap`

    def __next__(self):
        """Iterator on the objects"""
        if self.n is None or self.i < self.n:
            # Produce data until last iter
            self.i += 1
            data = self.__call__()
            raise StopIteration()

        return data

class HalfDifference(MultiScans):
    """A class to create weighted half differences uncertainty maps from a list of scans.

    This acts as a python lazy iterator and/or a callable

    filenames : list
        the list of fits files to produce the Jackknifes
    ipython_widget : bool, optional
        If True, the progress bar will display as an IPython notebook widget.
    n : int
        the number of Jackknifes maps to be produced in the iterator

            if set to `None`, produce only one weighted average of the maps

    parity_threshold : float
        mask threshold between 0 and 1 to keep partially jackknifed area
        * 1 pure jackknifed
        * 0 partially jackknifed, keep all

    A crude check is made on the wcs of each map when instanciated

    def __init__(self, filenames, parity_threshold=1, **kwd):
        super(HalfDifference, self).__init__(filenames, **kwd)
        self.parity_threshold = parity_threshold

        # Create weights for Half differences
        jk_weights = np.ones(len(self.filenames))

        if self.n is not None:
            jk_weights[::2] *= -1

        if self.n is not None and len(self.filenames) % 2:
            warnings.warn("Even number of files, dropping a random file", UserWarning)
            jk_weights[-1] = 0

        assert np.sum(jk_weights != 0), "Less than 2 existing files in filenames"

        self.jk_weights = jk_weights

    def parity_threshold(self):
        return self._parity

    def parity_threshold(self, value):
        if value is not None and isinstance(value, (int, float)) and 0 <= value <= 1:
            self._parity = value
            raise TypeError("parity must be between 0 and 1")

    def __call__(self):
        """Compute a Half Difference dataset

            a half difference data set

        with np.errstate(invalid="ignore", divide="ignore"):
            e_data = 1 / np.sqrt(np.sum(self.weights, axis=0))
            data = np.sum(self.datas * self.weights * self.jk_weights[:, np.newaxis, np.newaxis], axis=0) * e_data**2
            parity = np.mean((self.weights != 0) * self.jk_weights[:, np.newaxis, np.newaxis], axis=0)
            # TBC: In principle we should use a weighted parity to avoid different scans/weights problems
            # weighted_parity = np.sum(self.weights * self.jk_weights[:, np.newaxis, np.newaxis], axis=0) * e_data ** 2

        if self.n is not None:
            mask = (1 - np.abs(parity)) < self.parity_threshold
            mask = parity < self.parity_threshold

        mask = mask | self.mask

        data[mask] = np.nan
        e_data[mask] = np.nan

        # TBC: hits will have a different mask here....
        data = self.dataclass(

        return data  # , weighted_parity

class Jackknife(MultiScans):
    """A class to create weighted Jackknife maps from a list of scans.

    This acts as a python lazy iterator and/or a callable

    filenames : list
        the list of fits files to produce the Jackknifes
    n_samples : int
        The number of (sub) samples to use (from 2 to len(filenames))
    parity_threshold : float
        mask threshold between 0 and 1 to keep partially jackknifed area
        * 1 pure jackknifed
        * 0 partially jackknifed, keep all
    ipython_widget : bool, optional
        If True, the progress bar will display as an IPython notebook widget.
    n : int
        the number of Jackknifes maps to be produced by the iterator

    A crude check is made on the wcs of each map when instanciated

    def __init__(self, filenames, n_samples=None, parity_threshold=1, **kwd):
        super(Jackknife, self).__init__(filenames, **kwd)

        assert len(self.filenames) > 1, "Less than 2 existing files in filenames"

        self.n_samples = n_samples  # Will create the indexes for the sub-samples
        self.parity_threshold = parity_threshold

    def parity_threshold(self):
        return self._parity

    def parity_threshold(self, value):
        if value is not None and isinstance(value, (int, float)) and 0 <= value <= 1:
            self._parity = value
            raise TypeError("parity must be between 0 and 1")

    def n_samples(self):
        return self._n_samples

    def n_samples(self, value):
        if value is None:
            value = len(self.filenames)

        assert (2 <= value) and (value <= len(self.filenames)), "n_samples must be between 2 and the number of scans"

        self._n_samples = value

        # Check compatibility between n_samples and filenames length
        n_filenames = len(self.filenames)
        remainder = n_filenames % value

        if remainder:
                "Remainder in number of files for {} samples, dropping the last {}".format(value, remainder),
            n_filenames -= remainder

        assert n_filenames, "Less than 2 existing files in filenames"

        # Create the indexes for the sub-samples
        indexes = np.repeat(np.arange(value), n_filenames // value)

        if remainder:
            indexes = np.concatenate([indexes, np.full(remainder, np.nan)])

        self.indexes = indexes

    def __call__(self):
        """Compute a jackknifed dataset

            a jackknifed data set

        with np.errstate(invalid="ignore", divide="ignore"):
            # Compute sub-samples
            sub_datas = []
            sub_weights = []
            for idx in range(self.n_samples):
                mask = self.indexes == idx
                data, weight = np.ma.average(self.datas[mask], weights=self.weights[mask], axis=0, returned=True)

            sub_datas = np.ma.array(sub_datas)
            sub_weights = np.ma.array(sub_weights)

            data = np.ma.average(sub_datas, weights=sub_weights, axis=0)
            # unweighted sample variance
            V1 = self.n_samples
            e_data = np.sqrt(np.sum((sub_datas - data) ** 2, axis=0) / (V1 * (V1 - 1)))
            # TODO : weighted sample variance (NOT WORKING !!!)
            # V1 = np.sum(sub_weights, axis=0)
            # V2 = np.sum(sub_weights**2, axis=0)
            # e_data = np.sqrt(np.sum(sub_weights * (sub_datas - data)**2, axis=0)  / (V1 - V2 / V1) )
            # e_data = e_data.filled(np.nan)

            parity = np.mean(sub_weights != 0, axis=0)

            # TBC: In principle we should use a weighted parity to avoid different scans/weights problems

            mask = parity < self.parity_threshold

        mask = mask | self.mask

        data[mask] = np.nan
        e_data[mask] = np.nan

        # TBC: hits will have a different mask here....
        data = self.dataclass(

        return data  # , weighted_parity

class Bootstrap(MultiScans):
    """A class to create bootstraped maps from a list of scans.

    This acts as a python lazy iterator and/or a callable

    filenames : list
        the list of fits files to produce the Jackknifes
    n_bootstrap : int
        the number of realization to produce a bootsrapped map, by default 20 times the length of the input filename list
    ipython_widget : bool, optional
        If True, the progress bar will display as an IPython notebook widget.
    n : int
        the number of bootstrap maps to be produced by the iterator

    A crude check is made on the wcs of each map when instanciated

    def __init__(self, filenames, n_bootstrap=None, **kwd):
        super(Bootstrap, self).__init__(filenames, **kwd)

        if n_bootstrap is None:
            n_bootstrap = 50 * len(self.filenames)

        self.n_bootstrap = n_bootstrap

    def __call__(self):
        """Compute a bootstraped map

            a bootstraped data set

        _ = partial(_shuffled_average, datas=self.datas, weights=self.weights)

        bs_array = np.concatenate(
                np.array_split(np.arange(self.n_bootstrap), cpu_count()),

        bs_array[bs_array == 0] = np.nan
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", RuntimeWarning)
            data = np.nanmean(bs_array, axis=0)
            e_data = np.nanstd(bs_array, axis=0)

        # Mask unobserved regions
        unobserved = self.hits == 0
        data[unobserved] = np.nan
        e_data[unobserved] = np.nan

        data = self.dataclass(

        return data