takuseno/d3rlpy

View on GitHub
d3rlpy/dataset/utils.py

Summary

Maintainability
A
0 mins
Test Coverage
from typing import Any, Sequence, Type, TypeVar, Union, overload

import numpy as np
import numpy.typing as npt
from gym.spaces import Box, Discrete
from gymnasium.spaces import Box as GymnasiumBox
from gymnasium.spaces import Discrete as GymnasiumDiscrete

from ..constants import ActionSpace
from ..types import (
    DType,
    GymEnv,
    NDArray,
    Observation,
    ObservationSequence,
    Shape,
)

__all__ = [
    "retrieve_observation",
    "create_zero_observation",
    "slice_observations",
    "batch_pad_array",
    "batch_pad_observations",
    "stack_recent_observations",
    "stack_observations",
    "get_shape_from_observation",
    "get_shape_from_observation_sequence",
    "get_dtype_from_observation",
    "get_dtype_from_observation_sequence",
    "check_dtype",
    "check_non_1d_array",
    "cast_recursively",
    "detect_action_space",
    "detect_action_space_from_env",
    "detect_action_size_from_env",
    "is_tuple_shape",
    "cast_flat_shape",
    "cast_tuple_shape",
    "get_axis_size",
    "get_batch_dim",
]


@overload
def retrieve_observation(observations: NDArray, index: int) -> NDArray: ...


@overload
def retrieve_observation(
    observations: Sequence[NDArray], index: int
) -> Sequence[NDArray]: ...


def retrieve_observation(
    observations: ObservationSequence, index: int
) -> Observation:
    if isinstance(observations, np.ndarray):
        return observations[index]  # type: ignore
    elif isinstance(observations, (list, tuple)):
        return [obs[index] for obs in observations]
    else:
        raise ValueError(f"invalid observations type: {type(observations)}")


@overload
def create_zero_observation(observation: NDArray) -> NDArray: ...


@overload
def create_zero_observation(
    observation: Sequence[NDArray],
) -> Sequence[NDArray]: ...


def create_zero_observation(observation: Observation) -> Observation:
    if isinstance(observation, np.ndarray):
        return np.zeros_like(observation)
    elif isinstance(observation, (list, tuple)):
        return [np.zeros_like(observation[i]) for i in range(len(observation))]
    else:
        raise ValueError(f"invalid observation type: {type(observation)}")


@overload
def slice_observations(
    observations: NDArray, start: int, end: int
) -> NDArray: ...


@overload
def slice_observations(
    observations: Sequence[NDArray], start: int, end: int
) -> Sequence[NDArray]: ...


def slice_observations(
    observations: ObservationSequence, start: int, end: int
) -> ObservationSequence:
    if isinstance(observations, np.ndarray):
        return observations[start:end]
    elif isinstance(observations, (list, tuple)):
        return [obs[start:end] for obs in observations]
    else:
        raise ValueError(f"invalid observation type: {type(observations)}")


_TDType = TypeVar("_TDType", bound=Any)


def batch_pad_array(
    array: npt.NDArray[_TDType], pad_size: int
) -> npt.NDArray[_TDType]:
    batch_size = array.shape[0]
    shape = array.shape[1:]
    padded_array = np.zeros((pad_size + batch_size, *shape), dtype=array.dtype)
    padded_array[-batch_size:] = array
    return padded_array


@overload
def batch_pad_observations(observations: NDArray, pad_size: int) -> NDArray: ...


@overload
def batch_pad_observations(
    observations: Sequence[NDArray], pad_size: int
) -> Sequence[NDArray]: ...


def batch_pad_observations(
    observations: ObservationSequence, pad_size: int
) -> ObservationSequence:
    if isinstance(observations, np.ndarray):
        return batch_pad_array(observations, pad_size)
    elif isinstance(observations, (list, tuple)):
        padded_observations = [
            batch_pad_observations(obs, pad_size) for obs in observations
        ]
        return padded_observations
    else:
        raise ValueError(f"invalid observations type: {type(observations)}")


