JohannesHeidecke/irl-benchmark

View on GitHub
irl_benchmark/rl/algorithms/value_iteration.py

Summary

Maintainability
A
35 mins
Test Coverage
"""Module for value iteration RL algorithm."""
from typing import Union

import gym
import numpy as np

from irl_benchmark.envs.maze_world import MazeWorld
from irl_benchmark.config import RL_CONFIG_DOMAINS, RL_ALG_REQUIREMENTS
from irl_benchmark.rl.algorithms.base_algorithm import BaseRLAlgorithm
from irl_benchmark.rl.model.model_wrapper import BaseWorldModelWrapper
from irl_benchmark.utils.wrapper import is_unwrappable_to, unwrap_env


class ValueIteration(BaseRLAlgorithm):
    """Value iteration algorithm.

    Solves an MDP given exact knowledge of transition dynamics and rewards.
    Currently only implemented for DiscreteEnv environments.
    """

    def __init__(self, env: gym.Env, config: Union[None, dict] = None):
        """

        Parameters
        ----------
        env: gym.Env
            A DiscreteEnv environment
        config: dict
            Configuration of hyperparameters.
        """
        assert is_unwrappable_to(env, BaseWorldModelWrapper)
        super(ValueIteration, self).__init__(env, config)

        self.model_wrapper = unwrap_env(env, BaseWorldModelWrapper)

        # +1 for absorbing state
        self.no_states = self.model_wrapper.n_states() + 1
        self.no_actions = env.action_space.n
        self.transitions = self.model_wrapper.get_transition_array()
        # will be filled in beginning of training:
        self.rewards = None
        # will be filled during training:
        self.state_values = None
        self.q_values = None
        # whenever self._policy is None, it will be re-calculated
        # based on current self.q_values when calling policy().
        self._policy = None

    def train(self, no_episodes: int):
        """ Train the agent

        Parameters
        ----------
        no_episodes: int
            Not used in this algorithm (since it assumes known transition dynamics)
        """
        assert is_unwrappable_to(
            self.env,
            gym.envs.toy_text.discrete.DiscreteEnv) or is_unwrappable_to(
                self.env, MazeWorld)
        # extract reward function from env (using wrapped reward function if available):
        self.rewards = self.model_wrapper.get_reward_array()

        # initialize state values:
        state_values = np.zeros([self.no_states])

        while True:  # stops when state values converge
            # remember old values for error computation
            old_state_values = state_values.copy()
            # calculate Q-values:
            q_values = self.rewards + \
                       self.config['gamma'] * self.transitions.dot(state_values)
            # calculate state values either with maximum or mellow maximum:
            if self.config['temperature'] is None:
                # using default maximum operator:
                state_values = self._argmax_state_values(q_values)
            else:
                # using softmax:
                state_values = self._softmax_state_values(q_values)

            # stopping condition:
            # check if state values converged (almost no change since last iteration:
            if np.allclose(
                    state_values, old_state_values,
                    atol=self.config['epsilon']):
                break

        # persist learned state values and Q-values:
        self.state_values = state_values
        self.q_values = q_values
        # flag to tell other methods that policy needs to be updated based on new values:
        self._policy = None

    def pick_action(self, state: int) -> int:
        """ Pick an action given a state.

        The way of picking is either uniformly random between all best options (argmax)
        or according to mellowmax distribution, if 'temperature' is not None in config.
        See :meth:`.policy`.

        Parameters
        ----------
        state: int
            An integer corresponding to a state of a DiscreteEnv.

        Returns
        -------
        int
            An action for a DiscreteEnv.
        """
        return np.random.choice(self.no_actions, p=self.policy(state))

    def policy(self, state: int) -> np.ndarray:
        """ Return the probabilities of picking all possible actions given a state.

        The probabilities are either uniformly random between all best options (argmax)
        or according to mellowmax distribution, if 'temperature' is not None in config.

        Parameters
        ----------
        state: int
            An integer corresponding to a state of a DiscreteEnv.

        Returns
        -------
        np.ndarray
            Action probabilities given the state.

        """
        # convert state to integer using model wrapper:
        state = self.model_wrapper.state_to_index(state)
        assert self.q_values is not None, "Call train() before calling this method."
        assert np.isscalar(state)
        assert isinstance(state, (int, np.int64))

        return self.policy_array()[state, :]

    def policy_array(self):
        """Return action probabilities for all states as a numpy array.

        Returns
        -------
        np.ndarray
            Array containing probabilities for actions given a state.
            Shape: (n_states, n_actions)
        """
        self._update_policy_if_necessary()
        return self._policy

    def _update_policy_if_necessary(self):
        if self._policy is None:
            if self.config['temperature'] is None:
                self._policy = self._argmax_policy(self.q_values)
            else:
                self._policy = self._softmax_policy(self.q_values)

    def _argmax_policy(self, q_values: np.ndarray) -> np.ndarray:
        """ Calculate an argmax policy.

        Only picks actions with maximal Q-value given a state. If several actions
        are maxima, they are equally likely to be picked.

        Parameters
        ----------
        q_values:
            Q-values for all state-action pairs. Shape (n_states, n_actions)
        Returns
        -------
        np.ndarray
            Probabilities for actions given a state. Shape (n_states, n_actions)
        """
        # Find best actions:
        best_actions = np.isclose(q_values,
                                  np.max(q_values, axis=1).reshape((-1, 1)))
        # Initialize probabilities to be zero
        policy = np.zeros((self.no_states, self.no_actions))
        # Assign probability max to all best actions:
        policy[best_actions] = 1
        # Normalize values so their sum is 1. for each state:
        policy /= np.sum(policy, axis=1, keepdims=True)
        return policy

    def _softmax_policy(self, q_values):
        assert self.config['temperature'] is not None
        temperature = self.config['temperature']
        # for numerical stability (avoiding under- or overflow of exponent),
        # re-scale exponent without changing results of softmax,
        # using softmax(x) = softmax(x + c) for any constant c
        q_max = q_values.max(axis=1, keepdims=True)
        q_scaled = (q_values - q_max) / temperature
        # calculate softmax policy:
        policy = np.exp(q_scaled)
        # normalize values so their sum is 1. for each state:
        policy /= np.sum(policy, axis=1, keepdims=True)
        return policy

    def _mellowmax_policy(self, q_values):
        raise NotImplementedError()

    def _argmax_state_values(self, q_values):
        assert self.config['temperature'] is None
        return np.max(q_values, axis=1)

    def _softmax_state_values(self, q_values):
        assert self.config['temperature'] is not None
        # obtain probabilities of picking each (s, a):
        softmax_policy = self._softmax_policy(q_values)
        # multiply q_values by probability of picking them
        # then sum over actions to get state values:
        softmax_state_values = (softmax_policy * q_values).sum(axis=1)
        return softmax_state_values

    def _mellowmax_state_values(self, q_values):
        raise NotImplementedError()


RL_CONFIG_DOMAINS[ValueIteration] = {
    'gamma': {
        'type': float,
        'min': 0.0,
        'max': 1.0,
        'default': 0.9,
    },
    'epsilon': {
        'type': float,
        'min': 0.0,
        'max': float('inf'),
        'default': 1e-6,
    },
    'temperature': {
        'type': float,
        'optional': True,  # allows value to be None
        'min': 1e-10,
        'max': float('inf'),
        'default': None
    }
}

RL_ALG_REQUIREMENTS[ValueIteration] = {
    'requires_features': True,
    'requires_transitions': True,
}