takuseno/d3rlpy

View on GitHub
d3rlpy/dataset/mini_batch.py

Summary

Maintainability
A
0 mins
Test Coverage
import dataclasses
from typing import Sequence, Union

import numpy as np

from ..types import Float32NDArray, Shape
from .components import PartialTrajectory, Transition
from .utils import (
    cast_recursively,
    check_dtype,
    check_non_1d_array,
    get_shape_from_observation_sequence,
    stack_observations,
)

__all__ = ["TransitionMiniBatch", "TrajectoryMiniBatch"]


@dataclasses.dataclass(frozen=True)
class TransitionMiniBatch:
    r"""Mini-batch of transitions.

    Args:
        observations: Batched observations.
        actions: Batched actions.
        rewards: Batched rewards.
        next_observations: Batched next observations.
        returns_to_go: Batched returns-to-go.
        terminals: Batched environment terminal flags.
        intervals: Batched timesteps between observations and next
            observations.
        transitions: List of transitions.
    """

    observations: Union[Float32NDArray, Sequence[Float32NDArray]]  # (B, ...)
    actions: Float32NDArray  # (B, ...)
    rewards: Float32NDArray  # (B, 1)
    next_observations: Union[
        Float32NDArray, Sequence[Float32NDArray]
    ]  # (B, ...)
    terminals: Float32NDArray  # (B, 1)
    intervals: Float32NDArray  # (B, 1)
    transitions: Sequence[Transition]

    def __post_init__(self) -> None:
        assert check_non_1d_array(self.observations)
        assert check_dtype(self.observations, np.float32)
        assert check_non_1d_array(self.actions)
        assert check_dtype(self.actions, np.float32)
        assert check_non_1d_array(self.rewards)
        assert check_dtype(self.rewards, np.float32)
        assert check_non_1d_array(self.next_observations)
        assert check_dtype(self.next_observations, np.float32)
        assert check_non_1d_array(self.terminals)
        assert check_dtype(self.terminals, np.float32)
        assert check_non_1d_array(self.intervals)
        assert check_dtype(self.intervals, np.float32)

    @classmethod
    def from_transitions(
        cls, transitions: Sequence[Transition]
    ) -> "TransitionMiniBatch":
        r"""Constructs mini-batch from list of transitions.

        Args:
            transitions: List of transitions.

        Returns:
            Mini-batch.
        """
        observations = stack_observations(
            [transition.observation for transition in transitions]
        )
        actions = np.stack(
            [transition.action for transition in transitions], axis=0
        )
        rewards = np.stack(
            [transition.reward for transition in transitions], axis=0
        )
        next_observations = stack_observations(
            [transition.next_observation for transition in transitions]
        )
        terminals = np.reshape(
            np.array([transition.terminal for transition in transitions]),
            [-1, 1],
        )
        intervals = np.reshape(
            np.array([transition.interval for transition in transitions]),
            [-1, 1],
        )
        return TransitionMiniBatch(
            observations=cast_recursively(observations, np.float32),
            actions=cast_recursively(actions, np.float32),
            rewards=cast_recursively(rewards, np.float32),
            next_observations=cast_recursively(next_observations, np.float32),
            terminals=cast_recursively(terminals, np.float32),
            intervals=cast_recursively(intervals, np.float32),
            transitions=transitions,
        )

    @property
    def observation_shape(self) -> Shape:
        r"""Returns observation shape.

        Returns:
            Observation shape.
        """
        return get_shape_from_observation_sequence(self.observations)

    @property
    def action_shape(self) -> Sequence[int]:
        r"""Returns action shape.

        Returns:
            Action shape.
        """
        return self.actions.shape[1:]

    @property
    def reward_shape(self) -> Sequence[int]:
        r"""Returns reward shape.

        Returns:
            Reward shape.
        """
        return self.rewards.shape[1:]

    def __len__(self) -> int:
        return int(self.actions.shape[0])


@dataclasses.dataclass(frozen=True)
class TrajectoryMiniBatch:
    r"""Mini-batch of trajectories.

    Args:
        observations: Batched sequence of observations.
        actions: Batched sequence of actions.
        rewards: Batched sequence of rewards.
        returns_to_go: Batched sequence of returns-to-go.
        terminals: Batched sequence of environment terminal flags.
        timesteps: Batched sequence of environment timesteps.
        masks: Batched masks that represent padding.
        length: Length of trajectories.
    """

    observations: Union[Float32NDArray, Sequence[Float32NDArray]]  # (B, L, ...)
    actions: Float32NDArray  # (B, L, ...)
    rewards: Float32NDArray  # (B, L, 1)
    returns_to_go: Float32NDArray  # (B, L, 1)
    terminals: Float32NDArray  # (B, L, 1)
    timesteps: Float32NDArray  # (B, L)
    masks: Float32NDArray  # (B, L)
    length: int

    def __post_init__(self) -> None:
        assert check_dtype(self.observations, np.float32)
        assert check_dtype(self.actions, np.float32)
        assert check_dtype(self.rewards, np.float32)
        assert check_dtype(self.returns_to_go, np.float32)
        assert check_dtype(self.terminals, np.float32)
        assert check_dtype(self.timesteps, np.float32)
        assert check_dtype(self.masks, np.float32)

    @classmethod
    def from_partial_trajectories(
        cls, trajectories: Sequence[PartialTrajectory]
    ) -> "TrajectoryMiniBatch":
        r"""Constructs mini-batch from list of trajectories.

        Args:
            trajectories: List of trajectories.

        Returns:
            Mini-batch of trajectories.
        """
        observations = stack_observations(
            [traj.observations for traj in trajectories]
        )
        actions = np.stack([traj.actions for traj in trajectories], axis=0)
        rewards = np.stack([traj.rewards for traj in trajectories], axis=0)
        returns_to_go = np.stack(
            [traj.returns_to_go for traj in trajectories], axis=0
        )
        terminals = np.stack([traj.terminals for traj in trajectories], axis=0)
        timesteps = np.stack([traj.timesteps for traj in trajectories], axis=0)
        masks = np.stack([traj.masks for traj in trajectories], axis=0)
        return TrajectoryMiniBatch(
            observations=cast_recursively(observations, np.float32),
            actions=cast_recursively(actions, np.float32),
            rewards=cast_recursively(rewards, np.float32),
            returns_to_go=cast_recursively(returns_to_go, np.float32),
            terminals=cast_recursively(terminals, np.float32),
            timesteps=cast_recursively(timesteps, np.float32),
            masks=cast_recursively(masks, np.float32),
            length=trajectories[0].length,
        )

    @property
    def observation_shape(self) -> Shape:
        r"""Returns observation shape.

        Returns:
            Observation shape.
        """
        return get_shape_from_observation_sequence(self.observations)

    @property
    def action_shape(self) -> Sequence[int]:
        r"""Returns action shape.

        Returns:
            Action shape.
        """
        return self.actions.shape[1:]

    @property
    def reward_shape(self) -> Sequence[int]:
        r"""Returns reward shape.

        Returns:
            Reward shape.
        """
        return self.rewards.shape[1:]

    def __len__(self) -> int:
        return int(self.actions.shape[0])