takuseno/d3rlpy

View on GitHub
d3rlpy/datasets.py

Summary

Maintainability
C
7 hrs
Test Coverage
# pylint: disable=unused-import,too-many-return-statements

import enum
import os
import random
import re
from typing import Any, Dict, List, Optional, Tuple
from urllib import request

import gym
import gymnasium
import numpy as np
from gym.wrappers.time_limit import TimeLimit
from gymnasium.spaces import Box as GymnasiumBox
from gymnasium.spaces import Dict as GymnasiumDictSpace
from gymnasium.wrappers.time_limit import TimeLimit as GymnasiumTimeLimit

from .dataset import (
    BasicTrajectorySlicer,
    BasicTransitionPicker,
    Episode,
    EpisodeGenerator,
    FrameStackTrajectorySlicer,
    FrameStackTransitionPicker,
    InfiniteBuffer,
    MDPDataset,
    ReplayBuffer,
    TrajectorySlicerProtocol,
    TransitionPickerProtocol,
    create_infinite_replay_buffer,
    load_v1,
)
from .envs import ChannelFirst, FrameStack, GoalConcatWrapper
from .logging import LOG
from .types import NDArray, UInt8NDArray

__all__ = [
    "DATA_DIRECTORY",
    "DROPBOX_URL",
    "CARTPOLE_URL",
    "CARTPOLE_RANDOM_URL",
    "PENDULUM_URL",
    "PENDULUM_RANDOM_URL",
    "get_cartpole",
    "get_pendulum",
    "get_atari",
    "get_atari_transitions",
    "get_d4rl",
    "get_dataset",
]

DATA_DIRECTORY = "d3rlpy_data"
DROPBOX_URL = "https://www.dropbox.com/s"
CARTPOLE_URL = f"{DROPBOX_URL}/uep0lzlhxpi79pd/cartpole_v1.1.0.h5?dl=1"
CARTPOLE_RANDOM_URL = f"{DROPBOX_URL}/4lgai7tgj84cbov/cartpole_random_v1.1.0.h5?dl=1"  # pylint: disable=line-too-long
PENDULUM_URL = f"{DROPBOX_URL}/ukkucouzys0jkfs/pendulum_v1.1.0.h5?dl=1"
PENDULUM_RANDOM_URL = f"{DROPBOX_URL}/hhbq9i6ako24kzz/pendulum_random_v1.1.0.h5?dl=1"  # pylint: disable=line-too-long


def get_cartpole(
    dataset_type: str = "replay",
    transition_picker: Optional[TransitionPickerProtocol] = None,
    trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
    render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gym.Env[NDArray, int]]:
    """Returns cartpole dataset and environment.

    The dataset is automatically downloaded to ``d3rlpy_data/cartpole.h5`` if
    it does not exist.

    Args:
        dataset_type: dataset type. Available options are
            ``['replay', 'random']``.
        transition_picker: TransitionPickerProtocol object.
        trajectory_slicer: TrajectorySlicerProtocol object.
        render_mode: Mode of rendering (``human``, ``rgb_array``).

    Returns:
        tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
    """
    if dataset_type == "replay":
        url = CARTPOLE_URL
        file_name = "cartpole_replay_v1.1.0.h5"
    elif dataset_type == "random":
        url = CARTPOLE_RANDOM_URL
        file_name = "cartpole_random_v1.1.0.h5"
    else:
        raise ValueError(f"Invalid dataset_type: {dataset_type}.")

    data_path = os.path.join(DATA_DIRECTORY, file_name)

    # download dataset
    if not os.path.exists(data_path):
        os.makedirs(DATA_DIRECTORY, exist_ok=True)
        print(f"Downloading cartpole.pkl into {data_path}...")
        request.urlretrieve(url, data_path)

    # load dataset
    with open(data_path, "rb") as f:
        episodes = load_v1(f)
    dataset = ReplayBuffer(
        InfiniteBuffer(),
        episodes=episodes,
        transition_picker=transition_picker,
        trajectory_slicer=trajectory_slicer,
    )

    # environment
    env = gym.make("CartPole-v1", render_mode=render_mode)

    return dataset, env


