from __future__ import absolute_import, division, print_function

import warnings
from copy import deepcopy
from functools import partial
from itertools import product
from pathlib import Path

import numpy as np
from astropy import units as u
from astropy.convolution import Box2DKernel, Gaussian2DKernel, Kernel2D
from astropy.convolution.kernels import _round_up_to_odd_integer
from astropy.coordinates import match_coordinates_sky
from astropy.io import fits, registry
from astropy.modeling import models
from astropy.modeling.fitting import LevMarLSQFitter
from astropy.modeling.utils import ellipse_extent
from astropy.nddata import Cutout2D, InverseVariance, NDDataArray, NDUncertainty, StdDevUncertainty, VarianceUncertainty
from astropy.nddata.ccddata import (
from astropy.stats import SigmaClip
from astropy.stats.funcs import gaussian_fwhm_to_sigma, gaussian_sigma_to_fwhm
from astropy.table import Column, MaskedColumn, Table
from astropy.utils.console import ProgressBar
from astropy.utils.exceptions import AstropyWarning
from astropy.wcs import WCS, InconsistentAxisTypesError
from astropy.wcs.utils import proj_plane_pixel_scales
from photutils.background import LocalBackground, MedianBackground
from photutils.centroids import centroid_2dg  # , centroid_sources
from photutils.datasets import make_gaussian_sources_image
from photutils.detection import find_peaks
from photutils.psf import PSFPhotometry, SourceGrouper
from powspec import power_spectral_density
from scipy import signal, stats
from scipy.optimize import curve_fit

from .utils import (

Jy_beam = u.Jy / u.beam

__all__ = ["ContMap"]

class ContBeam(Kernel2D):
    """ContBeam describe the beam of a ContMap.

    By default the beams are derived from :class:`astropy.convolution.Kernel2D` but follow the api of :class:`radio_beam.Beam'
    and implement 2D gaussian function by default, but the class should be able to handle arbitrary beam.


    See also


    _major = None
    _minor = None
    _pa = None
    _pixscale = None
    default_unit = None
    support_scaling = 8

    def __init__(
        Create a new Gaussian beam

        major : :class:`~astropy.units.Quantity` with angular equivalency
            The FWHM major axis
        minor : :class:`~astropy.units.Quantity` with angular equivalency
            The FWHM minor axis
        pa : :class:`~astropy.units.Quantity` with angular equivalency
            The beam position angle
        area : :class:`~astropy.units.Quantity` with steradian equivalency
            The area of the beam.  This is an alternative to specifying the
            major/minor/PA, and will create those values assuming a circular
            Gaussian beam.
        default_unit : :class:`~astropy.units.Unit`
            The unit to impose on major, minor if they are specified as floats
        pixscale : :class:`~astropy.units.Quantity` with angular equivalency
            the size of the pixel
        array : array_like, optional
            replace the gaussian beam by this array

        if area is not None:
            if major is not None:
                raise ValueError("Can only specify one of {major,minor,pa} " "and {area}")
            if not area.unit.is_equivalent(u.sr):
                raise ValueError("Area unit should be equivalent to steradian.")
            rad = np.sqrt(area / (2 * np.pi))
            major = rad * gaussian_sigma_to_fwhm
            minor = rad * gaussian_sigma_to_fwhm
            pa = 0.0 * u.deg

        # give specified values priority
        if major is not None:
            if u.deg.is_equivalent(major):
                major = major
                warnings.warn("Assuming major axis has been specified in degrees")
                major = major * u.deg
        if minor is not None:
            if u.deg.is_equivalent(minor):
                minor = minor
                warnings.warn("Assuming minor axis has been specified in degrees")
                minor = minor * u.deg
        if pa is not None:
            if u.deg.is_equivalent(pa):
                pa = pa
                warnings.warn("Assuming position angle has been specified in degrees")
                pa = pa * u.deg
            pa = 0.0 * u.deg

        # some sensible defaults
        if minor is None:
            minor = major

        if major is not None and minor > major:
            raise ValueError("Minor axis greater than major axis.")

        if meta is None:
            self.meta = {}
        elif isinstance(meta, dict):
            self.meta = meta
            raise TypeError("metadata must be a dictionary")

        self._major = major
        self._minor = minor
        self._pa = pa
        self.default_unit = default_unit
        self.support_scaling = support_scaling

        self._pixscale = pixscale

        if self._pixscale is None:
            raise ValueError("You must define pixscale.")

        if self._major is not None:
            stddev_maj = (self.stddev_maj / self.pixscale).decompose()
            stddev_min = (self.stddev_min / self.pixscale).decompose()
            angle = (90 * u.deg + self.pa).to(u.radian).value

            self._model = models.Gaussian2D(
                1 / (2 * np.pi * stddev_maj * stddev_min), 0, 0, x_stddev=stddev_maj, y_stddev=stddev_min, theta=angle

            max_extent = np.max(ellipse_extent(stddev_maj, stddev_min, angle))
            self._default_size = _round_up_to_odd_integer(self.support_scaling * 2 * max_extent)
            super(ContBeam, self).__init__(**kwargs)

        elif array is not None:
            super(ContBeam, self).__init__(array=array, **kwargs)
            raise TypeError("Must specify either major or array")

    def __repr__(self):
        repr = "ContBeam: "
        if self._major is not None:
            repr += "BMAJ={0} BMIN={1} BPA={2} as ".format(
                self.major.to(self.default_unit), self.minor.to(self.default_unit), self.pa.to(u.deg)
        repr += "{} Kernel2D at pixscale {}".format(self._array.shape, self._pixscale)
        return repr

    def to_header_keywords(self):  # pragma: no cover
        return {
            "BMAJ": self.major.to(u.deg).value,
            "BMIN": self.minor.to(u.deg).value,
            "BPA": self.pa.to(u.deg).value,

    def ellipse_to_plot(self, xcen, ycen, pixscale):  # pragma: no cover
        Return a matplotlib ellipse for plotting
        xcen : int
            Center pixel in the x-direction.
        ycen : int
            Center pixel in the y-direction.
        pixscale : `~astropy.units.Quantity`
            Conversion from degrees to pixels.
            Ellipse patch object centered on the given pixel coordinates.
        from matplotlib.patches import Ellipse

        return Ellipse(
            (xcen, ycen),
            width=(self.major.to(u.deg) / pixscale).to(u.dimensionless_unscaled).value,
            height=(self.minor.to(u.deg) / pixscale).to(u.dimensionless_unscaled).value,
            # PA is 90 deg offset from x-y axes by convention
            # (it is angle from NCP)
            angle=(self.pa + 90 * u.deg).to(u.deg).value,

    def major(self):
        """Beam FWHM Major Axis"""
        return self._major

    def stddev_maj(self):
        """Beam Stddev Major Axis"""
        return self._major * gaussian_fwhm_to_sigma

    def minor(self):
        """Beam FWHM Minor Axis"""
        return self._minor

    def stddev_min(self):
        """Beam Stddev Minor Axis"""
        return self._minor * gaussian_fwhm_to_sigma

    def pa(self):
        return self._pa

    def pixscale(self):
        return self._pixscale

    def sr(self):
        if self.major is not None:
            return (2 * np.pi * (self.major * self.minor) * gaussian_fwhm_to_sigma**2).to(u.sr)
            return (self._array.sum() / self._array.max() * (self.pixscale**2)).to(u.sr)

    def as_kernel(self, pixscale, **kwargs):
        Returns an elliptical Gaussian kernel of the beam.
        .. warning::
            This method is not aware of any misalignment between pixel
            and world coordinates.
        pixscale : `~astropy.units.Quantity`
            Conversion from angular to pixel size.
        kwargs : passed to EllipticalGaussian2DKernel
        if pixscale == self._pixscale:
            return self
        elif self.major is not None:
            return ContBeam(
            raise ValueError("Do not rescale array kernel with different pixscale")

    def convolve(self, other):
        Convolve one beam with another.
        other : `ContBeam` or `Beam` or `Kernel2D`
            The beam to convolve with
        new_beam : `ContBeam`
            The convolved Beam
        if self.major is not None and getattr(other, "major", None) is not None:
            # other could be a ContBeam
            new_major, new_minor, new_pa = beam_convolve(self, other)
            return ContBeam(
        elif self.major is not None and isinstance(other, Gaussian2DKernel):
            warnings.warn("Assuming same pixelscale")
            major = other.model.x_fwhm.value * self.pixscale
            minor = other.model.y_fwhm.value * self.pixscale
            pa = other.model.theta.value * u.radian - 90 * u.deg
            new_major, new_minor, new_pa = beam_convolve(
                self, ContBeam(major=major, minor=minor, pa=pa, pixscale=self.pixscale)
            return ContBeam(
        elif isinstance(other, Kernel2D):
            other_pixscale = getattr(other, "pixscale", None)
            if other_pixscale is None:
                warnings.warn("Assuming same pixelscale")
            elif self.pixscale != other_pixscale:
                raise ValueError("Do not know hot to handle different pixscale")

            return ContBeam(array=signal.convolve(self.array, other.array), pixscale=self.pixscale)
            ValueError("Do not know how to handle {}".format(type(other)))

    def from_fits_header(cls, hdr, unit=u.deg, pixscale=None):  # pragma: no cover
        Instantiate the beam from a header. Attempts to extract the
        beam from standard keywords. Failing that, it looks for an
        AIPS-style HISTORY entry.
        # ... given a file try to make a fits header
        # assume a string refers to a filename on disk
        if not isinstance(hdr, fits.Header):
            if isinstance(hdr, str):
                if hdr.lower().endswith((".fits", ".fits.gz", ".fit", ".fit.gz", ".fits.Z", ".fit.Z")):
                    hdr = fits.getheader(hdr)
                    raise TypeError("Unrecognized extension.")
                raise TypeError("Header is not a FITS header or a filename")

        # If we find a major axis keyword then we are in keyword
        # mode. Else look to see if there is an AIPS header.
        if "BMAJ" in hdr:
            major = hdr["BMAJ"] * unit
            hist_beam = cls.from_fits_history(hdr, pixscale=pixscale)
            if hist_beam is not None:
                return hist_beam
                raise ValueError("No BMAJ found and does not appear to be a CASA/AIPS header.")

        # Fill out the minor axis and position angle if they are
        # present. Else they will default .
        if "BMIN" in hdr:
            minor = hdr["BMIN"] * unit
            minor = None
        if "BPA" in hdr:
            pa = hdr["BPA"] * unit
            pa = None

        return cls(major=major, minor=minor, pa=pa, pixscale=pixscale)

    def from_fits_history(cls, hdr, pixscale=None):  # pragma: no cover
        Instantiate the beam from an AIPS header. AIPS holds the beam
        in history. This method of initializing uses the last such
        # a line looks like
        # HISTORY AIPS   CLEAN BMAJ=  1.7599E-03 BMIN=  1.5740E-03 BPA=   2.61
        if "HISTORY" not in hdr:
            return None

        aipsline = None
        for line in hdr["HISTORY"]:
            if "BMAJ" in line:
                aipsline = line

        # a line looks like
        # HISTORY Sat May 10 20:53:11 2014
        # HISTORY imager::clean() [] Fitted beam used in
        # HISTORY > restoration: 1.34841 by 0.830715 (arcsec)
        #        at pa 82.8827 (deg)

        casaline = None
        for line in hdr["HISTORY"]:
            if ("restoration" in line) and ("arcsec" in line):
                casaline = line
        # assert precedence for CASA style over AIPS
        #        this is a dubious choice

        if casaline is not None:
            bmaj = float(casaline.split()[2]) * u.arcsec
            bmin = float(casaline.split()[4]) * u.arcsec
            bpa = float(casaline.split()[8]) * u.deg
            return cls(major=bmaj, minor=bmin, pa=bpa, pixscale=None)

        elif aipsline is not None:
            bmaj = float(aipsline.split()[3]) * u.deg
            bmin = float(aipsline.split()[5]) * u.deg
            bpa = float(aipsline.split()[7]) * u.deg
            return cls(major=bmaj, minor=bmin, pa=bpa, pixscale=None)

            return None

class ContMap(NDDataArray):
    """A ContMap object represent a continuum map with additionnal capabilities.

    It contains the metadata, wcs, and all attribute (data/stddev/time/unit/mask) as well as potential source list detected in these maps.

    data : :class:`~numpy.ndarray` or :class:`astropy.nddata.NDData`
        The actual data contained in this `NDData` object. Not that this
        will always be copies by *reference* , so you should make copy
        the ``data`` before passing it in if that's the  desired behavior.
    uncertainty : :class:`astropy.nddata.NDUncertainty`, optional
        Uncertainties on the data.
    mask : :class:`~numpy.ndarray`-like, optional
        Mask for the data, given as a boolean Numpy array or any object that
        can be converted to a boolean Numpy array with a shape
        matching that of the data. The values must be ``False`` where
        the data is *valid* and ``True`` when it is not (like Numpy
        masked arrays). If ``data`` is a numpy masked array, providing
        ``mask`` here will causes the mask from the masked array to be
    hits : :class:`~numpy.ndarray`-like, optional
        The hit per pixel on the map
    sampling_freq : float or :class:`~astropy.units.Quantity`
        the sampling frequency of the experiment, default 1 Hz
    wcs : undefined, optional
        WCS-object containing the world coordinate system for the data.
    meta : `dict`-like object, optional
        Metadata for this object.  "Metadata" here means all information that
        is included with this object but not part of any other attribute
        of this particular object.  e.g., creation date, unique identifier,
        simulation parameters, exposure time, telescope name, etc.
    unit : :class:`astropy.units.UnitBase` instance or str, optional
        The units of the data.
    beam : :class:`~nikamap.contmap.ContBeam`
        The beam corresponding to the data, by default a gaussian
        constructed from the header 'BMAJ' 'BMIN', 'PA' keyword.
    fake_source : :class:`astropy.table.Table`, optional
        The table of potential fake sources included in the data

        .. note::
            The table must contain at least 3 columns: ['ID', 'ra', 'dec']

    sources : :class`astropy.table.Table`, optional
        The table of detected sources in the data.


    _residual = None

    def __init__(self, data, *args, **kwargs):
        if "meta" not in kwargs:
            kwargs["meta"] = kwargs.pop("header", None)
        if "header" in kwargs:
            raise ValueError("can't have both header and meta.")

        # Arbitrary unit by default
        if "unit" not in kwargs:
            kwargs["unit"] = getattr(data, "unit", "adu")

        # Must be set AFTER the super() call
        self.fake_sources = kwargs.pop("fake_sources", None)
        self.sources = kwargs.pop("sources", None)
        self.sampling_freq = kwargs.pop("sampling_freq", None)

        self.hits = kwargs.pop("hits", None)
        self.beam = kwargs.pop("beam", None)

        super().__init__(data, *args, **kwargs)

        if isinstance(data, ContMap):
            if self.hits is None and data.hits is not None:
                self.hits = data.hits
            if self.beam is None and data.beam is not None:
                self.beam = data.beam

        if isinstance(self.wcs, WCS):
            pixscale = np.abs(self.wcs.wcs.cdelt[0]) * u.deg
            pixscale = np.abs(self.meta.get("CDELT1", 1)) * u.deg

        self._pixel_scale = u.pixel_scale(pixscale / u.pixel)

        if self.beam is None:
            # Default BMAJ 1 deg...
            header = meta_to_header(self.meta)
            if "BMAJ" not in header:
                header["BMAJ"] = 1
            self.beam = ContBeam.from_fits_header(header, pixscale=pixscale)

    def header(self):
        return self._meta

    def header(self, value):
        self.meta = value

    def time(self):
        if self.hits is not None and self.sampling_freq is not None:
            return (self.hits / self.sampling_freq).to(u.s)
            return None

    def compressed(self):
        return self.data[~self.mask] * self.unit

    def uncertainty_compressed(self):
        return self.uncertainty.array[~self.mask] * self.uncertainty.unit

    def __array__(self):
        This allows code that requests a Numpy array to use an NDData
        object as a Numpy array.

        Overrite NDData.__array__ to force for MaskedArray output
        return np.ma.array(self.data, mask=self.mask)

    def __u_array__(self):
        """Retrieve uncertainty array as masked array"""
        return np.ma.array(self.uncertainty.array, mask=self.mask)

    def __t_array__(self):
        """Retrieve hit array as maskedQuantity"""
        return np.ma.array(self.time, mask=self.mask, fill_value=0)

    def surface(self, box_size=None):
        """Retrieve surface covered by unmasked pixels
        box_size : scalar or tuple, optional
            The edge of the map is cropped by the box_size if not None.
            Default is None.

            Surface covered by unmasked pixels

            Default value for box_size in detect_sources is 5"""

        nvalid = np.prod(self.data.shape)

        if self.mask is not None:
            mask = self.mask
            if box_size is not None:
                box_kernel = Box2DKernel(box_size)
                mask = shrink_mask(mask, box_kernel)

            nvalid = np.sum(~mask)

        conversion = (u.pix.to(u.arcsec, equivalencies=self._pixel_scale)) ** 2

        return nvalid * conversion * u.arcsec**2

    def uncertainty(self):
        return self._uncertainty

    def uncertainty(self, value):
        if value is not None:
            if isinstance(value, NDUncertainty):
                if getattr(value, "_parent_nddata", None) is not None:
                    value = value.__class__(value, copy=False)
                self._uncertainty = value
            elif isinstance(value, np.ndarray):
                if value.shape != self.shape:
                    raise ValueError("uncertainty must have same shape as " "data.")
                self._uncertainty = StdDevUncertainty(value)
                warnings.warn("array provided for uncertainty; assuming it is a " "StdDevUncertainty.")
                raise TypeError("uncertainty must be an instance of a " "NDUncertainty object or a numpy array.")
            self._uncertainty.parent_nddata = self
            self._uncertainty = value

    def weights(self):
        """Return the weights as inverse variance regardless of the uncertainty type"""
        if isinstance(self.uncertainty, InverseVariance):
            weights = self.uncertainty.array
        elif isinstance(self.uncertainty, StdDevUncertainty):
            weights = 1 / self.uncertainty.array**2
        elif isinstance(self.uncertainty, VarianceUncertainty):
            weights = 1 / self.uncertainty.array
            raise ValueError("Unknown uncertainty type")

        return weights

    def snr(self):
        snr = self.data * np.sqrt(self.weights)

        return np.ma.array(snr, mask=self.mask)

    def _to_ma(self, item=None):
        """Get masked array quantities from object.

        item : str, optional (None|signal|uncertainty|snr|residual)
            The quantity to retrieve, by default None, ie signal

        data : ~np.ma.MaskedArray
            The corresponding item as masked quantity
        label : str
            The corresponding label

            When item is not in list

        if item == "snr":
            label = "SNR"
            data = self.snr
        elif item == "uncertainty":
            label = "Uncertainty"
            data = np.ma.array(self.uncertainty.array * self.unit, mask=self.mask)
        elif item in ["signal", None]:
            label = "Brightness"
            data = np.ma.array(self.data * self.unit, mask=self.mask)
        elif item == "residual":
            label = "Residual"
            data = np.ma.array(self._residual * self.unit, mask=self.mask)
            raise ValueError("must be in (None|signal|uncertainty|snr|residual)")

        return data, label

    def beam(self):
        return self._beam

    def beam(self, value):
        if value is None or isinstance(value, ContBeam):
            self._beam = value
            raise ValueError("Can not handle this beam type {}".format(type(value)))

    def _slice(self, item):
        # slice all normal attributes
        kwargs = super()._slice(item)
        # The arguments for creating a new instance are saved in kwargs
        # so we need to add another keyword "flags" and add the sliced flags
        kwargs["hits"] = self.hits[item] if self.hits is not None else None
        kwargs["beam"] = self.beam

        kwargs["fake_sources"] = self.fake_sources
        kwargs["sources"] = self.sources

        return kwargs  # these must be returned

    def trim(self):
        """Remove masked region on the edges

            return a trimmed ContMap object

        mask = self.mask
        axis_slice = []
        for axis in [1, 0]:
            good_pix = np.argwhere(np.mean(mask, axis=axis) != 1)
            axis_slice.append(slice(np.min(good_pix), np.max(good_pix) + 1))

        output = self[axis_slice[0], axis_slice[1]]
        return output

    def add_gaussian_sources(self, within=(0, 1), cat_gen=pos_uniform, **kwargs):
        """Add gaussian sources into the map.

        within : tuple of 2 int
            force the sources within this relative range in the map
        cat_gen : function (`pos_uniform`|`pos_gridded`|`pos_list`|...)
            the function used to generate the pixel positions and flux of the sources (see Notes below)
            any keyword arguments to be passed to the `cat_gen` function

        the `cat_gen` function is used to generate the list of x, y pixel positions and fluxes
        and must at least support the `shape=None, within=(0, 1), mask=None` arguments.
        shape = self.shape

        x_mean, y_mean, peak_flux = cat_gen(shape=shape, within=within, mask=self.mask, **kwargs)

        nsources = x_mean.shape[0]

        sources = Table(masked=True)

        sources["amplitude"] = peak_flux.to(self.unit * u.beam)

        sources["x_mean"] = x_mean
        sources["y_mean"] = y_mean

        sources["x_stddev"] = np.ones(nsources) * self.beam.stddev_maj.to(u.pix, self._pixel_scale).value
        sources["y_stddev"] = np.ones(nsources) * self.beam.stddev_min.to(u.pix, self._pixel_scale).value
        sources["theta"] = np.zeros(nsources)

        # Crude check to be within the finite part of the map
        if self.mask is not None:
            within_coverage = ~self.mask[sources["y_mean"].astype(int), sources["x_mean"].astype(int)]
            sources = sources[within_coverage]

        # Gaussian sources...
        self._data += make_gaussian_sources_image(shape, sources)

        # Add an ID column
        sources.add_column(Column(np.arange(len(sources)), name="fake_id"), 0)

        # Transform pixel to world coordinates
        a, d = self.wcs.pixel_to_world_values(sources["x_mean"], sources["y_mean"])
        sources.add_columns([Column(a * u.deg, name="ra"), Column(d * u.deg, name="dec")])

        sources["_ra"] = sources["ra"]
        sources["_dec"] = sources["dec"]

        # Remove unnecessary columns
        sources.remove_columns(["x_mean", "y_mean", "x_stddev", "y_stddev", "theta"])

        self.fake_sources = sources

    def detect_sources(self, threshold=3, box_size=5, detect_on="snr"):
        """Detect sources with find local peaks above a specified threshold value.

        The detection is made on the SNR map, and return an :class`astropy.table.Table` with columns ``ID, ra, dec, SNR``.
        If fake sources are present, a match is made with a distance threshold of ``beam_fwhm / 3``

        threshold : float
            The data value or pixel-wise data values to be used for the
            detection threshold.
        box_size : scalar or tuple, optional
            The size of the local region to search for peaks at every point
            in ``data``.  If ``box_size`` is a scalar, then the region shape
            will be ``(box_size, box_size)``.
        detect_on : str (None|signal|uncertainty|snr|residual)
            do the detection on given array, default 'snr'

        The edge of the map is cropped by the box_size in order to insure proper subpixel fitting.

        detect_on = self._to_ma(item=detect_on)[0].filled(0)

        if isinstance(threshold, u.Quantity):
            detect_on = detect_on.to(threshold.unit).value
            threshold = threshold.value

        if self.mask is not None:
            # Make sure that there is no detection on the edge of the map
            box_kernel = Box2DKernel(box_size)
            detect_mask = shrink_mask(self.mask, box_kernel)
            detect_on[detect_mask] = np.nan

        # TODO: Have a look at
        # ~photutils.psf.IterativelySubtractedPSFPhotometry

            # To avoid bad fit warnings...
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", AstropyWarning)
                sources = find_peaks(
        except InconsistentAxisTypesError:
            sources = []

        if sources is not None and len(sources) > 0:
            sources.rename_column("peak_value", "SNR")

            # Transform pixel coordinates column to world coordinates
            sources = xy_to_world(sources, self.wcs, "x_centroid", "y_centroid")

            # Transform to masked Table here to avoid future warnings
            sources = Table(sources, masked=True)
            sources.meta["method"] = "find_peak"
            sources.meta["threshold"] = threshold

            # Sort by decreasing SNR
            if "SNR" in sources.colnames:

            sources.add_column(Column(np.arange(len(sources)), name="ID"), 0)

        if self.fake_sources:
            # Match to the fake catalog
            fake_sources = self.fake_sources
            dist_threshold = self.beam.major / 3

            if sources is None or len(sources) == 0:
                fake_sources["find_peak"] = MaskedColumn(np.ones(len(fake_sources), dtype=int), mask=True)
                fake_sc = cat_to_sc(fake_sources)
                sources_sc = cat_to_sc(sources)

                idx, sep2d, _ = match_coordinates_sky(fake_sc, sources_sc)
                mask = sep2d > dist_threshold
                fake_sources["find_peak"] = MaskedColumn(sources[idx]["ID"], mask=mask)

                idx, sep2d, _ = match_coordinates_sky(sources_sc, fake_sc)
                mask = sep2d > dist_threshold
                sources["fake_id"] = MaskedColumn(fake_sources[idx]["fake_id"], mask=mask)

        if sources is not None and len(sources) > 0:
            self.sources = sources
            self.sources = None

    def match_sources(self, catalogs, dist_threshold=None):
        if dist_threshold is None:
            dist_threshold = self.beam.major / 3

        if not isinstance(catalogs, list):
            catalogs = [catalogs]

        for cat, ref_cat in product([self.sources], catalogs):
            cat_sc = cat_to_sc(cat)
            ref_sc = cat_to_sc(ref_cat)
            idx, sep2d, _ = match_coordinates_sky(cat_sc, ref_sc)
            mask = sep2d > dist_threshold
            cat[ref_cat.meta["name"]] = MaskedColumn(idx, mask=mask)

    def phot_sources(

        sources : :class:`astropy.table.Table`, optional
            Catalog on which to do the photometry, by default None
        peak : bool, optional
            Do peak photometry, by default True
        psf : bool, optional
            Do psf photometry, by default True
        fixed_positions : bool, optional,
            Fix sources positions in the fit, default True
        fixed_sigma: bool, optional
            Fix the sigma of the psf in the fit, default True
        background : bool, optional
            Estimate and remove a global median background, by default True
        local_background : bool, optional
            Estimate and remove a local background, by default True
        background_clipping : int, optional
            Sigma clipping used for the background, by default 3
        grouping_threshold : int, optional
            Grouping distance for psf photometry in unit of psf fwhm, by default 3
        if sources is None:
            sources = self.sources

        xx, yy = self.wcs.world_to_pixel_values(sources["ra"], sources["dec"])

        x_idx = np.floor(xx + 0.5).astype(int)
        y_idx = np.floor(yy + 0.5).astype(int)

        if peak:
            # Crude Peak Photometry
            # From pixel indexes to array indexing

            sources["flux_peak"] = Column(self.data[y_idx, x_idx], unit=self.unit * u.beam)
            sources["eflux_peak"] = Column(self.uncertainty.array[y_idx, x_idx], unit=self.unit * u.beam)

        if psf:
            data = self.data
            # BasicPSFPhotometry with fixed positions

            sigma_psf = self.beam.stddev_maj.to(u.pix, self._pixel_scale).value

            # Using an IntegratedGaussianPRF can cause biais in the photometry
            # TODO: Check the NIKA2 calibration scheme
            # from photutils.psf import IntegratedGaussianPRF
            # psf_model = IntegratedGaussianPRF(sigma=sigma_psf)
            psf_model = CircularGaussianPSF(sigma=sigma_psf)

            if fixed_positions:
                psf_model.x_0.fixed = True
                psf_model.y_0.fixed = True

            if fixed_sigma:
                psf_model.sigma.fixed = True
                psf_model.sigma.fixed = False

            source_grouper = None
            if len(xx) > 1:
                source_grouper = SourceGrouper(grouping_threshold * self.beam.major.to(u.pix, self._pixel_scale).value)

            bkgstat = MedianBackground(sigma_clip=SigmaClip(sigma=background_clipping, stdfunc="mad_std"))

            if background:
                data = data - bkgstat(data)

            local_bkg = None
            if local_background:
                local_bkg = LocalBackground(5, 10, bkgstat)

            photometry = PSFPhotometry(

            positions = Table(
                [Column(xx, name="x_0"), Column(yy, name="y_0"), Column(self.data[y_idx, x_idx], name="flux_init")]

            result_tab = photometry(
                error=1 / np.sqrt(self.weights),

            for _source, _tab in zip(["flux_psf", "eflux_psf"], ["flux_fit", "flux_err"]):
                # Sometimes the returning fluxes has no uncertainty....
                if _tab in result_tab.colnames:
                    sources[_source] = Column(result_tab[_tab] * psf_model(0, 0), unit=self.unit * u.beam)
            for key in ["local_bkg", "group_id", "qfit", "cfit"]:
                sources[key] = result_tab[key]

            if not fixed_positions:
                for key in ["x_fit", "x_err", "y_fit", "y_err"]:
                    sources[key] = result_tab[key]

                # Transform pixel coordinates column to world coordinates
                sources = xy_to_world(sources, self.wcs, "x_fit", "y_fit")

            if not fixed_sigma:
                assert np.abs(self.wcs.wcs.cdelt[0]) == np.abs(self.wcs.wcs.cdelt[1]), "Non square pixel not supported"
                unit = np.abs(self.wcs.wcs.cdelt[0]) * u.Unit(self.wcs.world_axis_units[0])
                for key in ["sigma_fit", "sigma_err"]:
                    sources[key.replace("sigma", "fwhm")] = gaussian_sigma_to_fwhm * result_tab[key] * unit

            self._residual = photometry.make_residual_image(data, (10, 10))

        self.sources = sources

    def match_filter(self, kernel):
        """Return a matched filtered version of the map.

        kernel : :class:`nikamap.contmap.ContBeam` or any :class:`astropy.convolution.kernel2D`
            the kernel used for filtering

            the resulting match filtered ContMap object

        This compute the match filtered :math:`MF` map as :

        .. math::

            MF = \\frac{B * (W M)}{B^2 * W}

        with :math:`B` the beam, :math:`W` the weights (inverse variance) and :math:`M` the signal map

        Peak photometry is conserved for data and e_data

        Resultings maps have a different mask

        >>> npix, std = 500, 4
        >>> kernel = Gaussian2DKernel(std)
        >>> mask = np.zeros((npix,npix))
        >>> data = np.random.normal(0, 1, size=mask.shape)
        >>> data[(npix-std*8)//2:(npix+std*8)//2+1,(npix-std*8)//2:(npix+std*8)//2+1] += kernel.array/kernel.array.max()
        >>> data = ContMap(data, uncertainty=StdDevUncertainty(np.ones_like(data)), time=np.ones_like(data)*u.s, mask=mask)
        >>> mf_data = data.match_filter(kernel)
        >>> import matplotlib.pypot as plt
        >>> plt.ion()
        >>> fig, axes = plt.subplots(ncols=2)
        >>> axes[0].imshow(data) ; axes[1].imshow(mf_data)

        mf_beam = self.beam.convolve(kernel)


        # Convolve the mask and retrieve the fully sampled region, this
        # will remove one kernel width on the edges
        # mf_mask = ~np.isclose(convolve(~self.mask, kernel, normalize_kernel=False), 1)
        if self.mask is not None:
            mf_mask = shrink_mask(self.mask, kernel)
            mf_mask = None

        # Convolve the time (integral for time)
        # with warnings.catch_warnings():
        #     warnings.simplefilter('ignore', AstropyWarning)
        #     mf_time = convolve(self.time, kernel, normalize_kernel=False)*self.time.unit
        if self.hits is not None:
            mf_hits = signal.fftconvolve(np.asarray(self.hits), kernel, mode="same")
            if mf_mask is not None:
                mf_hits[mf_mask] = 0
            mf_hits = None

        # Convolve the data (peak for unit conservation)
        kernel_sqr = kernel.array**2

        # ma.filled(0) required for the fft convolution
        weights = self.weights

        if self.mask is not None:
            weights[self.mask] = 0

        with np.errstate(invalid="ignore", divide="ignore"):
            mf_uncertainty = 1 / np.sqrt(signal.fftconvolve(weights, kernel_sqr, mode="same"))
        if mf_mask is not None:
            mf_uncertainty[mf_mask] = np.nan

        # Units are not propagated in masked arrays...
        mf_data = signal.fftconvolve(weights * self.__array__().filled(0), kernel, mode="same") * mf_uncertainty**2

        mf_data = ContMap(

        return mf_data

    def plot(self, to_plot=None, ax=None, cbar=False, cat=None, levels=None, beam=False, **kwargs):
        """Convenience routine to plot the dataset.

        to_plot : str, optionnal (None|signal|uncertainty|snr|residual)
            Choose which quantity to plot, by default None (signal)
        ax : :class:`matplotlib.axes.Axes`, optional
            Axe to plot the power spectrum
        cbar: boolean
            Draw a colorbar (ax must be None), default=False
        cat : boolean of list of tuple [(cat, kwargs)], optionnal
            If True, overplot the current self.source catalog
            with '^' as marker.
            Otherwise overplot the given catalogs on the map, with kwargs.
        levels: array_like, optionnal
            Overplot levels contours, add negative contours as dashed line
        beam: boolean
            Draw a beam in the lower left corner, default False

            Arbitrary keyword arguments for :func:`matplotib.pyplot.imshow `

        image : `~matplotlib.image.AxesImage`

        * if a fake_sources property is present, it will be overplotted with 'o' as marker
        * each catalog *must* have '_ra' & '_dec' column

            data, cbar_label = self._to_ma(item=to_plot)
        except ValueError as e:
            raise ValueError("to_plot {}".format(e))

        if isinstance(data.data, u.Quantity):
            # Remove unit to avoid matplotlib problems
            data = np.ma.array(data.data.to(self.unit).value, mask=data.mask)
            cbar_label = "{} [{}]".format(cbar_label, self.unit)

        ax = setup_ax(ax, self.wcs)

        iax = ax.imshow(data, origin="lower", interpolation="none", **kwargs)

        if levels is not None:
            ax.contour(data, levels=levels, alpha=0.8, colors="w")
            ax.contour(data, levels=-levels[::-1], alpha=0.8, colors="w", linestyles="dashed")

        if cbar:
            fig = ax.get_figure()
            cbar = fig.colorbar(iax, ax=ax)

        if cat is True and self.sources is not None:
            cat = [(self.sources, {"marker": "^", "color": "red"})]
        elif cat is None:
            cat = []

        # In case of fake sources, overplot them
        if self.fake_sources:
            fake_cat = [(self.fake_sources, {"marker": "o", "c": "red", "alpha": 0.8})]
            cat += fake_cat

        if cat != []:
            for _cat, _kwargs in list(cat):
                label = _cat.meta.get("method") or _cat.meta.get("name") or _cat.meta.get("NAME") or "Unknown"
                cat_sc = cat_to_sc(_cat)
                x, y = self.wcs.world_to_pixel_values(cat_sc.ra, cat_sc.dec)
                if _kwargs is None:
                    _kwargs = {"alpha": 0.8}
                ax.scatter(x, y, **_kwargs, label=label)

        if beam and hasattr(self, "beam"):
            ellipse_artist = self.beam.ellipse_to_plot(
                self.beam.support_scaling / 2, self.beam.support_scaling / 2, self.beam.pixscale

        ax.set_xlim(0, self.shape[1] - 1)
        ax.set_ylim(0, self.shape[0] - 1)

        return iax

    def plot_SNR(self, vmin=-3, vmax=5, **kwargs):
        """Convenience method to plot the signal to noise map.

        See :func:`nikamap.ContMap.plot`for additionnal keywords
        return self.plot(to_plot="snr", vmin=vmin, vmax=vmax, **kwargs)

    def check_SNR_pdf(self, ax=None, range=(-6, 3), return_mean=False):
        """Perform normality test on SNR map.

        This perform a normal distribution fit on snr pixels clipped between -6 and 3

        ax : :class:`~matplotlib.axes.Axes`, optional
            axe to plot the histogram and fits
        range: tuple of 2 floats
            perform the fit on the histogram between range[0] and range[1], default (-6, 3)
        return_mean: bool
            if True, return the mean of the histogram, default False

        std : float[robust]
            return the robust standard deviation of the SNR

        To recover the normality you must multiply the uncertainty array by the returned stddev value,
        if uncertainty is StdDevUncertainty.

        >>> std = data.check_SNR_pdf()
        >>> data.uncertainty.array *= std
        snr = self.snr.compressed()
        if range is not None:
            snr = snr[(snr > np.min(range)) & (snr < np.max(range))]

        snr_sorted = np.sort(snr)
        # p = np.linspace(0, 1, len(snr))
        p = np.linspace(stats.norm.cdf(np.min(range)), stats.norm.cdf(np.max(range)), len(snr))
        func = lambda x, loc, scale: stats.norm.cdf(x, loc=loc, scale=scale)

        popt, pcov = curve_fit(func, snr_sorted, p)

        ## WARNING DOES NOT WORK WITH SKEWED disttribution !!!
        # mean, std = stats.norm.fit(snr)
        mean, std = popt

        if ax is not None:
            _, bins_edges, _ = ax.hist(snr, bins="auto", histtype="stepfilled", alpha=0.2, density=True, range=range)
            ax.plot(bins_edges, stats.norm(mean, std).pdf(bins_edges))

        if return_mean:
            return std, mean
            return std

    def check_SNR(self, ax=None, bins="auto", range=(-6, 3), return_mean=False, **kwargs):
        """Perform normality test on SNR map.

        This perform a gaussian fit on snr pixels histogram clipped between -6 and 3

        ax : :class:`~matplotlib.axes.Axes`, optional
            axe to plot the histogram and fits
        bins: int or 'auto'
            number of bins for the histogram. Default 'auto'.
        range: tuple of 2 floats
            perform the fit on the histogram between range[0] and range[1], default (-6, 3)
        return_mean: bool
            if True, return the mean of the histogram, default False

        std : float[robust]
            return the robust standard deviation of the SNR

        To recover the normality you must multiply the uncertainty array by the returned stddev value,
        if uncertainty is StdDevUncertainty.

        >>> std = data.check_SNR()
        >>> data.uncertainty.array *= std
        snr = self.snr.compressed()
        hist, bin_edges = np.histogram(snr, bins=bins, density=True, range=range)

        # is biased if signal is present
        # is biased if trimmed
        # mu, std = norm.fit(SN)

        bin_center = (bin_edges[1:] + bin_edges[:-1]) / 2

        def gauss(x, a, c, s):
            return a * np.exp(-((x - c) ** 2) / (2 * s**2))

        popt, _ = curve_fit(gauss, bin_center.astype(np.float32), hist.astype(np.float32))
        mean, std = popt[1:]

        if ax is not None:
            ax.bar(bin_center, hist, width=np.median(np.diff(bin_center)), fill=True, alpha=0.2)
            ax.plot(bin_center, gauss(bin_center, *popt))

        if return_mean:
            return std, mean
            return std

    def check_SNR_simple(self, **kwargs):
        """Perform normality test on SNR maps

        This perform a simple mad absolute deviation on snr pixels
        snr = self.snr.compressed()
        return np.median(np.abs(snr - np.median(snr)))

    def normalize_uncertainty(self, factor=None, method="check_SNR", **kwargs):
        """Normalize the uncertainty.value

        factor : float, optionnal
            the factor which normalize the snr distribution
        method : str, (check_SNR_simple|check_SNR|check_SNR_pdf),
            the method to compute this factor if not provided, by default `check_SNR`
        assert method in ("check_SNR_simple", "check_SNR", "check_SNR_pdf", None)

        if factor is None:
            if method is None:
                raise ValueError("You must provide either `factor` or `method`.")
            elif method == "check_SNR_simple":
                factor = self.check_SNR_simple(**kwargs)
            elif method == "check_SNR":
                factor = self.check_SNR(**kwargs)

        if isinstance(self.uncertainty, StdDevUncertainty):
            self.uncertainty.array *= factor
        elif isinstance(self.uncertainty, InverseVariance):
            self.uncertainty.array /= factor**2
        elif isinstance(self.uncertainty, VarianceUncertainty):
            self.uncertainty.array *= factor**2
            raise ValueError("Unknown uncertainty type")

        # Add the factor in the meta
        if "FACTOR" in self.meta:
            self.meta["FACTOR"] *= factor
            self.meta["FACTOR"] = factor

    def plot_PSD(self, to_plot=None, ax=None, bins=100, range=None, apod_size=None, **kwargs):
        """Plot the power spectrum of the map.

        to_plot : str, optionnal (None|signal|uncertainty|snr|residual)
            Choose which quantity to plot, by default None (signal)
        ax : :class:`matplotlib.axes.Axes`, optional
            Axe to plot the power spectrum
        bins : int
            Number of bins for the histogram. Default 100.
        range : (float, float), optional
            The lower and upper range of the bins. (see `~numpy.histogram`)

        powspec_k : :class:`astropy.units.quantity.Quantity`
            The value of the power spectrum
        bin_edges : :class:`astropy.units.quantity.Quantity`
            Return the bin edges ``(length(hist)+1)``.
            data, label = self._to_ma(item=to_plot)
        except ValueError as e:
            raise ValueError("to_plot {}".format(e))

        res = (1 * u.pixel).to(u.arcsec, equivalencies=self._pixel_scale)
        powspec, bin_edges = power_spectral_density(data, res=res, bins=bins, range=range, apod_size=apod_size)

        if to_plot == "snr":
            powspec /= res**2
            pk_unit = u.Jy**2 / u.sr
            powspec /= (self.beam.sr / u.beam) ** 2
            powspec = powspec.to(pk_unit)
            label = "P(k) {} [{}]".format(label, pk_unit)

        if ax is not None:
            bin_center = (bin_edges[1:] + bin_edges[:-1]) / 2
            ax.loglog(bin_center, powspec, **kwargs)
            ax.set_xlabel(r"k [arcsec$^{-1}$]")

        return powspec, bin_edges

    def get_square_slice(self, start=None):
        """Retrieve the slice to get the maximum unmasked square.

        start : (int, int)
            define the center (y, x) of the starting point (default: center of the image)

        islice : slice
            to be applied on the object itself on both dimension data[islice, islice]

        Simply growth a square symetrically from the starting point
        if start is None:
            start = np.asarray(self.shape) // 2
            assert isinstance(start, (list, tuple, np.ndarray)), "start should have a length of 2"
            assert len(start) == 2, "start should have a length of 2"

        islice = slice(*(np.asarray(start) + [0, 1]))

        while np.all(~self.mask[islice, islice]):
            islice = slice(islice.start - 1, islice.stop)
        islice = slice(islice.start + 1, islice.stop)

        while np.all(~self.mask[islice, islice]):
            islice = slice(islice.start, islice.stop + 1)
        islice = slice(islice.start, islice.stop - 1)

        return islice

    def _arithmetic(self, operation, operand, *args, **kwargs):
        # take all args and kwargs to allow arithmetic on the other properties
        # to work like before.
        # do the arithmetics on the flags (pop the relevant kwargs, if any!!!)
        if self.hits is not None and operand.hits is not None:
            result_hits = operation(self.hits, operand.hits)
            # np.logical_or is just a suggestion you can do what you want
            if self.hits is not None:
                result_hits = deepcopy(self.hits)
                result_hits = deepcopy(operand.hits)

        # Let the superclass do all the other attributes note that
        # this returns the result and a dictionary containing other attributes
        result, kwargs = super()._arithmetic(operation, operand, *args, **kwargs)
        # The arguments for creating a new instance are saved in kwargs
        # so we need to add another keyword "flags" and add the processed flags
        kwargs["hits"] = result_hits
        return result, kwargs  # these must be returned

    def stack(self, coords, size, method="cutout2d", n_bootstrap=None, **kwargs):
        """Return a stacked map from a catalog of coordinates.

        coords : array of `~astropy.coordinates.SkyCoord`
            the position of the cutout arrays center
        size : `~astropy.units.Quantity`
            the size of the cutout array along each axis.  If ``size``
            is a scalar `~astropy.units.Quantity`, then a square cutout of
            ``size`` will be created. If ``size`` has two elements,
            they should be in ``(ny, nx)`` order.
        method : str ('cutout2d', 'reproject')
            the method to generate the cutouts
        n_bootstrap : int, optional
            use a bootstrap distribution of the signal instead of a weighted average

        Any additionnal keyword arguments are passed to the choosen method

        if method == "cutout2d":
            datas, weights, wcs = self._gen_cutout2d(coords, size, **kwargs)
        elif method == "reproject":
            datas, weights, wcs = self._gen_reproject(coords, size, **kwargs)
            raise ValueError("method should be cutout2d or reproject")

        header = self.header.copy()
        header["HISTORY"] = "Stacked on {} coordinates".format(len(coords))

        if n_bootstrap is None:
            # np.ma.average handle 0 weights in the final map
            data, weight = np.ma.average(datas, weights=weights, axis=0, returned=True)
            uncertainty = InverseVariance(weight)

            _ = partial(_shuffled_average, datas=datas, weights=weights)

            bs_array = ProgressBar.map(
                np.array_split(np.arange(n_bootstrap), cpu_count()),
            bs_array = np.concatenate(bs_array)

            data = np.mean(bs_array, axis=0)
            uncertainty = StdDevUncertainty(np.std(bs_array, axis=0, ddof=1))

        data = self.__class__(

        return data

    def _gen_cutout2d(self, coords, size, **kwargs):
        """Generate simple 2D cutout from a catalog of coordinates

        coords : array of `~astropy.coordinates.SkyCoord`
            the position of the cutout arrays center
        size : `~astropy.units.Quantity`
            the size of the cutout array along each axis.  If ``size``
            is a scalar `~astropy.units.Quantity`, then a square cutout of
            ``size`` will be created. If ``size`` has two elements,
            they should be in ``(ny, nx)`` order.

        The cutouts have odd number of pixels and are centered around the pixel containing the coordinates.
        # Convert size into pixel (ny, nx) to insure a even number of pixels
        size = np.atleast_1d(size)
        if len(size) == 1:
            size = np.repeat(size, 2)

        if len(size) > 2:
            raise ValueError("size must have at most two elements")

        pixel_scales = u.Quantity(
            [scale * u.Unit(unit) for scale, unit in zip(proj_plane_pixel_scales(self.wcs), self.wcs.wcs.cunit)]

        shape = np.zeros(2).astype(int)

        # ``size`` can have a mixture of int and Quantity (and even units),
        # so evaluate each axis separately
        for axis, side in enumerate(size):
            if side.unit.physical_type == "angle":
                shape[axis] = int(_round_up_to_odd_integer((side / pixel_scales[axis]).decompose()))
                raise ValueError("size must contains only Quantities with angular units")

        input_wcs = self.wcs
        input_array = self.__array__().filled(np.nan)
        data_cutouts = [
            Cutout2D(input_array, coord, shape, wcs=input_wcs, mode="partial", fill_value=np.nan).data
            for coord in coords
        data_cutouts = np.array(data_cutouts)

        weights = self.weights
        weights[self.mask] = 0
        weights_cutouts = [
            Cutout2D(weights, coord, shape, wcs=input_wcs, mode="partial", fill_value=np.nan).data for coord in coords
        weights_cutouts = np.array(weights_cutouts)
        weights_cutouts[np.isnan(data_cutouts)] = 0

        output_wcs = Cutout2D(self, coords[0], shape, mode="partial").wcs
        output_wcs.wcs.crval = (0, 0)
        output_wcs.wcs.crpix = (shape - 1) / 2 + 1

        return data_cutouts, weights_cutouts, output_wcs

    def _gen_reproject(self, coords, size, type="interp", pixel_scales=None, **kwargs):
        """Generate reprojected 2D cutout from a catalog of coordinates

        coords : array of `~astropy.coordinates.SkyCoord`
            the position of the cutout arrays center
        size : `~astropy.units.Quantity`
            the size of the cutout array along each axis.  If ``size``
            is a scalar `~astropy.units.Quantity`, then a square cutout of
            ``size`` will be created. If ``size`` has two elements,
            they should be in ``(ny, nx)`` order.
        type : str (``interp`` | ``adaptive`` | ``exact``)
            the type of reprojection used, default='interp'
        pixel_scales : `~astropy.units.Quantity`, optional
            the pixel scale of the output image, default None (same as image)

        The cutouts have odd number of pixels and are reprojected to be centered at the the coordinates.
        if type.lower() == "interp":
            from reproject import reproject_interp as _reproject

            _reproject = partial(_reproject)
        elif type.lower() == "adaptive":
            from reproject import reproject_adaptive as _reproject

            _reproject = partial(_reproject, kernel="gaussian", boundary_mode="strict", conserve_flux=True)
        elif type.lower() == "exact":
            from reproject import reproject_exact as _reproject

            _reproject = partial(_reproject)
            raise ValueError("Reprojection should be (``interp`` | ``adaptive`` | ``exact``)")

        # Convert size into pixel (ny, nx) to insure a even number of pixels
        size = np.atleast_1d(size)
        if len(size) == 1:
            size = np.repeat(size, 2)

        if len(size) > 2:
            raise ValueError("size must have at most two elements")

        if pixel_scales is None:
            pixel_scales = u.Quantity(
                [scale * u.Unit(unit) for scale, unit in zip(proj_plane_pixel_scales(self.wcs), self.wcs.wcs.cunit)]
            pixel_scales = np.atleast_1d(pixel_scales)
            if len(pixel_scales) == 1:
                pixel_scales = np.repeat(pixel_scales, 2)

            if len(pixel_scales) > 2:
                raise ValueError("pixel_scale must have at most two elements")

        shape = np.zeros(2).astype(int)
        cdelt = np.zeros(2)

        # ``size`` can have a mixture of int and Quantity (and even units),
        # so evaluate each axis separately
        for axis, side in enumerate(size):
            if side.unit.physical_type == "angle":
                cdelt[axis] = pixel_scales[axis].to(u.deg).value * np.sign(self.wcs.wcs.cdelt[axis])
                shape[axis] = int(_round_up_to_odd_integer((side / pixel_scales[axis]).decompose()))
                raise ValueError("size must contains only Quantities with angular units")

        output_wcs = WCS(naxis=2)
        output_wcs.wcs.ctype = self.wcs.wcs.ctype
        output_wcs.wcs.crpix = (shape - 1) / 2 + 1
        output_wcs.wcs.cdelt = cdelt

        input_array = self.__array__().filled(np.nan)
        input_weights = self.weights
        input_wcs = self.wcs

        data_cutouts = []
        weights_cutouts = []
        for coord in coords:
            output_wcs.wcs.crval = (coord.ra.to("deg").value, coord.dec.to("deg").value)
            array_new, footprint = _reproject((input_array, input_wcs), output_wcs, shape)
            weight_new = _reproject((input_weights, input_wcs), output_wcs, shape, return_footprint=False)

            array_new[footprint == 0] = np.nan
            weight_new[np.isnan(array_new)] = 0


        output_wcs.wcs.crval = (0, 0)

        return np.array(data_cutouts), np.array(weights_cutouts), output_wcs

    def to_hdus(
        """Creates an HDUList object from a ContMap object.
        hdu_data, hdu_mask, hdu_uncertainty, hdu_hits : str or None, optional
            If it is a string append this attribute to the HDUList as
            `~astropy.io.fits.ImageHDU` with the string as extension name.
            Default is ``'DATA'`` for data, ``'MASK'`` for mask, ``'UNCERT'``
            for uncertainty and ``HITS`` for hits.
        wcs_relax : bool
            Value of the ``relax`` parameter to use in converting the WCS to a
            FITS header using `~astropy.wcs.WCS.to_header`. The common
            ``CTYPE`` ``RA---TAN-SIP`` and ``DEC--TAN-SIP`` requires
            ``relax=True`` for the ``-SIP`` part of the ``CTYPE`` to be
        key_uncertainty_type : str, optional
            The header key name for the class name of the uncertainty (if any)
            that is used to store the uncertainty type in the uncertainty hdu.
            Default is ``UTYPE``.
        fits_header_comment : dict, optional
            A dictionnary (key, comment) to update the fits header comments.

            - If ``self.mask`` is set but not a `numpy.ndarray`.
            - If ``self.uncertainty`` is set but not a astropy uncertainty type.
            - If ``self.uncertainty`` is set but has another unit then

        hdulist : `~astropy.io.fits.HDUList`
        if isinstance(self.header, fits.Header):
            # Copy here so that we can modify the HDU header by adding WCS
            # information without changing the header of the CCDData object.
            header = self.header.copy()
            # Because _insert_in_metadata_fits_safe is written as a method
            # we need to create a dummy CCDData instance to hold the FITS
            # header we are constructing. This probably indicates that
            # _insert_in_metadata_fits_safe should be rewritten in a more
            # sensible way...
            header = meta_to_header(self.header)

            if fits_header_comment is not None:
                for key, comment in fits_header_comment.items():
                    if key in header:
                        header.set(key, comment=comment)

        if self.unit is not u.dimensionless_unscaled:
            header["bunit"] = self.unit.to_string()

        if self.wcs:
            # Simply extending the FITS header with the WCS can lead to
            # duplicates of the WCS keywords; iterating over the WCS
            # header should be safer.
            # Turns out if I had read the io.fits.Header.extend docs more
            # carefully, I would have realized that the keywords exist to
            # avoid duplicates and preserve, as much as possible, the
            # structure of the commentary cards.
            # Note that until astropy/astropy#3967 is closed, the extend
            # will fail if there are comment cards in the WCS header but
            # not header.
            wcs_header = self.wcs.to_header(relax=wcs_relax)
            header.extend(wcs_header, useblanks=False, update=True)

        hdus = [fits.ImageHDU(self.data, header, name=hdu_data)]

        if hdu_mask and self.mask is not None:
            # Always assuming that the mask is a np.ndarray (check that it has
            # a 'shape').
            if not hasattr(self.mask, "shape"):
                raise ValueError("only a numpy.ndarray mask can be saved.")

            # Convert boolean mask to uint since io.fits cannot handle bool.
            hduMask = fits.ImageHDU(self.mask.astype(np.uint8), header, name=hdu_mask)

        if hdu_uncertainty and self.uncertainty is not None:
            # We need to save some kind of information which uncertainty was
            # used so that loading the HDUList can infer the uncertainty type.
            # No idea how this can be done so only allow StdDevUncertainty.
            uncertainty_cls = self.uncertainty.__class__
            if uncertainty_cls not in _known_uncertainties:
                raise ValueError("only uncertainties of type {} can be saved.".format(_known_uncertainties))
            uncertainty_name = _unc_cls_to_name[uncertainty_cls]

            hdr_uncertainty = fits.Header(header)
            hdr_uncertainty[key_uncertainty_type] = uncertainty_name

            # Assuming uncertainty is an StdDevUncertainty save just the array
            # this might be problematic if the Uncertainty has a unit differing
            # from the data so abort for different units. This is important for
            # astropy > 1.2
            if hasattr(self.uncertainty, "unit") and self.uncertainty.unit is not None:
                if not _uncertainty_unit_equivalent_to_parent(uncertainty_cls, self.uncertainty.unit, self.unit):
                    raise ValueError(
                        "saving uncertainties with a unit that is not "
                        "equivalent to the unit from the data unit is not "

            hduUncert = fits.ImageHDU(self.uncertainty.array, hdr_uncertainty, name=hdu_uncertainty)

        if hdu_hits and self.hits is not None:
            # Always assuming that the hits is a np.ndarray (check that it has
            # a 'shape').
            if not hasattr(self.hits, "shape"):
                raise ValueError("only a numpy.ndarray hits can be saved.")

            # Convert boolean mask to uint since io.fits cannot handle bool.
            hduHits = fits.ImageHDU(self.hits.astype(np.uint32), header, name=hdu_hits)

        hdulist = fits.HDUList(hdus)

        return hdulist

def fits_contmap_reader(
    with fits.open(filename, **kwd) as hdus:
        hdr = hdus[0].header

        sampling_freq = hdr.get("sampling_freq", None)
        if sampling_freq is not None:
            sampling_freq = sampling_freq * u.Hz

        if hdu_data is not None and hdu_data in hdus:
            data = hdus[hdu_data].data
            wcs = WCS(hdus[hdu_data].header)
            if unit is None:
                unit = hdus[hdu_data].header.get("BUNIT", None)
            data = None
            wcs = None
        if hdu_uncertainty is not None and hdu_uncertainty in hdus:
            unc_hdu = hdus[hdu_uncertainty]
            stored_unc_name = unc_hdu.header.get(key_uncertainty_type, "None")

            unc_type = _unc_name_to_cls.get(stored_unc_name, StdDevUncertainty)
            uncertainty = unc_type(unc_hdu.data)
            uncertainty = None
        if hdu_mask is not None and hdu_mask in hdus:
            # Mask is saved as uint but we want it to be boolean.
            mask = hdus[hdu_mask].data.astype(np.bool_)
            mask = None
        if hdu_hits is not None and hdu_hits in hdus:
            # hits is saved as uint but we want it to be boolean.
            hits = hdus[hdu_hits].data
            hits = None

    c_data = ContMap(
        data, wcs=wcs, uncertainty=uncertainty, mask=mask, hits=hits, meta=hdr, unit=unit, sampling_freq=sampling_freq

    return c_data

def fits_contmap_writer(
    c_data, filename, hdu_mask="MASK", hdu_uncertainty="UNCERT", hdu_hits="HITS", key_uncertainty_type="UTYPE", **kwd
    Write ContMap object to FITS file.
    filename : str
        Name of file.
    hdu_mask, hdu_uncertainty, hdu_hits : str or None, optional
        If it is a string append this attribute to the HDUList as
        `~astropy.io.fits.ImageHDU` with the string as extension name.
        Flags are not supported at this time. If ``None`` this attribute
        is not appended.
        Default is ``'MASK'`` for mask, ``'UNCERT'`` for uncertainty and
        ``HITS`` for flags.
    key_uncertainty_type : str, optional
        The header key name for the class name of the uncertainty (if any)
        that is used to store the uncertainty type in the uncertainty hdu.
        Default is ``UTYPE``.
        .. versionadded:: 3.1
    kwd :
        All additional keywords are passed to :py:mod:`astropy.io.fits`
        - If ``self.mask`` is set but not a `numpy.ndarray`.
        - If ``self.uncertainty`` is set but not a
        - If ``self.uncertainty`` is set but has another unit then
        Saving flags is not supported.
    # Build the primary header with history and comments

    header = meta_to_header(c_data.header)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=fits.verify.VerifyWarning)
        if c_data.sampling_freq is not None:
            header["sampling_freq"] = c_data.sampling_freq.to(u.Hz).value

    hdu = [fits.PrimaryHDU(None, header=header)]
    hdu += c_data.to_hdus(

    hdu = fits.HDUList(hdu)
    hdu.writeto(filename, **kwd)

with registry.delay_doc_updates(ContMap):
    registry.register_reader("fits", ContMap, fits_contmap_reader)
    registry.register_writer("fits", ContMap, fits_contmap_writer)
    registry.register_identifier("fits", ContMap, fits.connect.is_fits)

def contmap_average(continuum_datas, normalize=False):
    """Return weighted average of severak ContMap, using inverse variance as the weights

    continuum_datas: list of class:`kidsdata.continuum_data.ContMap`
        the list of ContMap object
    normalize : bool
        normalize the uncertainty such that the snr std is 1, default False

    data : class:`kidsdata.continuum_data.ContMap`
        the resulting combined filtered ContMap object

    datas = np.array([item.data for item in continuum_datas])
    masks = np.array([item.mask for item in continuum_datas])
    hits = np.array([item.hits for item in continuum_datas])

    wcs = [item.wcs for item in continuum_datas]

    assert all([wcs[0].wcs == item.wcs for item in wcs[1:]]), "All wcs must be equal"

    weights = np.array([item.weights for item in continuum_datas])

    datas[masks] = 0.0
    weights[masks] = 0.0

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", RuntimeWarning)
        weight = np.sum(weights, axis=0)
        data = np.sum(datas * weights, axis=0) / weight

    mask = np.isnan(data)
    hits = np.sum(hits, axis=0)

    output = ContMap(data=data, uncertainty=InverseVariance(weight), wcs=wcs[0], hits=hits, mask=mask)

    if normalize:

    return output