takuseno/d3rlpy

View on GitHub
d3rlpy/algos/utility.py

Summary

Maintainability
A
0 mins
Test Coverage
from typing import Any

from gym.spaces import Box, Discrete
from gymnasium.spaces import Box as GymnasiumBox
from gymnasium.spaces import Discrete as GymnasiumDiscrete

from ..base import LearnableBase
from ..constants import (
    CONTINUOUS_ACTION_SPACE_MISMATCH_ERROR,
    DISCRETE_ACTION_SPACE_MISMATCH_ERROR,
    ActionSpace,
)
from ..dataset import DatasetInfo, ReplayBufferBase
from ..logging import LOG
from ..types import GymEnv

__all__ = ["assert_action_space_with_dataset", "assert_action_space_with_env"]


def assert_action_space_with_dataset(
    algo: LearnableBase[Any, Any], dataset_info: DatasetInfo
) -> None:
    if algo.get_action_type() == ActionSpace.BOTH:
        pass
    elif dataset_info.action_space == ActionSpace.DISCRETE:
        assert (
            algo.get_action_type() == ActionSpace.DISCRETE
        ), DISCRETE_ACTION_SPACE_MISMATCH_ERROR
    else:
        assert (
            algo.get_action_type() == ActionSpace.CONTINUOUS
        ), CONTINUOUS_ACTION_SPACE_MISMATCH_ERROR


def assert_action_space_with_env(
    algo: LearnableBase[Any, Any], env: GymEnv
) -> None:
    if isinstance(env.action_space, (Box, GymnasiumBox)):
        assert (
            algo.get_action_type() == ActionSpace.CONTINUOUS
        ), CONTINUOUS_ACTION_SPACE_MISMATCH_ERROR
    elif isinstance(env.action_space, (Discrete, GymnasiumDiscrete)):
        assert (
            algo.get_action_type() == ActionSpace.DISCRETE
        ), DISCRETE_ACTION_SPACE_MISMATCH_ERROR
    else:
        action_space = type(env.action_space)
        raise ValueError(f"The action-space is not supported: {action_space}")


def build_scalers_with_transition_picker(
    algo: LearnableBase[Any, Any], dataset: ReplayBufferBase
) -> None:
    # initialize observation scaler
    if algo.observation_scaler and not algo.observation_scaler.built:
        LOG.debug(
            "Fitting observation scaler...",
            observation_scaler=algo.observation_scaler.get_type(),
        )
        algo.observation_scaler.fit_with_transition_picker(
            dataset.episodes, dataset.transition_picker
        )

    # initialize action scaler
    if algo.action_scaler and not algo.action_scaler.built:
        LOG.debug(
            "Fitting action scaler...",
            action_scaler=algo.action_scaler.get_type(),
        )
        algo.action_scaler.fit_with_transition_picker(
            dataset.episodes, dataset.transition_picker
        )

    # initialize reward scaler
    if algo.reward_scaler and not algo.reward_scaler.built:
        LOG.debug(
            "Fitting reward scaler...",
            reward_scaler=algo.reward_scaler.get_type(),
        )
        algo.reward_scaler.fit_with_transition_picker(
            dataset.episodes, dataset.transition_picker
        )


def build_scalers_with_trajectory_slicer(
    algo: LearnableBase[Any, Any], dataset: ReplayBufferBase
) -> None:
    # initialize observation scaler
    if algo.observation_scaler and not algo.observation_scaler.built:
        LOG.debug(
            "Fitting observation scaler...",
            observation_scaler=algo.observation_scaler.get_type(),
        )
        algo.observation_scaler.fit_with_trajectory_slicer(
            dataset.episodes, dataset.trajectory_slicer
        )

    # initialize action scaler
    if algo.action_scaler and not algo.action_scaler.built:
        LOG.debug(
            "Fitting action scaler...",
            action_scaler=algo.action_scaler.get_type(),
        )
        algo.action_scaler.fit_with_trajectory_slicer(
            dataset.episodes, dataset.trajectory_slicer
        )

    # initialize reward scaler
    if algo.reward_scaler and not algo.reward_scaler.built:
        LOG.debug(
            "Fitting reward scaler...",
            reward_scaler=algo.reward_scaler.get_type(),
        )
        algo.reward_scaler.fit_with_trajectory_slicer(
            dataset.episodes, dataset.trajectory_slicer
        )


def build_scalers_with_env(
    algo: LearnableBase[Any, Any],
    env: GymEnv,
) -> None:
    # initialize observation scaler
    if algo.observation_scaler and not algo.observation_scaler.built:
        LOG.debug(
            "Fitting observation scaler...",
            observation_scaler=algo.observation_scaler.get_type(),
        )
        algo.observation_scaler.fit_with_env(env)

    # initialize action scaler
    if algo.action_scaler and not algo.action_scaler.built:
        LOG.debug(
            "Fitting action scaler...",
            action_scler=algo.action_scaler.get_type(),
        )
        algo.action_scaler.fit_with_env(env)