@overload
def stack_recent_observations(
    observations: NDArray, index: int, n_frames: int
) -> NDArray: ...


@overload
def stack_recent_observations(
    observations: Sequence[NDArray], index: int, n_frames: int
) -> Sequence[NDArray]: ...


def stack_recent_observations(
    observations: ObservationSequence, index: int, n_frames: int
) -> Observation:
    start = max(index - n_frames + 1, 0)
    end = index + 1
    # (B, C, ...)
    observation_seq = slice_observations(observations, start, end)
    if end - start < n_frames:
        observation_seq = batch_pad_observations(
            observation_seq, n_frames - (end - start)
        )

    def squeeze_batch_dim(array: NDArray) -> NDArray:
        shape = array.shape
        batch_size = shape[0]
        channel_size = shape[1]
        rest_shape = shape[2:]
        return np.reshape(array, [batch_size * channel_size, *rest_shape])

    # (B, C, ...) -> (B * C, ...)
    if isinstance(observation_seq, np.ndarray):
        return squeeze_batch_dim(observation_seq)
    elif isinstance(observation_seq, (list, tuple)):
        return [squeeze_batch_dim(obs) for obs in observation_seq]
    else:
        raise ValueError(f"invalid observation type: {type(observation_seq)}")


@overload
def stack_observations(observations: Sequence[NDArray]) -> NDArray: ...


@overload
def stack_observations(
    observations: Sequence[Sequence[NDArray]],
) -> Sequence[NDArray]: ...


@overload
def stack_observations(observations: Sequence[Observation]) -> Observation: ...


def stack_observations(observations: Sequence[Observation]) -> Observation:
    if isinstance(observations[0], (list, tuple)):
        obs_kinds = len(observations[0])
        return [
            np.stack([obs[i] for obs in observations], axis=0)
            for i in range(obs_kinds)
        ]
    elif isinstance(observations[0], np.ndarray):
        return np.stack(observations, axis=0)
    else:
        raise ValueError(f"invalid observation type: {type(observations[0])}")


@overload
def get_shape_from_observation(observation: NDArray) -> Sequence[int]: ...


@overload
def get_shape_from_observation(
    observation: Sequence[NDArray],
) -> Sequence[Sequence[int]]: ...


def get_shape_from_observation(observation: Observation) -> Shape:
    if isinstance(observation, np.ndarray):
        return observation.shape
    elif isinstance(observation, (list, tuple)):
        return [obs.shape for obs in observation]
    else:
        raise ValueError(f"invalid observation type: {type(observation)}")


@overload
def get_shape_from_observation_sequence(
    observations: NDArray,
) -> Sequence[int]: ...


@overload
def get_shape_from_observation_sequence(
    observations: Sequence[NDArray],
) -> Sequence[Sequence[int]]: ...


def get_shape_from_observation_sequence(
    observations: ObservationSequence,
) -> Shape:
    if isinstance(observations, np.ndarray):
        return observations.shape[1:]
    elif isinstance(observations, (list, tuple)):
        return [obs.shape[1:] for obs in observations]
    else:
        raise ValueError(f"invalid observation type: {type(observations)}")


@overload
def get_dtype_from_observation(observation: NDArray) -> DType: ...


@overload
def get_dtype_from_observation(
    observation: Sequence[NDArray],
) -> Sequence[DType]: ...


def get_dtype_from_observation(
    observation: Observation,
) -> Union[DType, Sequence[DType]]:
    if isinstance(observation, np.ndarray):
        return observation.dtype
    elif isinstance(observation, (list, tuple)):
        return [obs.dtype for obs in observation]
    else:
        raise ValueError(f"invalid observation type: {type(observation)}")


@overload
def get_dtype_from_observation_sequence(
    observations: NDArray,
) -> DType: ...


@overload
def get_dtype_from_observation_sequence(
    observations: Sequence[NDArray],
) -> Sequence[DType]: ...