def get_pendulum(
    dataset_type: str = "replay",
    transition_picker: Optional[TransitionPickerProtocol] = None,
    trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
    render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gym.Env[NDArray, NDArray]]:
    """Returns pendulum dataset and environment.

    The dataset is automatically downloaded to ``d3rlpy_data/pendulum.h5`` if
    it does not exist.

    Args:
        dataset_type: dataset type. Available options are
            ``['replay', 'random']``.
        transition_picker: TransitionPickerProtocol object.
        trajectory_slicer: TrajectorySlicerProtocol object.
        render_mode: Mode of rendering (``human``, ``rgb_array``).

    Returns:
        tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
    """
    if dataset_type == "replay":
        url = PENDULUM_URL
        file_name = "pendulum_replay_v1.1.0.h5"
    elif dataset_type == "random":
        url = PENDULUM_RANDOM_URL
        file_name = "pendulum_random_v1.1.0.h5"
    else:
        raise ValueError(f"Invalid dataset_type: {dataset_type}.")

    data_path = os.path.join(DATA_DIRECTORY, file_name)

    if not os.path.exists(data_path):
        os.makedirs(DATA_DIRECTORY, exist_ok=True)
        print(f"Donwloading pendulum.pkl into {data_path}...")
        request.urlretrieve(url, data_path)

    # load dataset
    with open(data_path, "rb") as f:
        episodes = load_v1(f)
    dataset = ReplayBuffer(
        InfiniteBuffer(),
        episodes=episodes,
        transition_picker=transition_picker,
        trajectory_slicer=trajectory_slicer,
    )

    # environment
    env = gym.make("Pendulum-v1", render_mode=render_mode)

    return dataset, env


def _stack_frames(episode: Episode, num_stack: int) -> Episode:
    assert isinstance(episode.observations, np.ndarray)
    episode_length = episode.observations.shape[0]
    observations: UInt8NDArray = np.zeros(
        (episode_length, num_stack, 84, 84),
        dtype=np.uint8,
    )
    for i in range(num_stack):
        pad_size = num_stack - i - 1
        if pad_size > 0:
            observations[pad_size:, i] = np.reshape(
                episode.observations[:-pad_size], [-1, 84, 84]
            )
        else:
            observations[:, i] = np.reshape(episode.observations, [-1, 84, 84])
    return Episode(
        observations=observations,
        actions=episode.actions.copy(),
        rewards=episode.rewards.copy(),
        terminated=episode.terminated,
    )


def get_atari(
    env_name: str,
    num_stack: Optional[int] = None,
    sticky_action: bool = True,
    pre_stack: bool = False,
    render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gym.Env[NDArray, int]]:
    """Returns atari dataset and envrironment.

    The dataset is provided through d4rl-atari. See more details including
    available dataset from its GitHub page.

    .. code-block:: python

        from d3rlpy.datasets import get_atari

        dataset, env = get_atari('breakout-mixed-v0')

    References:
        * https://github.com/takuseno/d4rl-atari

    Args:
        env_name: environment id of d4rl-atari dataset.
        num_stack: the number of frames to stack (only applied to env).
        sticky_action: Flag to enable sticky action.
        pre_stack: Flag to pre-stack observations. If this is ``False``,
            ``FrameStackTransitionPicker`` and ``FrameStackTrajectorySlicer``
            will be used to stack observations at sampling-time.
        render_mode: Mode of rendering (``human``, ``rgb_array``).

    Returns:
        tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
    """
    try:
        import d4rl_atari  # type: ignore

        env = gym.make(
            env_name,
            render_mode=render_mode,
            sticky_action=sticky_action,
        )
        raw_dataset = env.get_dataset()  # type: ignore
        episode_generator = EpisodeGenerator(**raw_dataset)
        episodes = episode_generator()

        if pre_stack:
            stacked_episodes = []
            for episode in episodes:
                assert num_stack is not None
                stacked_episode = _stack_frames(episode, num_stack)
                stacked_episodes.append(stacked_episode)
            episodes = stacked_episodes

        picker: TransitionPickerProtocol
        slicer: TrajectorySlicerProtocol
        if num_stack is None or pre_stack:
            picker = BasicTransitionPicker()
            slicer = BasicTrajectorySlicer()
        else:
            picker = FrameStackTransitionPicker(num_stack or 1)
            slicer = FrameStackTrajectorySlicer(num_stack or 1)

        dataset = create_infinite_replay_buffer(
            episodes=episodes,
            transition_picker=picker,
            trajectory_slicer=slicer,
        )
        if num_stack:
            env = FrameStack(env, num_stack=num_stack)
        else:
            env = ChannelFirst(env)
        return dataset, env
    except ImportError as e:
        raise ImportError(
            "d4rl-atari is not installed.\n" "$ d3rlpy install d4rl_atari"
        ) from e


