takuseno/d3rlpy

View on GitHub
d3rlpy/dataset/trajectory_slicers.py

Summary

Maintainability
A
1 hr
Test Coverage
import numpy as np
from typing_extensions import Protocol

from ..types import Float32NDArray, Int32NDArray
from .components import EpisodeBase, PartialTrajectory
from .utils import batch_pad_array, batch_pad_observations, slice_observations

__all__ = [
    "TrajectorySlicerProtocol",
    "BasicTrajectorySlicer",
    "FrameStackTrajectorySlicer",
]


class TrajectorySlicerProtocol(Protocol):
    r"""Interface of TrajectorySlicer."""

    def __call__(
        self, episode: EpisodeBase, end_index: int, size: int
    ) -> PartialTrajectory:
        r"""Slice trajectory.

        This method returns a partial trajectory from ``t=end_index-size`` to
        ``t=end_index``. If ``end_index-size`` is smaller than 0, those parts
        will be padded by zeros.

        Args:
            episode: Episode.
            end_index: Index at the end of the sliced trajectory.
            size: Length of the sliced trajectory.

        Returns:
            Sliced trajectory.
        """
        raise NotImplementedError


class BasicTrajectorySlicer(TrajectorySlicerProtocol):
    r"""Standard trajectory slicer.

    This class implements a basic trajectory slicing.
    """

    def __call__(
        self, episode: EpisodeBase, end_index: int, size: int
    ) -> PartialTrajectory:
        end = end_index + 1
        start = max(end - size, 0)
        actual_size = end - start

        # prepare terminal flags
        terminals: Float32NDArray = np.zeros((actual_size, 1), dtype=np.float32)
        if episode.terminated and end_index == episode.size() - 1:
            terminals[-1][0] = 1.0

        # slice data
        observations = slice_observations(episode.observations, start, end)
        actions = episode.actions[start:end]
        rewards = episode.rewards[start:end]
        ret = np.sum(episode.rewards[start:])
        all_returns_to_go = ret - np.cumsum(episode.rewards[start:], axis=0)
        returns_to_go = all_returns_to_go[:actual_size].reshape((-1, 1))

        # prepare metadata
        timesteps: Int32NDArray = np.arange(start, end) + 1
        masks: Float32NDArray = np.ones(end - start, dtype=np.float32)

        # compute backward padding size
        pad_size = size - actual_size

        if pad_size == 0:
            return PartialTrajectory(
                observations=observations,
                actions=actions,
                rewards=rewards,
                returns_to_go=returns_to_go,
                terminals=terminals,
                timesteps=timesteps,
                masks=masks,
                length=size,
            )

        return PartialTrajectory(
            observations=batch_pad_observations(observations, pad_size),
            actions=batch_pad_array(actions, pad_size),
            rewards=batch_pad_array(rewards, pad_size),
            returns_to_go=batch_pad_array(returns_to_go, pad_size),
            terminals=batch_pad_array(terminals, pad_size),
            timesteps=batch_pad_array(timesteps, pad_size),
            masks=batch_pad_array(masks, pad_size),
            length=size,
        )


class FrameStackTrajectorySlicer(TrajectorySlicerProtocol):
    r"""Frame-stacking trajectory slicer.

    This class implements the frame-stacking logic. The observations are
    stacked with the last ``n_frames-1`` frames. When ``index`` specifies
    timestep below ``n_frames``, those frames are padded by zeros.

    .. code-block:: python

        episode = Episode(
            observations=np.random.random((100, 1, 84, 84)),
            actions=np.random.random((100, 2)),
            rewards=np.random.random((100, 1)),
            terminated=False,
        )

        frame_stacking_slicer = FrameStackTrajectorySlicer(n_frames=4)
        trajectory = frame_stacking_slicer(episode, 0, 10)

        trajectory.observations.shape == (10, 4, 84, 84)

    Args:
        n_frames: Number of frames to stack.
    """

    _n_frames: int

    def __init__(self, n_frames: int):
        assert n_frames > 0
        self._n_frames = n_frames

    def __call__(
        self, episode: EpisodeBase, end_index: int, size: int
    ) -> PartialTrajectory:
        end = end_index + 1
        start = max(end - size, 0)
        actual_size = end - start

        # prepare terminal flags
        terminals: Float32NDArray = np.zeros((actual_size, 1), dtype=np.float32)
        if episode.terminated and end_index == episode.size() - 1:
            terminals[-1][0] = 1.0

        # prepare stacked observation data with zero initialization
        assert (
            len(episode.observation_signature.shape) == 1
        ), "Tuple observations are not supported yet."
        stacked_shape = list(episode.observation_signature.shape[0])
        channel_size = stacked_shape[0]
        image_shape = stacked_shape[1:]
        stacked_observations = np.zeros(
            (self._n_frames, actual_size, channel_size, *image_shape)
        )

        # fill stacked observations
        for i in range(self._n_frames):
            offset = self._n_frames - i - 1
            frame_start = max(start - offset, 0)
            frame_end = max(end - offset, 0)
            pad_size = actual_size - (frame_end - frame_start)
            observations = slice_observations(
                episode.observations, frame_start, frame_end
            )
            stacked_observations[i, pad_size:] = observations

        # (N, T, C, W, H) -> (T, N, C, W, H)
        stacked_observations = np.swapaxes(stacked_observations, 0, 1)
        # (T, N, C, W, H) -> (T, N * C, W, H)
        stacked_observations = np.reshape(
            stacked_observations,
            [actual_size, channel_size * self._n_frames, *image_shape],
        )

        # slice data
        actions = episode.actions[start:end]
        rewards = episode.rewards[start:end]
        ret = np.sum(episode.rewards[start:])
        all_returns_to_go = ret - np.cumsum(episode.rewards[start:], axis=0)
        returns_to_go = all_returns_to_go[:actual_size].reshape((-1, 1))

        # prepare metadata
        timesteps: Int32NDArray = np.arange(start, end) + 1
        masks: Float32NDArray = np.ones(end - start, dtype=np.float32)

        # compute backward padding size
        pad_size = size - actual_size

        if pad_size == 0:
            return PartialTrajectory(
                observations=stacked_observations,
                actions=actions,
                rewards=rewards,
                returns_to_go=returns_to_go,
                terminals=terminals,
                timesteps=timesteps,
                masks=masks,
                length=size,
            )

        return PartialTrajectory(
            observations=batch_pad_observations(stacked_observations, pad_size),
            actions=batch_pad_array(actions, pad_size),
            rewards=batch_pad_array(rewards, pad_size),
            returns_to_go=batch_pad_array(returns_to_go, pad_size),
            terminals=batch_pad_array(terminals, pad_size),
            timesteps=batch_pad_array(timesteps, pad_size),
            masks=batch_pad_array(masks, pad_size),
            length=size,
        )