JohannesHeidecke/irl-benchmark

View on GitHub
irl_benchmark/utils/irl.py

Summary

Maintainability
A
0 mins
Test Coverage
"""Util functions related to IRL."""

from typing import Dict, List
import numpy as np

from irl_benchmark.irl.feature.feature_wrapper import FeatureWrapper
from irl_benchmark.utils.wrapper import unwrap_env, is_unwrappable_to


def feature_count(env, trajs: List[Dict[str, list]],
                  gamma: float) -> np.ndarray:
    """Return empirical discounted feature counts of input trajectories.

    Parameters
    ----------
    env: gym.Env
        A gym environment, wrapped in a feature wrapper
    trajs: List[Dict[str, list]]
         A list of trajectories.
        Each trajectory is a dictionary with keys
        ['states', 'actions', 'rewards', 'true_rewards', 'features'].
        The values of each dictionary are lists.
        See :func:`irl_benchmark.irl.collect.collect_trajs`.
    gamma: float
        The discount factor. Must be in range [0., 1.].

    Returns
    -------
    np.ndarray
        A numpy array containing discounted feature counts. The shape
        is the same as the trajectories' feature shapes. One scalar
        feature count per feature.
    """
    assert is_unwrappable_to(env, FeatureWrapper)

    # Initialize feature count sum to zeros of correct shape:
    feature_dim = unwrap_env(env, FeatureWrapper).feature_dimensionality()
    # feature_dim is a 1-tuple,
    # extract the feature dimensionality as integer:
    assert len(feature_dim) == 1
    feature_dim = feature_dim[0]
    feature_count_sum = np.zeros(feature_dim)

    for traj in trajs:
        assert traj['features']  # empty lists are False in python

        # gammas is a vector containing [gamma^0, gamma^1, gamma^2, ... gamma^l]
        # where l is length of the trajectory:
        gammas = gamma**np.arange(len(traj['features']))
        traj_feature_count = np.sum(
            gammas.reshape(-1, 1) * np.array(traj['features']).reshape(
                (-1, feature_dim)),
            axis=0)
        # add trajectory's feature count:
        feature_count_sum += traj_feature_count
    # divide feature_count_sum by number of trajectories to normalize:
    result = feature_count_sum / len(trajs)
    return result