def get_atari_transitions(
    game_name: str,
    fraction: float = 0.01,
    index: int = 0,
    num_stack: Optional[int] = None,
    sticky_action: bool = True,
    pre_stack: bool = False,
    render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gym.Env[NDArray, int]]:
    """Returns atari dataset as a list of Transition objects and envrironment.

    The dataset is provided through d4rl-atari.
    The difference from ``get_atari`` function is that this function will
    sample transitions from all epochs.
    This function is necessary for reproducing Atari experiments.

    .. code-block:: python

        from d3rlpy.datasets import get_atari_transitions

        # get 1% of transitions from all epochs (1M x 50 epoch x 1% = 0.5M)
        dataset, env = get_atari_transitions('breakout', fraction=0.01)

    References:
        * https://github.com/takuseno/d4rl-atari

    Args:
        game_name: Atari 2600 game name in lower_snake_case.
        fraction: fraction of sampled transitions.
        index: index to specify which trial to load.
        num_stack: the number of frames to stack (only applied to env).
        sticky_action: Flag to enable sticky action.
        pre_stack: Flag to pre-stack observations. If this is ``False``,
            ``FrameStackTransitionPicker`` and ``FrameStackTrajectorySlicer``
            will be used to stack observations at sampling-time.
        render_mode: Mode of rendering (``human``, ``rgb_array``).

    Returns:
        tuple of a list of :class:`d3rlpy.dataset.Transition` and gym
        environment.
    """
    try:
        import d4rl_atari

        # each epoch consists of 1M steps
        num_transitions_per_epoch = int(1000000 * fraction)

        copied_episodes = []
        for i in range(50):
            env_name = f"{game_name}-epoch-{i + 1}-v{index}"
            LOG.info(f"Collecting {env_name}...")
            env = gym.make(
                env_name,
                sticky_action=sticky_action,
                render_mode=render_mode,
            )
            raw_dataset = env.get_dataset()  # type: ignore
            episode_generator = EpisodeGenerator(**raw_dataset)
            episodes = list(episode_generator())

            # copy episode data to release memory of unused data
            random.shuffle(episodes)
            num_data = 0
            for episode in episodes:
                if num_data >= num_transitions_per_epoch:
                    break

                assert isinstance(episode.observations, np.ndarray)
                copied_episode = Episode(
                    observations=episode.observations.copy(),
                    actions=episode.actions.copy(),
                    rewards=episode.rewards.copy(),
                    terminated=episode.terminated,
                )
                if pre_stack:
                    assert num_stack is not None
                    copied_episode = _stack_frames(copied_episode, num_stack)

                # trim episode
                if num_data + copied_episode.size() > num_transitions_per_epoch:
                    end = num_transitions_per_epoch - num_data
                    copied_episode = Episode(
                        observations=copied_episode.observations[:end],
                        actions=copied_episode.actions[:end],
                        rewards=copied_episode.rewards[:end],
                        terminated=False,
                    )

                copied_episodes.append(copied_episode)
                num_data += copied_episode.size()

        picker: TransitionPickerProtocol
        slicer: TrajectorySlicerProtocol
        if num_stack is None or pre_stack:
            picker = BasicTransitionPicker()
            slicer = BasicTrajectorySlicer()
        else:
            picker = FrameStackTransitionPicker(num_stack or 1)
            slicer = FrameStackTrajectorySlicer(num_stack or 1)

        dataset = ReplayBuffer(
            InfiniteBuffer(),
            episodes=copied_episodes,
            transition_picker=picker,
            trajectory_slicer=slicer,
        )

        if num_stack:
            env = FrameStack(env, num_stack=num_stack)
        else:
            env = ChannelFirst(env)

        return dataset, env
    except ImportError as e:
        raise ImportError(
            "d4rl-atari is not installed.\n" "$ d3rlpy install d4rl_atari"
        ) from e


