takuseno/d3rlpy

View on GitHub
d3rlpy/dataset/writers.py

Summary

Maintainability
A
1 hr
Test Coverage
from typing import Any, Dict, Sequence, Union

import numpy as np
from typing_extensions import Protocol

from ..types import NDArray, Observation, ObservationSequence
from .buffers import BufferProtocol
from .components import Episode, EpisodeBase, Signature
from .utils import get_dtype_from_observation, get_shape_from_observation

__all__ = [
    "WriterPreprocessProtocol",
    "BasicWriterPreprocess",
    "LastFrameWriterPreprocess",
    "ExperienceWriter",
]


class WriterPreprocessProtocol(Protocol):
    r"""Interface of WriterPreprocess."""

    def process_observation(self, observation: Observation) -> Observation:
        r"""Processes observation.

        Args:
            observation: Observation.

        Returns:
            Processed observation.
        """
        raise NotImplementedError

    def process_action(self, action: NDArray) -> NDArray:
        r"""Processes action.

        Args:
            action: Action.

        Returns:
            Processed action.
        """
        raise NotImplementedError

    def process_reward(self, reward: NDArray) -> NDArray:
        r"""Processes reward.

        Args:
            reward: Reward.

        Returns:
            Processed reward.
        """
        raise NotImplementedError


class BasicWriterPreprocess(WriterPreprocessProtocol):
    """Stanard data writer.

    This class implements identity preprocess.
    """

    def process_observation(self, observation: Observation) -> Observation:
        return observation

    def process_action(self, action: NDArray) -> NDArray:
        return action

    def process_reward(self, reward: NDArray) -> NDArray:
        return reward


class LastFrameWriterPreprocess(BasicWriterPreprocess):
    """Data writer that writes the last channel of observation.

    This class is designed to be used with ``FrameStackTransitionPicker``.
    """

    def process_observation(self, observation: Observation) -> Observation:
        if isinstance(observation, (list, tuple)):
            return [np.expand_dims(obs[-1], axis=0) for obs in observation]
        else:
            return np.expand_dims(observation[-1], axis=0)


class _ActiveEpisode(EpisodeBase):
    _preprocessor: WriterPreprocessProtocol
    _cache_size: int
    _cursor: int
    _observation_signature: Signature
    _action_signature: Signature
    _reward_signature: Signature
    _observations: Sequence[NDArray]
    _actions: NDArray
    _rewards: NDArray
    _terminated: bool
    _frozen: bool

    def __init__(
        self,
        preprocessor: WriterPreprocessProtocol,
        cache_size: int,
        observation_signature: Signature,
        action_signature: Signature,
        reward_signature: Signature,
    ) -> None:
        self._preprocessor = preprocessor
        self._cache_size = cache_size
        self._cursor = 0
        shapes = observation_signature.shape
        dtypes = observation_signature.dtype
        self._observations = [
            np.empty((cache_size, *shape), dtype=dtype)
            for shape, dtype in zip(shapes, dtypes)
        ]
        self._actions = np.empty(
            (cache_size, *action_signature.shape[0]),
            dtype=action_signature.dtype[0],
        )
        self._rewards = np.empty(
            (cache_size, *reward_signature.shape[0]),
            dtype=reward_signature.dtype[0],
        )
        self._terminated = False
        self._observation_signature = observation_signature
        self._action_signature = action_signature
        self._reward_signature = reward_signature
        self._frozen = True

    def append(
        self,
        observation: Observation,
        action: Union[int, NDArray],
        reward: Union[float, NDArray],
    ) -> None:
        assert self._frozen, "This episode is already shrinked."
        assert (
            self._cursor < self._cache_size
        ), "episode length exceeds cache_size."

        if not isinstance(action, np.ndarray) or action.ndim == 0:
            action = np.array([action], dtype=self._action_signature.dtype[0])
        if not isinstance(reward, np.ndarray) or reward.ndim == 0:
            reward = np.array([reward], dtype=self._reward_signature.dtype[0])

        # preprocess
        observation = self._preprocessor.process_observation(observation)
        action = self._preprocessor.process_action(action)
        reward = self._preprocessor.process_reward(reward)

        if isinstance(observation, (list, tuple)):
            for i, obs in enumerate(observation):
                self._observations[i][self._cursor] = obs
        else:
            self._observations[0][self._cursor] = observation
        self._actions[self._cursor] = action
        self._rewards[self._cursor] = reward
        self._cursor += 1

    def to_episode(self, terminated: bool) -> Episode:
        observations: ObservationSequence
        if len(self._observations) == 1:
            observations = self._observations[0][: self._cursor].copy()
        else:
            observations = [
                obs[: self._cursor].copy() for obs in self._observations
            ]
        return Episode(
            observations=observations,
            actions=self._actions[: self._cursor].copy(),
            rewards=self._rewards[: self._cursor].copy(),
            terminated=terminated,
        )

    def shrink(self, terminated: bool) -> None:
        episode = self.to_episode(terminated)
        if isinstance(episode.observations, np.ndarray):
            self._observations = [episode.observations]
        else:
            self._observations = episode.observations
        self._actions = episode.actions
        self._rewards = episode.rewards
        self._terminated = terminated
        self._frozen = True

    def size(self) -> int:
        return self._cursor

    @property
    def observations(self) -> ObservationSequence:
        if len(self._observations) == 1:
            return self._observations[0][: self._cursor]
        else:
            return [obs[: self._cursor] for obs in self._observations]

    @property
    def actions(self) -> NDArray:
        return self._actions[: self._cursor]

    @property
    def rewards(self) -> NDArray:
        return self._rewards[: self._cursor]

    @property
    def terminated(self) -> bool:
        return self._terminated

    @property
    def observation_signature(self) -> Signature:
        return self._observation_signature

    @property
    def action_signature(self) -> Signature:
        return self._action_signature

    @property
    def reward_signature(self) -> Signature:
        return self._reward_signature

    def compute_return(self) -> float:
        return float(np.sum(self.rewards[: self._cursor]))

    def serialize(self) -> Dict[str, Any]:
        return {
            "observations": self.observations,
            "actions": self.actions,
            "rewards": self.rewards,
            "terminated": self.terminated,
        }

    @classmethod
    def deserialize(cls, serializedData: Dict[str, Any]) -> "EpisodeBase":
        raise NotImplementedError("_ActiveEpisode cannot be deserialized.")

    def __len__(self) -> int:
        return self.size()

    @property
    def transition_count(self) -> int:
        return self.size() if self.terminated else self.size() - 1


