takuseno/d3rlpy

View on GitHub
d3rlpy/algos/transformer/inputs.py

Summary

Maintainability
A
35 mins
Test Coverage
import dataclasses
from typing import Optional

import numpy as np
import torch

from ...dataset import (
    batch_pad_array,
    batch_pad_observations,
    get_axis_size,
    slice_observations,
)
from ...preprocessing import ActionScaler, ObservationScaler, RewardScaler
from ...torch_utility import convert_to_torch, convert_to_torch_recursively
from ...types import (
    Float32NDArray,
    Int32NDArray,
    NDArray,
    ObservationSequence,
    TorchObservation,
)

__all__ = ["TransformerInput", "TorchTransformerInput"]


@dataclasses.dataclass(frozen=True)
class TransformerInput:
    observations: ObservationSequence  # (L, ...)
    actions: NDArray  # (L, ...)
    rewards: Float32NDArray  # (L, 1)
    returns_to_go: Float32NDArray  # (L, 1)
    timesteps: Int32NDArray  # (L,)

    def __post_init__(self) -> None:
        # check sequence size
        length = get_axis_size(self.observations, axis=0)
        assert get_axis_size(self.actions, axis=0) == length
        assert get_axis_size(self.rewards, axis=0) == length
        assert get_axis_size(self.returns_to_go, axis=0) == length
        assert get_axis_size(self.timesteps, axis=0) == length

    @property
    def length(self) -> int:
        return get_axis_size(self.actions, axis=0)


@dataclasses.dataclass(frozen=True)
class TorchTransformerInput:
    observations: TorchObservation  # (1, L, ...)
    actions: torch.Tensor  # (1, L, ...)
    rewards: torch.Tensor  # (1, L, 1)
    returns_to_go: torch.Tensor  # (1, L, 1)
    timesteps: torch.Tensor  # (1, L)
    masks: torch.Tensor  # (1, L)
    length: int

    @classmethod
    def from_numpy(
        cls,
        inpt: TransformerInput,
        context_size: int,
        device: str,
        observation_scaler: Optional[ObservationScaler] = None,
        action_scaler: Optional[ActionScaler] = None,
        reward_scaler: Optional[RewardScaler] = None,
    ) -> "TorchTransformerInput":
        masks: Float32NDArray
        if context_size < inpt.length:
            observations = slice_observations(
                inpt.observations, inpt.length - context_size, inpt.length
            )
            actions = inpt.actions[-context_size:]
            rewards = inpt.rewards[-context_size:]
            returns_to_go = inpt.returns_to_go[-context_size:]
            timesteps = inpt.timesteps[-context_size:]
            masks = np.ones(context_size, dtype=np.float32)
        else:
            pad_size = context_size - inpt.length
            observations = batch_pad_observations(inpt.observations, pad_size)
            actions = batch_pad_array(inpt.actions, pad_size)
            rewards = batch_pad_array(inpt.rewards, pad_size)
            returns_to_go = batch_pad_array(inpt.returns_to_go, pad_size)
            timesteps = batch_pad_array(inpt.timesteps, pad_size)
            masks = batch_pad_array(
                np.ones(inpt.length, dtype=np.float32), pad_size
            )

        # convert numpy array to torch tensor
        observations_pt = convert_to_torch_recursively(observations, device)
        actions_pt = convert_to_torch(actions, device)
        rewards_pt = convert_to_torch(rewards, device)
        returns_to_go_pt = convert_to_torch(returns_to_go, device)
        timesteps_pt = convert_to_torch(timesteps, device).long()
        masks_pt = convert_to_torch(masks, device)

        # apply scaler
        if observation_scaler:
            observations_pt = observation_scaler.transform(observations_pt)
        if action_scaler:
            actions_pt = action_scaler.transform(actions_pt)
        if reward_scaler:
            rewards_pt = reward_scaler.transform(rewards_pt)
            returns_to_go_pt = reward_scaler.transform(returns_to_go_pt)

        if isinstance(observations_pt, torch.Tensor):
            unsqueezed_observation = observations_pt.unsqueeze(0)
        else:
            unsqueezed_observation = [o.unsqueeze(0) for o in observations_pt]

        return TorchTransformerInput(
            observations=unsqueezed_observation,
            actions=actions_pt.unsqueeze(0),
            rewards=rewards_pt.unsqueeze(0),
            returns_to_go=returns_to_go_pt.unsqueeze(0),
            timesteps=timesteps_pt.unsqueeze(0),
            masks=masks_pt.unsqueeze(0),
            length=context_size,
        )