import math
import yaml
import os
import re
import pathlib
import datetime
import logging
import tempfile

import numpy as np

from typing import List, Tuple, Optional, NamedTuple

from bopt.basic_types import Hyperparameter, OptimizationFailed
from bopt.hyperparam_values import HyperparamValues
from bopt.gp_config import GPConfig
from bopt.models.model import Model
from bopt.sample import Sample, CollectFlag, SampleCollection
from bopt.models.parameters import ModelParameters
from bopt.models.random_search import RandomSearch
from bopt.runner.abstract import Runner
from bopt.runner.runner_loader import RunnerLoader
from bopt.models.gpy_model import GPyModel

# TODO: set this at a proper global place
# logging.getLogger().setLevel(logging.DEBUG)
# logging.getLogger("matplotlib").setLevel(logging.INFO)

class ExperimentStats(NamedTuple):
    min: float
    max: float
    mean: float
    std: float
    median: float

class NoAliasDumper(yaml.Dumper):
    def ignore_aliases(self, data):
        return True

class Experiment:
    kernel_names = ["rbf", "Mat32", "Mat52"]
    acquisition_fn_names = ["ei", "pi"]

    task_name: str
    batch_name: Optional[str]

    hyperparameters: List[Hyperparameter]
    runner: Runner
    samples: List[Sample]
    result_regex: str

    gp_config: GPConfig

    def __init__(self, task_name: str, batch_name: Optional[str],
                 hyperparameters: List[Hyperparameter],
                 runner: Runner, result_regex: str,
                 gp_config: GPConfig,) -> None:
        self.task_name = task_name
        self.batch_name = batch_name
        self.hyperparameters = hyperparameters
        self.runner = runner
        self.samples = []
        self.result_regex = result_regex
        self.gp_config = gp_config

    def to_dict(self) -> dict:
        return {
            "task_name": self.task_name,
            "batch_name": self.batch_name,
            "hyperparameters": {h.name: h.to_dict() for h in self.hyperparameters},
            "samples": [s.to_dict() for s in self.samples],
            "runner": self.runner.to_dict(),
            "result_regex": self.result_regex,
            "gp_config": self.gp_config.to_dict()

    def from_dict(data: dict) -> "Experiment":
        hyperparameters = \
            [Hyperparameter.from_dict(key, data["hyperparameters"][key])
             for key in data["hyperparameters"].keys()]

        if data["samples"] and len(data["samples"]) > 0:
            samples = [Sample.from_dict(s, hyperparameters)
                       for s in data["samples"]]
            samples = []

        runner = RunnerLoader.from_dict(data["runner"])

        # assert "task_name" in data, "'task_name' is required, but was missing in {}".format(data)
        if "task_name" not in data:
            data["task_name"] = "XXX"
        if "batch_name" not in data:
            data["batch_name"] = "XXX"

        if isinstance(data["gp_config"], GPConfig):
            gp_config = data["gp_config"]
            gp_config = GPConfig.from_dict(data["gp_config"])

        experiment = Experiment(data["task_name"], data["batch_name"],
                                hyperparameters, runner, data["result_regex"],

        experiment.samples = samples

        return experiment

    def best_result(self) -> float:
        return self.stats().max

    def stats(self) -> ExperimentStats:
        results = [sample.result for sample in self.samples
                   if sample.result is not None]

        return ExperimentStats(min(results),

    def collect_results(self) -> None:
        # TODO: collect run time + check collected_at

        for sample in self.samples:
            if sample.collect_flag == CollectFlag.WAITING_FOR_SIMILAR:
                assert sample.result is None

                finished_similar_samples = self.get_finished_similar_samples(sample.hyperparam_values)

                if len(finished_similar_samples) > 0:
                    logging.info("Waiting for similar DONE, copying over results at %s",

                    picked_similar = finished_similar_samples[0]

                    sample.result = picked_similar.result
                    sample.finished_at = datetime.datetime.now()
                    sample.collected_at = sample.finished_at
                    sample.collect_flag = CollectFlag.COLLECT_OK
                    sample.run_time = (sample.finished_at - sample.created_at).total_seconds()

            elif sample.collect_flag == CollectFlag.WAITING_FOR_JOB:
                assert sample.job
                assert sample.result is None

                if sample.job.is_finished():
                    # Sine we're using `handle_cd` we always assume the working
                    # directory is where meta.yml is.
                    fname = os.path.join("output", f"job.o{sample.job.job_id}")

                    if os.path.exists(fname):
                        with open(fname, "r") as f:
                            contents = f.read().rstrip("\n")
                            found = False

                            for line in contents.split("\n"):
                                bash_time_regex = r"real\t(\d+)m(\d+.\d+)s"

                                time_matches = re.match(bash_time_regex, line)

                                if time_matches:
                                    g = time_matches.groups()
                                    sample.run_time = int(g[0]) * 60 + float(g[1])
                                    sample.finished_at = sample.created_at + \

                                    logging.info("Collect parsed runtime of %fs", sample.run_time)

                                matches = re.match(self.result_regex, line)

                                # RESULT=1,2,3,4

                                if matches:
                                    sample.result = float(matches.groups()[0])
                                    sample.collected_at = datetime.datetime.now()
                                    sample.collect_flag = CollectFlag.COLLECT_OK
                                    found = True

                                    if not sample.run_time:
                                        logging.debug("No TIME parsed from the output, using `collected_at instead`.")
                                        sample.run_time = (sample.collected_at - sample.created_at).total_seconds()

                                    logging.info("Collect got result %s", sample.result)

                            if not found:
                                logging.error("Job %d seems to have failed, "
                                              "it finished running and its result cannot "
                                              "be parsed.", sample.job.job_id)

                                sample.collect_flag = CollectFlag.COLLECT_FAILED
                        logging.error("Output file not found for job %d "
                                      "even though it finished. It will be considered "
                                      "as a failed job.", sample.job.job_id)

                        sample.collect_flag = CollectFlag.COLLECT_FAILED

    def samples_for_prediction(self) -> List[Sample]:
        return [s for s in self.samples if s.result or not s.model.sampled_from_random_search()]

    def predictive_samples_before(self, sample: Sample) -> List[Sample]:
        result = []

        for other in self.samples_for_prediction():
            other_date = other.finished_at or other.collected_at
            if not other_date:

            if other_date < sample.created_at:  # or sample == other:

        return result

    def get_xy(self):
        samples = self.samples_for_prediction()

        sample_col = SampleCollection(samples)
        X_sample, Y_sample = sample_col.to_xy()

        return X_sample, Y_sample

    def suggest(self) -> Tuple[HyperparamValues, Model]:
        job_params: HyperparamValues
        fitted_model: Model

        # TODO: overit, ze by to fungovalo i na ok+running a mean_pred
        if (len(self.samples_for_prediction()) < 2) or self.gp_config.random_search_only:
            logging.info("Sampling with random search.")

            job_params = RandomSearch.predict_next(self.hyperparameters)
            fitted_model = RandomSearch()
            X_sample, Y_sample = self.get_xy()

                job_params, fitted_model = GPyModel.predict_next(self.gp_config,
            except OptimizationFailed as e:
                logging.error("Optimization failed, retrying with "
                              "RandomSearch: %s", e)

                job_params = RandomSearch.predict_next(self.hyperparameters)
                fitted_model = RandomSearch()

        return job_params, fitted_model

    def run_next(self, num_similar_retries: int = 5) -> Tuple[Model, Sample]:
        found_similar = True

        # This makes sure we try at least `num_similar_retries` times to re-run the job.
        while found_similar and num_similar_retries > 0:
            num_similar_retries -= 1

            job_params, fitted_model = self.suggest()

            next_sample, found_similar = self.manual_run(job_params,

        return fitted_model, next_sample

    def get_similar_samples(self, hyperparam_values: HyperparamValues) \
            -> List[Sample]:
        return [s for s in self.samples
                if s.job and s.hyperparam_values.similar_to(hyperparam_values)]

    def get_finished_similar_samples(self, hyperparam_values: HyperparamValues) \
            -> List[Sample]:
        # Double filtering, but we don't care since there are only a few
        # samples anyway.
        return [s for s in self.get_similar_samples(hyperparam_values)
                if s.status() == CollectFlag.COLLECT_OK]

    def manual_run(self, hyperparam_values: HyperparamValues,
                   model_params: ModelParameters) -> Tuple[Sample, bool]:
        assert isinstance(hyperparam_values, HyperparamValues)

        output_dir_path = pathlib.Path("output")
        output_dir_path.mkdir(parents=True, exist_ok=True)

        logging.debug("Output set to: {}\t{}".format(output_dir_path, output_dir_path.absolute()))


        output_dir = str(output_dir_path)

        similar_samples = self.get_similar_samples(hyperparam_values)
        found_similar = len(similar_samples) > 0

        if found_similar:
            finished_similar_samples = self.get_finished_similar_samples(hyperparam_values)

            if len(finished_similar_samples) > 0:
                warning_str = "Found finished similar sample, "
                warning_str += "creating MANUAL_SAMPLE with equal hyperparam values and result"
                warning_str += "... param values:\n{}\n{}".format(hyperparam_values,


                similar_sample = finished_similar_samples[0]
                # print("ss", similar_sample)
                assert similar_sample.result is not None

                created_at = datetime.datetime.now()

                next_sample = Sample(None, model_params, hyperparam_values,
                                     similar_sample.mu_pred, similar_sample.sigma_pred,
                                     CollectFlag.COLLECT_OK, created_at)

                next_sample.collected_at = created_at
                next_sample.run_time = 0.0
                next_sample.result = similar_sample.result
                next_sample.comment = "created as similar of {}"\

                # TODO: opravit:
                #   - sample nemusi mit mu/sigma predikci
                #   - pokud uz byl vyhodnoceny, chci preskocit pousteni jobu a udelat "ManualSample"?
                similar_sample = similar_samples[0]

                next_sample = Sample(None, model_params, hyperparam_values,
                                     similar_sample.mu_pred, similar_sample.sigma_pred,

                next_sample.comment = "created as similar of {}"\

            manual_file_args = self.runner.fetch_and_shift_manual_file_args()
            job = self.runner.start(output_dir, hyperparam_values, manual_file_args)

            X_sample, Y_sample = self.get_xy()

            if len(X_sample) > 0:
                from bopt.models.gpy_model import GPyModel

                if model_params.can_predict_mean():
                    # Use the fitted model to predict mu/sigma.
                    gpy_model = GPyModel.from_model_params(self.gp_config,
                                                           X_sample, Y_sample)

                    model = gpy_model.model

                    # TODO: gpy pouzito na 2 mistech?
                    model = GPyModel.gpy_regression(self.hyperparameters,
                                                    self.gp_config, X_sample, Y_sample)

                X_next = np.array([hyperparam_values.x])

                mu, var = model.predict(X_next)
                sigma = np.sqrt(var)

                mu = float(mu)
                sigma = float(sigma)

                assert not math.isnan(float(mu))
                assert not math.isnan(float(sigma))
                mu = None
                sigma = None

            next_sample = Sample(job, model_params, hyperparam_values,
                                 mu, sigma, CollectFlag.WAITING_FOR_JOB,

            next_sample.comment = " ".join(manual_file_args)


        logging.debug("Serialization done")

        return next_sample, found_similar

    def sample_results(self) -> List[float]:
        # TODO: finished samples only?
        return [s.result for s in self.samples if s.result is not None]

    def bootstrapped_sample_results(self, num_bootstrap: int = 1000) -> List[float]:
        results = np.array(self.sample_results())

        MEAN_RESULTS = True

        if MEAN_RESULTS:
            means = [np.max(np.random.choice(results, size=len(results), replace=True))
                     for i in range(num_bootstrap)]
            means = np.random.choice(results, size=10000, replace=True).tolist()

        # if np.any(np.isnan(means)):
        #     raise RuntimeError("Received NAN while bootstrapping")
        return means

    def sample_cumulative_results(self) -> List[float]:
        return np.maximum.accumulate(self.sample_results()).tolist()

    def serialize(self) -> None:
        dump = yaml.dump(self.to_dict(), default_flow_style=False, Dumper=NoAliasDumper)

        temp_meta_fname = tempfile.mktemp(dir=".")

        with open(temp_meta_fname, "w") as f:

        os.rename(temp_meta_fname, "meta.yml")

    def deserialize() -> "Experiment":
        import json
        # import orjson as json
        # meta_json = "meta.json"
        # meta_yaml = "meta.yml"

        loaders = [
            ("meta.json", lambda x: json.loads(x)),
            ("meta.yml", lambda x: yaml.load(x, Loader=yaml.Loader)),

        for fname, loader in loaders:
            if os.path.exists(fname):
                with open(fname, "r") as f:
                    obj = loader(f.read())

                    experiment = Experiment.from_dict(obj)
                    # experiment.collect_results()
                    # experiment.serialize()

                    return experiment

        tested_fnames = [a[0] for a in loaders]
        raise RuntimeError("No meta file found, tested {}".format(tested_fnames))

        # TODO: remove once the new implementation is tested
        # if os.path.exists(meta_json):
        #     with open(meta_json, "r") as f:
        #         obj = json.loads(f.read())
        # elif os.path.exists(meta_yaml):
        # # if os.path.exists(meta_yaml):
        #     with open(meta_yaml, "r") as f:
        #         contents = f.read()
        #         obj = yaml.load(contents, Loader=yaml.Loader)