"""Module for feature wrappers providing features for different environments."""

from abc import abstractmethod
import functools
from typing import Union, Tuple

import gym
import numpy as np

from irl_benchmark.envs.maze_world import MazeWorld, RANDOM_QUIT_CHANCE, REWARD_SMALL, REWARD_MEDIUM, REWARD_LARGE
from irl_benchmark.utils.general import to_one_hot
from irl_benchmark.utils.wrapper import unwrap_env

class FeatureWrapper(gym.Wrapper):
    """Wrapper that adds features to the info dictionary in the step function.

    Generally, each environment needs its own feature wrapper."""

    def __init__(self, env: gym.Env):

        env: gym.Env
            The gym environment to be wrapped.
        super(FeatureWrapper, self).__init__(env)
        self.current_state = None

    def reset(self, **kwargs):  # pylint: disable=method-hidden, R0801
        """ Reset environment and return initial state.
        No changes to base class reset function."""
        self.current_state = self.env.reset()
        return self.current_state

    def step(self, action: Union[np.ndarray, int, float]
             ) -> Tuple[Union[np.ndarray, float, int], float, bool, dict]:

        action: Union[np.ndarray, int, float]

        Tuple[Union[np.ndarray, float, int], float, bool, dict]
            Tuple with values for (state, reward, done, info).
            Normal return values of any gym step function.
            A field 'features' is added to the returned info dict.
            The value of 'features' is a np.ndarray of shape (1, d)
            where d is the dimensionality of the feature space.
        # pylint: disable=E0202

        # execute action:
        next_state, reward, terminated, info = self.env.step(action)

        info['features'] = self.features(self.current_state, action,

        # remember which state we are in:
        self.current_state = next_state

        return next_state, reward, terminated, info

    def features(self, current_state: Union[np.ndarray, int, float],
                 action: Union[np.ndarray, int, float],
                 next_state: Union[np.ndarray, int, float]) -> np.ndarray:
        """Return features for a single state or a state-action pair or a
        state-action-next_state triplet. To be saved in step method's info dictionary.

        current_state: Union[np.ndarray, int, float]
            The current state. Can be None if used reward function has properties
            action_in_domain == False and next_state_in_domain == False and if
            next_state is not None. In that case the features are calculated for
            the next state and used for the reward function R(s) - the reward for
            reaching next_state.
        action: Union[np.ndarray, int, float]
            A single action. Has to be given if used reward function has property
            action_in_domain == True.
        next_state: Union[np.ndarray, int, float]
            The next state. Has to be given if used reward function has property
            next_state_in_domain == True.

            The features in a numpy array of shape (1, d), where d is the
            dimensionality of the feature space (see :meth:`.feature_dimensionality`).
        raise NotImplementedError()

    def feature_dimensionality(self) -> tuple:
        """Get the dimensionality of the feature space."""
        raise NotImplementedError()

    def feature_range(self) -> np.ndarray:
        """Get minimum and maximum values of all d features, where d is the
        dimensionality of the feature space (see :meth:`.feature_dimensionality`)

            The minimum and maximum values in an array of shape (2, d).
            First row corresponds to minimum values and second row to maximum values.
        raise NotImplementedError()

    def feature_array(self) -> np.ndarray:
        """ Get features for the entire domain as an array.
        Has to be overwritten in each feature wrapper.
        Wrappers for large environments will not implement this method.

            The features for the entire domain as an array.
            Shape: (domain_size, d).
        raise NotImplementedError()

class FrozenLakeFeatureWrapper(FeatureWrapper):
    """Feature wrapper that was ad hoc written for the FrozenLake env.

    Would also work to get one-hot features for any other discrete env
    such that feature-based IRL algorithms can be used in a tabular setting.

    def features(self, current_state: None, action: None,
                 next_state: int) -> np.ndarray:
        """Return features to be saved in step method's info dictionary.
        One-hot encoding the next state.

        current_state: None
        action: None
        next_state: int
            The next state.

            The features in a numpy array.
        assert next_state is not None
        if isinstance(next_state, (int, np.int64, np.ndarray)):
            return to_one_hot(next_state, self.env.observation_space.n)
            raise NotImplementedError()

    def feature_dimensionality(self) -> Tuple:
        """Return dimension of the one-hot vectors used as features."""
        return (self.env.observation_space.n, )

    def feature_range(self):
        """Get maximum and minimum values of all k features.

        `np.ndarray` of shape (2, k) w/ max in 1st and min in 2nd row.
        ranges = np.zeros((2, self.feature_dimensionality()[0]))
        ranges[1, :] = 1.0
        return ranges

    def feature_array(self):
        """Returns feature array for FrozenLake. Each state in the domain
        corresponds to a one_hot vector. Features of all states together
        are the identity matrix."""
        return np.eye(self.env.observation_space.n)

class MazeWorldFeatureWrapper(FeatureWrapper):
    def features(self, current_state: np.ndarray, action: int,
                 next_state: None) -> np.ndarray:
        """Return features to be saved in step method's info dictionary.

        There are four feature variables: expected walking distance,
        probability of reaching a small reward field, probability of reaching
        a medium reward field, probability of reaching a large reward field.
        Only one of the last three values will be non-zero."""

        maze_env = unwrap_env(self.env, MazeWorld)

        # can only calculate features for a single state-action pair.
        assert len(current_state.shape) == 1

        # special case: not at any position:
        if np.sum(current_state[:maze_env.num_rewards]) == 0:
            return np.array([1, 0, 0, 0])
        path_len = maze_env.get_path_len(current_state, action)

        # special case: all rewards collected:
        if np.sum(current_state[maze_env.num_rewards:]) == 0:
            return np.zeros(4)

        assert path_len > 0
        # special case: walking to current position
        if path_len == 1:
            # assert that agent is walking to its current position:
            assert current_state[action] == 1.0
            expected_walking_distance = 1.0
            # calculate expected walking distance feature:
            possible_distances = np.arange(1, path_len)
            prob_getting_to_distance = (
                1 - RANDOM_QUIT_CHANCE)**possible_distances
            prob_stopping_at_distance = np.ones_like(
                possible_distances, dtype=np.float32)
            prob_stopping_at_distance[:-1] = RANDOM_QUIT_CHANCE
            expected_walking_distance = np.sum(
                possible_distances * prob_getting_to_distance *

        # coin collection probabilities:
        ccps = np.zeros(3)
        rew_value = maze_env.get_rew_value(current_state, action)
        if rew_value != 0.:
            assert rew_value in [REWARD_SMALL, REWARD_MEDIUM, REWARD_LARGE]
            rew_value_index = [REWARD_SMALL, REWARD_MEDIUM,
            if path_len == 1:
                ccps[rew_value_index] = (1 - RANDOM_QUIT_CHANCE)
                ccps[rew_value_index] = (1 - RANDOM_QUIT_CHANCE)**(
                    path_len - 1)

        return np.concatenate((np.array([expected_walking_distance]), ccps))

    def feature_dimensionality(self):
        """Return dimension of the one-hot vectors used as features."""
        return (4, )

    def feature_range(self):
        """Return minimum and maximum values of features.
        Max is set to an arbitrary high value."""
        return np.array([[0, 0, 0, 0], [1e3] * 4])

    def feature_array(self) -> np.ndarray:
        """ Get features for the entire domain as an array.
        Has to be overwritten in each feature wrapper.
        Wrappers for large environments will not implement this method.

            The features for the entire domain as an array.
            Shape: (domain_size, d).
        maze_world = unwrap_env(self.env, MazeWorld)
        num_rewards = maze_world.num_rewards
        n_states = num_rewards * 2**num_rewards
        feature_array = np.zeros((n_states, num_rewards, 4))
        for s in range(n_states):
            for a in range(num_rewards):
                state = maze_world.index_to_state(s)
                feature = self.features(state, a, None)
                feature_array[s, a, :] = feature
        return feature_array

def feature_wrappable_envs() -> set:
    """Return list of ids for all gym environments that can currently be
    wrapped with a feature wrapper."""
    return set(_FEATURE_WRAPPERS.keys())

def make(key: str) -> FeatureWrapper:
    """Return a feature wrapper around the gym environment specified with key.

    key: str
        A gym environment's id (can be found as,
        for example 'FrozenLake-v0'.

        An environment created as :func:`irl_benchmark.envs.make_env`(key) wrapped in an
        adequate feature wrapper.
    # the _FEATURE_WRAPPERS dict is filled below by registering environments with
    # @_register_feature_wrapper.
    return _FEATURE_WRAPPERS[key]()

def _register_feature_wrapper(key: str):
    """Unified way of registering feature wrappers for gym environments."""

    def decorator(decorated_function):
        def wrapper_factory():
            # import unified way of creating environments
            # (usually using gym.make, with some exceptions
            from irl_benchmark.envs import make_env
            # return a new feature wrapper around a new gym environment:
            return decorated_function(make_env(key))

        # add docstring
        wrapper_factory.__doc__ = "Creates feature wrapper for {}".format(key)
        # add to list of feature-wrappable environments
        _FEATURE_WRAPPERS[key] = wrapper_factory
        return wrapper_factory

    return decorator

def frozen_lake(env: gym.Env):
    """Register 'FrozenLake-v0' feature wrapper."""
    return FrozenLakeFeatureWrapper(env)

def frozen_lake_8_8(env: gym.Env):
    """Register 'FrozenLake-v2' feature wrapper."""
    # same feature wrapper as for 'FrozenLake-v0' can be used
    # as size of state space is automatically extracted
    return FrozenLakeFeatureWrapper(env)

def maze_world_0(env: gym.Env):
    """Register MazeWorld0 (10 rewards) feature wrapper."""
    return MazeWorldFeatureWrapper(env)

def maze_world_1(env: gym.Env):
    """Register MazeWorld1 (10 rewards) feature wrapper."""
    return MazeWorldFeatureWrapper(env)