def get_d4rl(
    env_name: str,
    transition_picker: Optional[TransitionPickerProtocol] = None,
    trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
    render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gym.Env[NDArray, NDArray]]:
    """Returns d4rl dataset and envrironment.

    The dataset is provided through d4rl.

    .. code-block:: python

        from d3rlpy.datasets import get_d4rl

        dataset, env = get_d4rl('hopper-medium-v0')

    References:
        * `Fu et al., D4RL: Datasets for Deep Data-Driven Reinforcement
          Learning. <https://arxiv.org/abs/2004.07219>`_
        * https://github.com/rail-berkeley/d4rl

    Args:
        env_name: environment id of d4rl dataset.
        transition_picker: TransitionPickerProtocol object.
        trajectory_slicer: TrajectorySlicerProtocol object.
        render_mode: Mode of rendering (``human``, ``rgb_array``).

    Returns:
        tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
    """
    try:
        import d4rl  # type: ignore

        env = gym.make(env_name)
        raw_dataset: Dict[str, NDArray] = env.get_dataset()  # type: ignore

        observations = raw_dataset["observations"]
        actions = raw_dataset["actions"]
        rewards = raw_dataset["rewards"]
        terminals = raw_dataset["terminals"]
        timeouts = raw_dataset["timeouts"]

        dataset = MDPDataset(
            observations=observations,
            actions=actions,
            rewards=rewards,
            terminals=terminals,
            timeouts=timeouts,
            transition_picker=transition_picker,
            trajectory_slicer=trajectory_slicer,
        )

        # wrapped by NormalizedBoxEnv that is incompatible with newer Gym
        unwrapped_env: gym.Env[Any, Any] = env.env.env.env.wrapped_env  # type: ignore
        unwrapped_env.render_mode = render_mode  # overwrite

        return dataset, TimeLimit(unwrapped_env, max_episode_steps=1000)
    except ImportError as e:
        raise ImportError(
            "d4rl is not installed.\n" "$ d3rlpy install d4rl"
        ) from e


class _MinariEnvType(enum.Enum):
    BOX = 0
    GOAL_CONDITIONED = 1


def get_minari(
    env_name: str,
    transition_picker: Optional[TransitionPickerProtocol] = None,
    trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
    render_mode: Optional[str] = None,
    tuple_observation: bool = False,
) -> Tuple[ReplayBuffer, gymnasium.Env[Any, Any]]:
    """Returns minari dataset and envrironment.

    The dataset is provided through minari.

    .. code-block:: python
        from d3rlpy.datasets import get_minari
        dataset, env = get_minari('door-cloned-v1')

    Args:
        env_name: environment id of minari dataset.
        transition_picker: TransitionPickerProtocol object.
        trajectory_slicer: TrajectorySlicerProtocol object.
        render_mode: Mode of rendering (``human``, ``rgb_array``).
        tuple_observation: Flag to include goals as tuple element.

    Returns:
        tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
    """
    try:
        import minari

        _dataset = minari.load_dataset(env_name, download=True)
        env = _dataset.recover_environment()
        unwrapped_env = env.unwrapped
        unwrapped_env.render_mode = render_mode

        if isinstance(env.observation_space, GymnasiumBox):
            env_type = _MinariEnvType.BOX
        elif (
            isinstance(env.observation_space, GymnasiumDictSpace)
            and "observation" in env.observation_space.spaces
            and "desired_goal" in env.observation_space.spaces
        ):
            env_type = _MinariEnvType.GOAL_CONDITIONED
            unwrapped_env = GoalConcatWrapper(
                unwrapped_env, tuple_observation=tuple_observation
            )
        else:
            raise ValueError(
                f"Unsupported observation space: {env.observation_space}"
            )

        observations = []
        actions = []
        rewards = []
        terminals = []
        timeouts = []

        for ep in _dataset:
            if env_type == _MinariEnvType.BOX:
                _observations = ep.observations
            elif env_type == _MinariEnvType.GOAL_CONDITIONED:
                assert isinstance(ep.observations, dict)
                if isinstance(ep.observations["desired_goal"], dict):
                    sorted_keys = sorted(
                        list(ep.observations["desired_goal"].keys())
                    )
                    goal_obs = np.concatenate(
                        [
                            ep.observations["desired_goal"][key]
                            for key in sorted_keys
                        ],
                        axis=-1,
                    )
                else:
                    goal_obs = ep.observations["desired_goal"]
                if tuple_observation:
                    _observations = (ep.observations["observation"], goal_obs)
                else:
                    _observations = np.concatenate(
                        [
                            ep.observations["observation"],
                            goal_obs,
                        ],
                        axis=-1,
                    )
            else:
                raise ValueError("Unsupported observation format.")
            observations.append(_observations)
            actions.append(ep.actions)
            rewards.append(ep.rewards)
            terminals.append(ep.terminations)
            timeouts.append(ep.truncations)

        if tuple_observation:
            stacked_observations = tuple(
                np.concatenate([observation[i] for observation in observations])
                for i in range(2)
            )
        else:
            stacked_observations = np.concatenate(observations)

        dataset = MDPDataset(
            observations=stacked_observations,
            actions=np.concatenate(actions),
            rewards=np.concatenate(rewards),
            terminals=np.concatenate(terminals),
            timeouts=np.concatenate(timeouts),
            transition_picker=transition_picker,
            trajectory_slicer=trajectory_slicer,
        )

        return dataset, GymnasiumTimeLimit(
            unwrapped_env, max_episode_steps=env.spec.max_episode_steps
        )

    except ImportError as e:
        raise ImportError(
            "minari is not installed.\n" "$ d3rlpy install minari"
        ) from e


