takuseno/d3rlpy

View on GitHub
d3rlpy/dataset/transition_pickers.py

Summary

Maintainability
A
0 mins
Test Coverage
import dataclasses

import numpy as np
from typing_extensions import Protocol

from ..types import Float32NDArray
from .components import EpisodeBase, Transition
from .utils import (
    create_zero_observation,
    retrieve_observation,
    stack_recent_observations,
)

__all__ = [
    "TransitionPickerProtocol",
    "BasicTransitionPicker",
    "SparseRewardTransitionPicker",
    "FrameStackTransitionPicker",
    "MultiStepTransitionPicker",
]


def _validate_index(episode: EpisodeBase, index: int) -> None:
    assert index < episode.transition_count


class TransitionPickerProtocol(Protocol):
    r"""Interface of TransitionPicker."""

    def __call__(self, episode: EpisodeBase, index: int) -> Transition:
        r"""Returns transition specified by ``index``.

        Args:
            episode: Episode.
            index: Index at the target transition.

        Returns:
            Transition.
        """
        raise NotImplementedError


class BasicTransitionPicker(TransitionPickerProtocol):
    r"""Standard transition picker.

    This class implements a basic transition picking.
    """

    def __call__(self, episode: EpisodeBase, index: int) -> Transition:
        _validate_index(episode, index)

        observation = retrieve_observation(episode.observations, index)
        is_terminal = episode.terminated and index == episode.size() - 1
        if is_terminal:
            next_observation = create_zero_observation(observation)
            next_action = np.zeros_like(episode.actions[index])
        else:
            next_observation = retrieve_observation(
                episode.observations, index + 1
            )
            next_action = episode.actions[index + 1]

        return Transition(
            observation=observation,
            action=episode.actions[index],
            reward=episode.rewards[index],
            next_observation=next_observation,
            next_action=next_action,
            terminal=float(is_terminal),
            interval=1,
            rewards_to_go=episode.rewards[index:],
        )


class SparseRewardTransitionPicker(TransitionPickerProtocol):
    r"""Sparse reward transition picker.

    This class extends BasicTransitionPicker to handle special returns_to_go
    calculation mainly used in AntMaze environments.

    For the failure trajectories, this class sets the constant return value to
    avoid inconsistent horizon due to time out.

    Args:
        failure_return (int): Return value for failure trajectories.
        step_reward (float): Immediate step reward value in sparse reward
            setting.
    """

    def __init__(self, failure_return: float, step_reward: float = 0.0):
        self._failure_return = failure_return
        self._step_reward = step_reward
        self._transition_picker = BasicTransitionPicker()

    def __call__(self, episode: EpisodeBase, index: int) -> Transition:
        transition = self._transition_picker(episode, index)
        if np.all(transition.rewards_to_go == self._step_reward):
            extended_rewards_to_go: Float32NDArray = np.array(
                [[self._failure_return]],
                dtype=np.float32,
            )
            transition = dataclasses.replace(
                transition,
                rewards_to_go=extended_rewards_to_go,
            )
        return transition


class FrameStackTransitionPicker(TransitionPickerProtocol):
    r"""Frame-stacking transition picker.

    This class implements the frame-stacking logic. The observations are
    stacked with the last ``n_frames-1`` frames. When ``index`` specifies
    timestep below ``n_frames``, those frames are padded by zeros.

    .. code-block:: python

        episode = Episode(
            observations=np.random.random((100, 1, 84, 84)),
            actions=np.random.random((100, 2)),
            rewards=np.random.random((100, 1)),
            terminated=False,
        )

        frame_stacking_picker = FrameStackTransitionPicker(n_frames=4)
        transition = frame_stacking_picker(episode, 10)

        transition.observation.shape == (4, 84, 84)

    Args:
        n_frames (int): Number of frames to stack.
    """

    _n_frames: int

    def __init__(self, n_frames: int):
        assert n_frames > 0
        self._n_frames = n_frames

    def __call__(self, episode: EpisodeBase, index: int) -> Transition:
        _validate_index(episode, index)

        observation = stack_recent_observations(
            episode.observations, index, self._n_frames
        )
        is_terminal = episode.terminated and index == episode.size() - 1
        if is_terminal:
            next_observation = create_zero_observation(observation)
            next_action = np.zeros_like(episode.actions[index])
        else:
            next_observation = stack_recent_observations(
                episode.observations, index + 1, self._n_frames
            )
            next_action = episode.actions[index + 1]

        return Transition(
            observation=observation,
            action=episode.actions[index],
            reward=episode.rewards[index],
            next_observation=next_observation,
            next_action=next_action,
            terminal=float(is_terminal),
            interval=1,
            rewards_to_go=episode.rewards[index:],
        )


class MultiStepTransitionPicker(TransitionPickerProtocol):
    r"""Multi-step transition picker.

    This class implements transition picking for the multi-step TD error.
    ``reward`` is computed as a multi-step discounted return.

    Args:
        n_steps: Delta timestep between ``observation`` and
            ``net_observation``.
        gamma: Discount factor to compute a multi-step return.
    """

    _n_steps: int
    _gamma: float

    def __init__(self, n_steps: int, gamma: float):
        self._n_steps = n_steps
        self._gamma = gamma

    def __call__(self, episode: EpisodeBase, index: int) -> Transition:
        _validate_index(episode, index)

        observation = retrieve_observation(episode.observations, index)

        # get observation N-step ahead
        if episode.terminated:
            next_index = min(index + self._n_steps, episode.size())
            is_terminal = next_index == episode.size()
            if is_terminal:
                next_observation = create_zero_observation(observation)
                next_action = np.zeros_like(episode.actions[index])
            else:
                next_observation = retrieve_observation(
                    episode.observations, next_index
                )
                next_action = episode.actions[next_index]
        else:
            is_terminal = False
            next_index = min(index + self._n_steps, episode.size() - 1)
            next_observation = retrieve_observation(
                episode.observations, next_index
            )
            next_action = episode.actions[next_index]

        # compute multi-step return
        interval = next_index - index
        cum_gammas = np.expand_dims(self._gamma ** np.arange(interval), axis=1)
        ret = np.sum(episode.rewards[index:next_index] * cum_gammas, axis=0)

        return Transition(
            observation=observation,
            action=episode.actions[index],
            reward=ret,
            next_observation=next_observation,
            next_action=next_action,
            terminal=float(is_terminal),
            interval=interval,
            rewards_to_go=episode.rewards[index:],
        )