d3rlpy/algos/transformer/base.py
import dataclasses
from abc import abstractmethod
from collections import defaultdict, deque
from typing import Callable, Deque, Dict, Generic, Optional, TypeVar, Union
import numpy as np
import torch
from tqdm.auto import tqdm
from typing_extensions import Self
from ...base import ImplBase, LearnableBase, LearnableConfig, save_config
from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from ...dataset import ReplayBuffer, TrajectoryMiniBatch
from ...logging import (
LOG,
D3RLPyLogger,
FileAdapterFactory,
LoggerAdapterFactory,
)
from ...metrics import evaluate_transformer_with_environment
from ...torch_utility import TorchTrajectoryMiniBatch, eval_api, train_api
from ...types import GymEnv, NDArray, Observation
from ..utility import (
assert_action_space_with_dataset,
build_scalers_with_trajectory_slicer,
)
from .action_samplers import (
IdentityTransformerActionSampler,
SoftmaxTransformerActionSampler,
TransformerActionSampler,
)
from .inputs import TorchTransformerInput, TransformerInput
__all__ = [
"TransformerAlgoImplBase",
"StatefulTransformerWrapper",
"TransformerConfig",
"TransformerAlgoBase",
]
class TransformerAlgoImplBase(ImplBase):
@eval_api
def predict(self, inpt: TorchTransformerInput) -> torch.Tensor:
return self.inner_predict(inpt)
@abstractmethod
def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor:
raise NotImplementedError
@train_api
def update(
self, batch: TorchTrajectoryMiniBatch, grad_step: int
) -> Dict[str, float]:
return self.inner_update(batch, grad_step)
@abstractmethod
def inner_update(
self, batch: TorchTrajectoryMiniBatch, grad_step: int
) -> Dict[str, float]:
raise NotImplementedError
@dataclasses.dataclass()
class TransformerConfig(LearnableConfig):
context_size: int = 20
max_timestep: int = 1000
TTransformerImpl = TypeVar("TTransformerImpl", bound=TransformerAlgoImplBase)
TTransformerConfig = TypeVar("TTransformerConfig", bound=TransformerConfig)
class StatefulTransformerWrapper(Generic[TTransformerImpl, TTransformerConfig]):
r"""A stateful wrapper for inference of Transformer-based algorithms.
This wrapper class provides a similar interface of Q-learning-based
algoritms, which is especially useful when you evaluate Transformer-based
algorithms such as Decision Transformer.
.. code-block:: python
from d3rlpy.algos import DecisionTransformerConfig
from d3rlpy.algos import StatefulTransformerWrapper
dt = DecisionTransformerConfig().create()
dt.create_impl(<observation_shape>, <action_size>)
# initialize wrapper with a target return of 1000
wrapper = StatefulTransformerWrapper(dt, target_return=1000)
# shortcut is also available
wrapper = dt.as_stateful_wrapper(target_return=1000)
# predict next action to achieve the return of 1000 in the end
action = wrapper.predict(<observation>, <reward>)
# clear stateful information
wrapper.reset()
Args:
algo (TransformerAlgoBase): Transformer-based algorithm.
target_return (float): Target return.
action_sampler (d3rlpy.algos.TransformerActionSampler): Action sampler.
"""
_algo: "TransformerAlgoBase[TTransformerImpl, TTransformerConfig]"
_target_return: float
_action_sampler: TransformerActionSampler
_return_rest: float
_observations: Deque[Observation]
_actions: Deque[Union[NDArray, int]]
_rewards: Deque[float]
_returns_to_go: Deque[float]
_timesteps: Deque[int]
_timestep: int
def __init__(
self,
algo: "TransformerAlgoBase[TTransformerImpl, TTransformerConfig]",
target_return: float,
action_sampler: TransformerActionSampler,
):
assert algo.impl, "algo must be built."
self._algo = algo
self._target_return = target_return
self._action_sampler = action_sampler
self._return_rest = target_return
context_size = algo.config.context_size
self._observations = deque([], maxlen=context_size)
self._actions = deque([self._get_pad_action()], maxlen=context_size)
self._rewards = deque([], maxlen=context_size)
self._returns_to_go = deque([], maxlen=context_size)
self._timesteps = deque([], maxlen=context_size)
self._timestep = 1
def predict(self, x: Observation, reward: float) -> Union[NDArray, int]:
r"""Returns action.
Args:
x: Observation.
reward: Last reward.
Returns:
Action.
"""
self._observations.append(x)
self._rewards.append(reward)
self._returns_to_go.append(self._return_rest - reward)
self._timesteps.append(self._timestep)
numpy_observations: Observation
if isinstance(x, np.ndarray):
numpy_observations = np.array(self._observations)
else:
numpy_observations = [
np.array([o[i] for o in self._observations])
for i in range(len(x))
]
inpt = TransformerInput(
observations=numpy_observations,
actions=np.array(self._actions),
rewards=np.array(self._rewards).reshape((-1, 1)),
returns_to_go=np.array(self._returns_to_go).reshape((-1, 1)),
timesteps=np.array(self._timesteps),
)
action = self._action_sampler(self._algo.predict(inpt))
self._actions[-1] = action
self._actions.append(self._get_pad_action())
self._timestep = min(self._timestep + 1, self._algo.config.max_timestep)
self._return_rest -= reward
return action
def reset(self) -> None:
"""Clears stateful information."""
self._observations.clear()
self._actions.clear()
self._rewards.clear()
self._returns_to_go.clear()
self._timesteps.clear()
self._actions.append(self._get_pad_action())
self._timestep = 1
self._return_rest = self._target_return
@property
def algo(
self,
) -> "TransformerAlgoBase[TTransformerImpl, TTransformerConfig]":
return self._algo
def _get_pad_action(self) -> Union[int, NDArray]:
assert self._algo.impl
pad_action: Union[int, NDArray]
if self._algo.get_action_type() == ActionSpace.CONTINUOUS:
pad_action = np.zeros(self._algo.impl.action_size, dtype=np.float32)
else:
pad_action = 0
return pad_action
class TransformerAlgoBase(
Generic[TTransformerImpl, TTransformerConfig],
LearnableBase[TTransformerImpl, TTransformerConfig],
):
def predict(self, inpt: TransformerInput) -> NDArray:
"""Returns action.
This is for internal use. For evaluation, use
``StatefulTransformerWrapper`` instead.
Args:
inpt: Sequence input.
Returns:
Action.
"""
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
with torch.no_grad():
torch_inpt = TorchTransformerInput.from_numpy(
inpt=inpt,
context_size=self._config.context_size,
device=self._device,
observation_scaler=self._config.observation_scaler,
action_scaler=self._config.action_scaler,
reward_scaler=self._config.reward_scaler,
)
action = self._impl.predict(torch_inpt)
if self._config.action_scaler:
action = self._config.action_scaler.reverse_transform(action)
return action.cpu().detach().numpy() # type: ignore
def fit(
self,
dataset: ReplayBuffer,
n_steps: int,
n_steps_per_epoch: int = 10000,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
eval_env: Optional[GymEnv] = None,
eval_target_return: Optional[float] = None,
eval_action_sampler: Optional[TransformerActionSampler] = None,
save_interval: int = 1,
callback: Optional[Callable[[Self, int, int], None]] = None,
enable_ddp: bool = False,
) -> None:
"""Trains with given dataset.
Args:
dataset: Offline dataset to train.
n_steps: Number of steps to train.
n_steps_per_epoch: Number of steps per epoch. This value will
be ignored when ``n_steps`` is ``None``.
experiment_name: Experiment name for logging. If not passed,
the directory name will be `{class name}_{timestamp}`.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
eval_env: Evaluation environment.
eval_target_return: Evaluation return target.
eval_action_sampler: Action sampler used in evaluation.
save_interval: Interval to save parameters.
callback: Callable function that takes ``(algo, epoch, total_step)``
, which is called every step.
enable_ddp: Flag to wrap models with DataDistributedParallel.
"""
LOG.info("dataset info", dataset_info=dataset.dataset_info)
# check action space
assert_action_space_with_dataset(self, dataset.dataset_info)
# initialize scalers
build_scalers_with_trajectory_slicer(self, dataset)
# setup logger
if experiment_name is None:
experiment_name = self.__class__.__name__
logger = D3RLPyLogger(
adapter_factory=logger_adapter,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
)
# instantiate implementation
if self._impl is None:
LOG.debug("Building models...")
action_size = dataset.dataset_info.action_size
observation_shape = (
dataset.sample_transition().observation_signature.shape
)
if len(observation_shape) == 1:
observation_shape = observation_shape[0] # type: ignore
self.create_impl(observation_shape, action_size)
LOG.debug("Models have been built.")
else:
LOG.warning("Skip building models since they're already built.")
# wrap all PyTorch modules with DataDistributedParallel
if enable_ddp:
assert self._impl
self._impl.wrap_models_by_ddp()
# save hyperparameters
save_config(self, logger)
# training loop
n_epochs = n_steps // n_steps_per_epoch
total_step = 0
for epoch in range(1, n_epochs + 1):
# dict to add incremental mean losses to epoch
epoch_loss = defaultdict(list)
range_gen = tqdm(
range(n_steps_per_epoch),
disable=not show_progress,
desc=f"Epoch {int(epoch)}/{n_epochs}",
)
for itr in range_gen:
with logger.measure_time("step"):
# pick transitions
with logger.measure_time("sample_batch"):
batch = dataset.sample_trajectory_batch(
self._config.batch_size,
length=self._config.context_size,
)
# update parameters
with logger.measure_time("algorithm_update"):
loss = self.update(batch)
# record metrics
for name, val in loss.items():
logger.add_metric(name, val)
epoch_loss[name].append(val)
# update progress postfix with losses
if itr % 10 == 0:
mean_loss = {
k: np.mean(v) for k, v in epoch_loss.items()
}
range_gen.set_postfix(mean_loss)
total_step += 1
# call callback if given
if callback:
callback(self, epoch, total_step)
if eval_env:
assert eval_target_return is not None
eval_score = evaluate_transformer_with_environment(
algo=self.as_stateful_wrapper(
target_return=eval_target_return,
action_sampler=eval_action_sampler,
),
env=eval_env,
)
logger.add_metric("environment", eval_score)
# save metrics
logger.commit(epoch, total_step)
# save model parameters
if epoch % save_interval == 0:
logger.save_model(total_step, self)
logger.close()
def update(self, batch: TrajectoryMiniBatch) -> Dict[str, float]:
"""Update parameters with mini-batch of data.
Args:
batch: Mini-batch data.
Returns:
Dictionary of metrics.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
torch_batch = TorchTrajectoryMiniBatch.from_batch(
batch=batch,
device=self._device,
observation_scaler=self._config.observation_scaler,
action_scaler=self._config.action_scaler,
reward_scaler=self._config.reward_scaler,
)
loss = self._impl.update(torch_batch, self._grad_step)
self._grad_step += 1
return loss
def as_stateful_wrapper(
self,
target_return: float,
action_sampler: Optional[TransformerActionSampler] = None,
) -> StatefulTransformerWrapper[TTransformerImpl, TTransformerConfig]:
"""Returns a wrapped Transformer algorithm for stateful decision making.
Args:
target_return: Target environment return.
action_sampler: Action sampler.
Returns:
StatefulTransformerWrapper object.
"""
if action_sampler is None:
if self.get_action_type() == ActionSpace.CONTINUOUS:
action_sampler = IdentityTransformerActionSampler()
else:
action_sampler = SoftmaxTransformerActionSampler()
return StatefulTransformerWrapper(self, target_return, action_sampler)