def get_dtype_from_observation_sequence(
    observations: ObservationSequence,
) -> Union[DType, Sequence[DType]]:
    if isinstance(observations, np.ndarray):
        return observations.dtype
    elif isinstance(observations, (list, tuple)):
        return [obs.dtype for obs in observations]
    else:
        raise ValueError(f"invalid observation type: {type(observations)}")


def check_dtype(array: Union[NDArray, Sequence[NDArray]], dtype: DType) -> bool:
    if isinstance(array, (list, tuple)):
        return all(v.dtype == dtype for v in array)
    elif isinstance(array, np.ndarray):
        return array.dtype == dtype
    else:
        raise ValueError(f"invalid array type: {type(array)}")


def check_non_1d_array(array: Union[NDArray, Sequence[NDArray]]) -> bool:
    if isinstance(array, (list, tuple)):
        return all(v.ndim > 1 for v in array)
    elif isinstance(array, np.ndarray):
        return array.ndim > 1
    else:
        raise ValueError(f"invalid array type: {type(array)}")


@overload
def cast_recursively(
    array: NDArray, dtype: Type[_TDType]
) -> npt.NDArray[_TDType]: ...


@overload
def cast_recursively(
    array: Sequence[NDArray], dtype: Type[_TDType]
) -> Sequence[npt.NDArray[_TDType]]: ...


def cast_recursively(
    array: Union[NDArray, Sequence[NDArray]], dtype: Type[_TDType]
) -> Union[npt.NDArray[_TDType], Sequence[npt.NDArray[_TDType]]]:
    if isinstance(array, (list, tuple)):
        return [array[i].astype(dtype) for i in range(len(array))]
    elif isinstance(array, np.ndarray):
        return array.astype(dtype)
    else:
        raise ValueError(f"invalid array type: {type(array)}")


def detect_action_space(actions: NDArray) -> ActionSpace:
    if np.all(np.array(actions, dtype=np.int32) == actions):
        return ActionSpace.DISCRETE
    else:
        return ActionSpace.CONTINUOUS


def detect_action_space_from_env(env: GymEnv) -> ActionSpace:
    if isinstance(env.action_space, (Box, GymnasiumBox)):
        action_space = ActionSpace.CONTINUOUS
    elif isinstance(env.action_space, (Discrete, GymnasiumDiscrete)):
        action_space = ActionSpace.DISCRETE
    else:
        raise ValueError(f"Unsupported action_space: {type(env.action_space)}")
    return action_space


def detect_action_size_from_env(env: GymEnv) -> int:
    if isinstance(env.action_space, (Discrete, GymnasiumDiscrete)):
        action_size = env.action_space.n
    elif isinstance(env.action_space, (Box, GymnasiumBox)):
        action_size = env.action_space.shape[0]
    else:
        raise ValueError(f"Unsupported action_space: {type(env.action_space)}")
    return int(action_size)


def is_tuple_shape(shape: Shape) -> bool:
    return isinstance(shape[0], (list, tuple))


def cast_tuple_shape(shape: Shape) -> Sequence[Sequence[int]]:
    assert is_tuple_shape(shape)
    return shape  # type: ignore


def cast_flat_shape(shape: Shape) -> Sequence[int]:
    assert not is_tuple_shape(shape)
    return shape  # type: ignore


def get_axis_size(array: Union[NDArray, Sequence[NDArray]], axis: int) -> int:
    if isinstance(array, np.ndarray):
        return int(array.shape[axis])
    elif isinstance(array, (list, tuple)):
        sizes = list(map(lambda v: v.shape[axis], array))
        size = sizes[axis]
        assert np.all(np.array(sizes) == size)
        return int(size)
    else:
        raise ValueError(f"invalid array type: {type(array)}")


def get_batch_dim(array: Union[NDArray, Sequence[NDArray]]) -> int:
    return get_axis_size(array, axis=0)