JohannesHeidecke/irl-benchmark

View on GitHub
irl_benchmark/irl/collect/__init__.py

Summary

Maintainability
A
3 hrs
Test Coverage
"""This module is used to collect trajectories for inverse reinforcement learning"""

import os
from typing import Dict, List, Union

import gym
import msgpack
import msgpack_numpy as m
from tqdm import tqdm

from irl_benchmark.rl.algorithms.base_algorithm import BaseRLAlgorithm

# set up MessagePack to be numpy compatible:
m.patch()


# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
def collect_trajs(env: gym.Env,
                  agent: BaseRLAlgorithm,
                  no_trajectories: int,
                  max_steps_per_episode: Union[int, None] = None,
                  store_to: Union[str, None] = None,
                  verbose: bool = False) -> List[Dict[str, list]]:
    """ Collect and return trajectories of an agent in an environment.

    Parameters
    ----------
    env: gym.Env
        A gym environment
    agent: BaseRLAlgorithm
        An RL algorithm / agent
    no_trajectories: int
        Number of trajectories to be collected
    max_steps_per_episode: int
        Maximum number of steps allowed to take in each episode. Optional.
        If not set, the environment's default is used.
    store_to: Union[str, None]
        If not None: a path of where trajectories should be persisted.
    verbose: bool
        Whether to use a tqdm progress bar while collecting.

    Returns
    -------
    List[Dict[str, list]]
        A list of trajectories. Each trajectory is a dictionary with keys
        ['states', 'actions', 'rewards', 'true_rewards', 'features']. The
        values of each dictionary are lists. The list of states contains
        both starting and final state and is one element longer than the
        list of actions. The lists of true_rewards and features can be
        empty if the environment was not wrapped in a RewardWrapper or
        FeatureWrapper respectively.
    """

    # list of trajectories to be returned:
    trajectories = []

    for _ in (tqdm(range(no_trajectories))
              if verbose else range(no_trajectories)):
        # start new episode by resetting environment:
        state = env.reset()
        done = False

        # trajectories contain starting state.
        # this makes len(states) to be one larger than len(actions)
        states = [state]
        # initialize other lists:
        actions = []
        rewards = []
        true_rewards = []
        features = []

        # if no max steps are specified,
        # check if there is a default provided by the environment:
        if env.spec is not None and env.spec.max_episode_steps is not None:
            assert isinstance(env.spec.max_episode_steps, int)
            if max_steps_per_episode is None:
                max_steps_per_episode = env.spec.max_episode_steps
            elif env.spec.max_episode_steps < max_steps_per_episode:
                print(
                    'WARNING: running episodes for longer than the default of '
                    + str(env.spec.max_episode_steps) +
                    ' for this environment.')

        step_counter = 0
        while not done \
            and (max_steps_per_episode is None or step_counter <
                 max_steps_per_episode):
            step_counter += 1
            # Agent picks an action:
            action = agent.pick_action(state)
            # Action is executed by environment:
            next_state, reward, done, info = env.step(action)
            # Append obtained information to respective lists:
            states.append(next_state)
            actions.append(action)
            rewards.append(reward)
            if 'true_reward' in info.keys():
                true_rewards.append(info['true_reward'])
            if 'features' in info.keys():
                features.append(info['features'])
            # Update which state we are in:
            state = next_state

        # Construct trajectory dictionary:
        trajectory = {
            'states': states,
            'actions': actions,
            'rewards': rewards,
            'true_rewards': true_rewards,
            'features': features
        }
        # Append trajectory to list of trajectories:
        trajectories.append(trajectory)

    # If requested, store trajetories to file:
    if store_to is not None:
        store_trajs(trajectories, store_to)

    return trajectories


def store_trajs(trajs, store_to):
    """Store trajectories to store_to/trajs.data."""
    if not os.path.exists(store_to):
        os.makedirs(store_to)
    file_path = os.path.join(store_to, 'trajs.data')
    with open(file_path, 'wb') as file:
        msgpack.pack(trajs, file)


def load_stored_trajs(load_from):
    """Return trajectories storead at load_from/trajs.data."""
    file_path = os.path.join(load_from, 'trajs.data')
    with open(file_path, 'rb') as file:
        trajs = msgpack.unpack(file)
    # convert byte keys back to string keys:
    trajs = [{
        key.decode('utf-8') if isinstance(key, bytes) else key: traj[key]
        for key in traj.keys()
    } for traj in trajs]
    return trajs