takuseno/d3rlpy

View on GitHub
d3rlpy/dataset/io.py

Summary

Maintainability
B
5 hrs
Test Coverage
from typing import BinaryIO, Sequence, Type, TypeVar, cast

import h5py
import numpy as np

from .components import Episode, EpisodeBase
from .episode_generator import EpisodeGenerator

__all__ = ["dump", "load", "load_v1", "DATASET_VERSION"]


DATASET_VERSION = "2.1"


def dump(episodes: Sequence[EpisodeBase], f: BinaryIO) -> None:
    r"""Writes episode data to file-like object.

    Args:
        episodes: Sequence of episodes.
        f: Binary file-like object.
    """
    with h5py.File(f, "w") as h5:
        keys = list(episodes[0].serialize().keys())
        h5.create_dataset("columns", data=keys)
        h5.create_dataset("num_episodes", data=len(episodes))
        for i, episode in enumerate(episodes):
            serializedData = episode.serialize()
            for key in keys:
                if isinstance(serializedData[key], (list, tuple)):
                    for j in range(len(serializedData[key])):
                        elm = serializedData[key][j]
                        h5.create_dataset(f"{key}_{i}_{j}", data=elm)
                else:
                    h5.create_dataset(f"{key}_{i}", data=serializedData[key])
        h5.create_dataset("version", data=DATASET_VERSION)
        h5.flush()


_TEpisode = TypeVar("_TEpisode", bound=EpisodeBase)


def load(episode_cls: Type[_TEpisode], f: BinaryIO) -> Sequence[_TEpisode]:
    r"""Constructs episodes from file-like object.

    Args:
        episode_cls: Episode class.
        f: Binary file-like object.

    Returns:
        Sequence of episodes.
    """
    episodes = []
    with h5py.File(f, "r") as h5:
        version = cast(str, h5["version"][()])
        if version == "2.0":
            raise ValueError("version=2.0 has been obsolete.")
        keys = cast(
            Sequence[str],
            list(map(lambda s: s.decode("utf-8"), h5["columns"][()])),
        )
        num_episodes = cast(int, h5["num_episodes"][()])
        for i in range(num_episodes):
            data = {}
            for key in keys:
                path = f"{key}_{i}"
                if path in h5:
                    data[key] = h5[path][()]
                else:
                    j = 0
                    tuple_data = []
                    while True:
                        tuple_path = f"{key}_{i}_{j}"
                        if tuple_path in h5:
                            tuple_data.append(h5[tuple_path][()])
                        else:
                            break
                        j += 1
                    data[key] = tuple_data
            episode = episode_cls.deserialize(data)
            episodes.append(episode)
    return episodes  # type: ignore


def load_v1(f: BinaryIO) -> Sequence[Episode]:
    r"""Loads v1 dataset data.

    Args:
        f: Binary file-like object.

    Returns:
        Sequence of episodes.
    """
    with h5py.File(f, "r") as h5:
        observations = h5["observations"][()]
        actions = h5["actions"][()]
        rewards = h5["rewards"][()]
        terminals = h5["terminals"][()]

        if "episode_terminals" in h5:
            episode_terminals = h5["episode_terminals"][()]
        else:
            episode_terminals = None

    if episode_terminals is None:
        timeouts = None
    else:
        timeouts = np.logical_and(np.logical_not(terminals), episode_terminals)

    episode_generator = EpisodeGenerator(
        observations=observations,
        actions=actions,
        rewards=rewards,
        terminals=terminals,
        timeouts=timeouts,
    )

    return episode_generator()