class ExperienceWriter:
    """Experience writer.

    Args:
        buffer: Buffer.
        preprocessor: Writer preprocess.
        observation_signature: Signature of unprocessed observation.
        action_signature: Signature of unprocessed action.
        reward_signature: Signature of unprocessed reward.
        cache_size: Size of data in active episode. This needs to be larger
            than the maximum length of episodes.
        write_at_termination: Flag to write experiences to the buffer at the
            end of an episode all at once.
    """

    _preprocessor: WriterPreprocessProtocol
    _buffer: BufferProtocol
    _cache_size: int
    _write_at_termination: bool
    _observation_signature: Signature
    _action_signature: Signature
    _reward_signature: Signature
    _active_episode: _ActiveEpisode
    _step: int

    def __init__(
        self,
        buffer: BufferProtocol,
        preprocessor: WriterPreprocessProtocol,
        observation_signature: Signature,
        action_signature: Signature,
        reward_signature: Signature,
        cache_size: int = 10000,
        write_at_termination: bool = False,
    ):
        self._buffer = buffer
        self._preprocessor = preprocessor
        self._cache_size = cache_size
        self._write_at_termination = write_at_termination

        # preprocessed signatures
        if len(observation_signature.dtype) == 1:
            processed_observation = preprocessor.process_observation(
                observation_signature.sample()[0]
            )
            assert isinstance(processed_observation, np.ndarray)
            observation_signature = Signature(
                shape=[processed_observation.shape],
                dtype=[processed_observation.dtype],
            )
        else:
            processed_observation = preprocessor.process_observation(
                observation_signature.sample()
            )
            observation_shape = get_shape_from_observation(
                processed_observation
            )
            assert isinstance(observation_shape[0], (list, tuple))
            observation_dtype = get_dtype_from_observation(
                processed_observation
            )
            assert isinstance(observation_dtype, (list, tuple))
            observation_signature = Signature(
                shape=observation_shape,  # type: ignore
                dtype=observation_dtype,
            )

        processed_action = preprocessor.process_action(
            action_signature.sample()[0]
        )
        action_shape: Sequence[int]
        if (
            not isinstance(processed_action, np.ndarray)
            or processed_action.ndim == 0
        ):
            action_shape = (1,)
        else:
            action_shape = processed_action.shape
        action_signature = Signature(
            shape=[action_shape],
            dtype=[processed_action.dtype],
        )

        processed_reward = preprocessor.process_reward(
            reward_signature.sample()[0]
        )
        reward_shape: Sequence[int]
        if (
            not isinstance(processed_reward, np.ndarray)
            or processed_reward.ndim == 0
        ):
            reward_shape = (1,)
        else:
            reward_shape = processed_reward.shape
        reward_signature = Signature(
            shape=[reward_shape],
            dtype=[processed_reward.dtype],
        )

        self._observation_signature = observation_signature
        self._action_signature = action_signature
        self._reward_signature = reward_signature
        self._active_episode = _ActiveEpisode(
            preprocessor,
            cache_size=cache_size,
            observation_signature=observation_signature,
            action_signature=action_signature,
            reward_signature=reward_signature,
        )

    def write(
        self,
        observation: Observation,
        action: Union[int, NDArray],
        reward: Union[float, NDArray],
    ) -> None:
        r"""Writes state tuple to buffer.

        Args:
            observation: Observation.
            action: Action.
            reward: Reward.
        """
        self._active_episode.append(observation, action, reward)
        if (
            not self._write_at_termination
            and self._active_episode.transition_count > 0
        ):
            self._buffer.append(
                episode=self._active_episode,
                index=self._active_episode.transition_count - 1,
            )

    def clip_episode(self, terminated: bool) -> None:
        r"""Clips the current episode.

        Args:
            terminated: Flag to represent environment termination.
        """
        if self._active_episode.transition_count == 0:
            return

        if self._write_at_termination:
            for i in range(self._active_episode.transition_count):
                self._buffer.append(episode=self._active_episode, index=i)

        # shrink heap memory
        self._active_episode.shrink(terminated)

        # append terminal state if necessary
        if terminated:
            self._buffer.append(
                self._active_episode,
                self._active_episode.transition_count - 1,
            )

        # prepare next active episode
        self._active_episode = _ActiveEpisode(
            self._preprocessor,
            cache_size=self._cache_size,
            observation_signature=self._observation_signature,
            action_signature=self._action_signature,
            reward_signature=self._reward_signature,
        )