ATARI_GAMES = [
    "adventure",
    "air-raid",
    "alien",
    "amidar",
    "assault",
    "asterix",
    "asteroids",
    "atlantis",
    "bank-heist",
    "battle-zone",
    "beam-rider",
    "berzerk",
    "bowling",
    "boxing",
    "breakout",
    "carnival",
    "centipede",
    "chopper-command",
    "crazy-climber",
    "defender",
    "demon-attack",
    "double-dunk",
    "elevator-action",
    "enduro",
    "fishing-derby",
    "freeway",
    "frostbite",
    "gopher",
    "gravitar",
    "hero",
    "ice-hockey",
    "jamesbond",
    "journey-escape",
    "kangaroo",
    "krull",
    "kung-fu-master",
    "montezuma-revenge",
    "ms-pacman",
    "name-this-game",
    "phoenix",
    "pitfall",
    "pong",
    "pooyan",
    "private-eye",
    "qbert",
    "riverraid",
    "road-runner",
    "robotank",
    "seaquest",
    "skiing",
    "solaris",
    "space-invaders",
    "star-gunner",
    "tennis",
    "time-pilot",
    "tutankham",
    "up-n-down",
    "venture",
    "video-pinball",
    "wizard-of-wor",
    "yars-revenge",
    "zaxxon",
]


def get_dataset(
    env_name: str,
    transition_picker: Optional[TransitionPickerProtocol] = None,
    trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
    render_mode: Optional[str] = None,
) -> Tuple[ReplayBuffer, gym.Env[Any, Any]]:
    """Returns dataset and envrironment by guessing from name.

    This function returns dataset by matching name with the following datasets.

    - cartpole-replay
    - cartpole-random
    - pendulum-replay
    - pendulum-random
    - d4rl-pybullet
    - d4rl-atari
    - d4rl

    .. code-block:: python

       import d3rlpy

       # cartpole dataset
       dataset, env = d3rlpy.datasets.get_dataset('cartpole')

       # pendulum dataset
       dataset, env = d3rlpy.datasets.get_dataset('pendulum')

       # d4rl-atari dataset
       dataset, env = d3rlpy.datasets.get_dataset('breakout-mixed-v0')

       # d4rl dataset
       dataset, env = d3rlpy.datasets.get_dataset('hopper-medium-v0')

    Args:
        env_name: environment id of the dataset.
        transition_picker: TransitionPickerProtocol object.
        trajectory_slicer: TrajectorySlicerProtocol object.
        render_mode: Mode of rendering (``human``, ``rgb_array``).

    Returns:
        tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
    """
    if env_name == "cartpole-replay":
        return get_cartpole(
            dataset_type="replay",
            transition_picker=transition_picker,
            trajectory_slicer=trajectory_slicer,
            render_mode=render_mode,
        )
    elif env_name == "cartpole-random":
        return get_cartpole(
            dataset_type="random",
            transition_picker=transition_picker,
            trajectory_slicer=trajectory_slicer,
            render_mode=render_mode,
        )
    elif env_name == "pendulum-replay":
        return get_pendulum(
            dataset_type="replay",
            transition_picker=transition_picker,
            trajectory_slicer=trajectory_slicer,
            render_mode=render_mode,
        )
    elif env_name == "pendulum-random":
        return get_pendulum(
            dataset_type="random",
            transition_picker=transition_picker,
            trajectory_slicer=trajectory_slicer,
            render_mode=render_mode,
        )
    elif re.match(r"^bullet-.+$", env_name):
        return get_d4rl(
            env_name,
            transition_picker=transition_picker,
            trajectory_slicer=trajectory_slicer,
            render_mode=render_mode,
        )
    elif re.match(r"hopper|halfcheetah|walker|ant", env_name):
        return get_d4rl(
            env_name,
            transition_picker=transition_picker,
            trajectory_slicer=trajectory_slicer,
            render_mode=render_mode,
        )
    raise ValueError(f"Unrecognized env_name: {env_name}.")