import typing
import random
import numpy as np
from loguru import logger

from stagesepx import toolbox
from stagesepx import constants
from stagesepx.hook import BaseHook
from stagesepx.video import VideoObject, VideoFrame

class VideoCutRange(object):
    def __init__(
        # TODO why can it be a dict?
        video: typing.Union[VideoObject, typing.Dict],
        start: int,
        end: int,
        # TODO need refactored ?
        ssim: typing.List[float],
        mse: typing.List[float],
        psnr: typing.List[float],
        start_time: float,
        end_time: float,
        if isinstance(video, dict):
            self.video = VideoObject(**video)
            self.video = video

        self.start = start
        self.end = end
        self.ssim = ssim
        self.mse = mse
        self.psnr = psnr
        self.start_time = start_time
        self.end_time = end_time

        # if length is 1
        # https://github.com/williamfzc/stagesepx/issues/9
        if start > end:
            self.start, self.end = self.end, self.start
            self.start_time, self.end_time = self.end_time, self.start_time

            f"new a range: {self.start}({self.start_time}) - {self.end}({self.end_time})"

    def can_merge(self, another: "VideoCutRange", offset: int = None, **_):
        if not offset:
            is_continuous = self.end == another.start
            is_continuous = self.end + offset >= another.start
        return is_continuous and self.video.path == another.video.path

    def merge(self, another: "VideoCutRange", **kwargs) -> "VideoCutRange":
        assert self.can_merge(another, **kwargs)
        return __class__(
            self.ssim + another.ssim,
            self.mse + another.mse,
            self.psnr + another.psnr,

    def contain(self, frame_id: int) -> bool:
        # in python:
        # range(0, 10) => [0, 10)
        # range(0, 10 + 1) => [0, 10]
        return frame_id in range(self.start, self.end + 1)

    # alias
    contain_frame_id = contain

    def contain_image(
        self, image_path: str = None, image_object: np.ndarray = None, *args, **kwargs
    ) -> typing.Dict[str, typing.Any]:
        # todo pick only one picture?
        target_id = self.pick(*args, **kwargs)[0]
        operator = self.video.get_operator()
        frame = operator.get_frame_by_id(target_id)
        return frame.contain_image(
            image_path=image_path, image_object=image_object, **kwargs

    def pick(
        self, frame_count: int = None, is_random: bool = None, *_, **__
    ) -> typing.List[int]:
        if not frame_count:
            frame_count = 3
            f"pick {frame_count} frames "
            f"from {self.start}({self.start_time}) "
            f"to {self.end}({self.end_time}) "
            f"on video {self.video.path}"

        result = list()
        if is_random:
            return random.sample(range(self.start, self.end), frame_count)
        length = self.get_length()

        # https://github.com/williamfzc/stagesepx/issues/37
        frame_count += 1
        for _ in range(1, frame_count):
            cur = int(self.start + length / frame_count * _)
        return result

    def get_frames(
        self, frame_id_list: typing.List[int], *_, **__
    ) -> typing.List[VideoFrame]:
        """return a list of VideoFrame, usually works with pick"""
        out = list()
        operator = self.video.get_operator()
        for each_id in frame_id_list:
            frame = operator.get_frame_by_id(each_id)
        return out

    def pick_and_get(self, *args, **kwargs) -> typing.List[VideoFrame]:
        picked = self.pick(*args, **kwargs)
        return self.get_frames(picked, *args, **kwargs)

    def get_length(self):
        return self.end - self.start + 1

    def is_stable(
        self, threshold: float = None, psnr_threshold: float = None, **_
    ) -> bool:
        # IMPORTANT function!
        # it decided whether a range is stable => everything is based on it!
        if not threshold:
            threshold = constants.DEFAULT_THRESHOLD

        # ssim
        res = np.mean(self.ssim) > threshold
        # psnr (double check if stable)
        if res and psnr_threshold:
            res = np.mean(self.psnr) > psnr_threshold

        return res

    def is_loop(self, threshold: float = None, **_) -> bool:
        if not threshold:
            threshold = constants.DEFAULT_THRESHOLD
        operator = self.video.get_operator()
        start_frame = operator.get_frame_by_id(self.start)
        end_frame = operator.get_frame_by_id(self.end)
        return toolbox.compare_ssim(start_frame.data, end_frame.data) > threshold

    def diff(
        another: "VideoCutRange",
        pre_hooks: typing.List[BaseHook],
    ) -> typing.List[float]:
        self_picked = self.pick_and_get(*args, **kwargs)
        another_picked = another.pick_and_get(*args, **kwargs)
        return toolbox.multi_compare_ssim(self_picked, another_picked, pre_hooks)

    def __str__(self):
        return f"<VideoCutRange [{self.start}({self.start_time})-{self.end}({self.end_time})] ssim={self.ssim}>"

    __repr__ = __str__