JohannesHeidecke/irl-benchmark

View on GitHub
irl_benchmark/irl/algorithms/base_algorithm.py

Summary

Maintainability
A
35 mins
Test Coverage
"""Module for the abstract base class of all IRL algorithms."""
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Tuple, Union

import gym
import numpy as np

from irl_benchmark.config import preprocess_config, IRL_CONFIG_DOMAINS, IRL_ALG_REQUIREMENTS
from irl_benchmark.irl.feature.feature_wrapper import FeatureWrapper
from irl_benchmark.irl.reward.reward_function import BaseRewardFunction
from irl_benchmark.irl.reward.reward_wrapper import RewardWrapper
from irl_benchmark.metrics.base_metric import BaseMetric
from irl_benchmark.rl.algorithms.base_algorithm import BaseRLAlgorithm
from irl_benchmark.rl.model.model_wrapper import BaseWorldModelWrapper
from irl_benchmark.utils.wrapper import is_unwrappable_to
import irl_benchmark.utils.irl as irl_utils


class BaseIRLAlgorithm(ABC):
    """The abstract base class for all IRL algorithms."""

    def __init__(self,
                 env: gym.Env,
                 expert_trajs: List[Dict[str, list]],
                 rl_alg_factory: Callable[[gym.Env], BaseRLAlgorithm],
                 metrics: List[BaseMetric] = None,
                 config: Union[dict, None] = None):
        """

        Parameters
        ----------
        env: gym.Env
            The gym environment to be trained on.
            Needs to be wrapped in a RewardWrapper to not leak the true reward function.
        expert_trajs: List[dict]
            A list of trajectories.
            Each trajectory is a dictionary with keys
            ['states', 'actions', 'rewards', 'true_rewards', 'features'].
            The values of each dictionary are lists.
            See :func:`irl_benchmark.irl.collect.collect_trajs`.
        rl_alg_factory: Callable[[gym.Env], BaseRLAlgorithm]
            A function which returns a new RL algorithm when called.
        config: dict
            A dictionary containing algorithm-specific parameters.
        """

        assert is_unwrappable_to(env, RewardWrapper)

        if IRL_ALG_REQUIREMENTS[type(self)]['requires_features']:
            assert is_unwrappable_to(env, FeatureWrapper)
        if IRL_ALG_REQUIREMENTS[type(self)]['requires_transitions']:
            assert is_unwrappable_to(env, BaseWorldModelWrapper)

        self.env = env
        self.expert_trajs = expert_trajs
        self.rl_alg_factory = rl_alg_factory
        if metrics is None:
            metrics = []
        self.metrics = metrics
        self.metric_results = [[]] * len(metrics)
        self.config = preprocess_config(self, IRL_CONFIG_DOMAINS, config)

    @abstractmethod
    def train(self, no_irl_iterations: int,
              no_rl_episodes_per_irl_iteration: int,
              no_irl_episodes_per_irl_iteration: int
              ) -> Tuple[BaseRewardFunction, BaseRLAlgorithm]:
        """Train the IRL algorithm.

        Parameters
        ----------
        no_irl_iterations: int
            The number of iteration the algorithm should be run.
        no_rl_episodes_per_irl_iteration: int
            The number of episodes the RL algorithm is allowed to run in
            each iteration of the IRL algorithm.
        no_irl_episodes_per_irl_iteration: int
            The number of episodes permitted to be run in each iteration
            to update the current reward estimate (e.g. to estimate state frequencies
            of the currently optimal policy).

        Returns
        -------
        Tuple[BaseRewardFunction, BaseRLAlgorithm]
            The estimated reward function and a RL agent trained for this estimate.
        """
        raise NotImplementedError()

    def evaluate_metrics(self, evaluation_input: dict):
        """ Evaluate all metrics in self.metrics. To be called at the end
        of each IRL iteration.

        Parameters
        ----------
        evaluation_input: specific metric input.
        """
        for metric in self.metrics:
            result = metric.evaluate(evaluation_input)
            self.metric_results.append(result)
            print(type(metric).__name__ + ': \t' + str(result))

    def feature_count(self, trajs: List[Dict[str, list]],
                      gamma: float) -> np.ndarray:
        """Return empirical discounted feature counts of input trajectories.

        Parameters
        ----------
        trajs: List[Dict[str, list]]
             A list of trajectories.
            Each trajectory is a dictionary with keys
            ['states', 'actions', 'rewards', 'true_rewards', 'features'].
            The values of each dictionary are lists.
            See :func:`irl_benchmark.irl.collect.collect_trajs`.
        gamma: float
            The discount factor. Must be in range [0., 1.].

        Returns
        -------
        np.ndarray
            A numpy array containing discounted feature counts. The shape
            is the same as the trajectories' feature shapes. One scalar
            feature count per feature.
        """
        # This was moved to utils:
        return irl_utils.feature_count(self.env, trajs, gamma)