d3rlpy/torch_utility.py
import collections
import dataclasses
from typing import (
Any,
BinaryIO,
Dict,
Optional,
Sequence,
TypeVar,
Union,
overload,
)
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from typing_extensions import Self
from .dataclass_utils import asdict_without_copy
from .dataset import TrajectoryMiniBatch, TransitionMiniBatch
from .preprocessing import ActionScaler, ObservationScaler, RewardScaler
from .types import Float32NDArray, NDArray, TorchObservation
__all__ = [
"soft_sync",
"hard_sync",
"sync_optimizer_state",
"map_location",
"TorchMiniBatch",
"TorchTrajectoryMiniBatch",
"Checkpointer",
"Modules",
"convert_to_torch",
"convert_to_torch_recursively",
"convert_to_numpy_recursively",
"get_device",
"get_batch_size",
"expand_and_repeat_recursively",
"flatten_left_recursively",
"eval_api",
"train_api",
"View",
]
def soft_sync(targ_model: nn.Module, model: nn.Module, tau: float) -> None:
with torch.no_grad():
params = model.parameters()
targ_params = targ_model.parameters()
for p, p_targ in zip(params, targ_params):
p_targ.data.mul_(1 - tau)
p_targ.data.add_(tau * p.data)
def hard_sync(targ_model: nn.Module, model: nn.Module) -> None:
with torch.no_grad():
params = model.parameters()
targ_params = targ_model.parameters()
for p, p_targ in zip(params, targ_params):
p_targ.data.copy_(p.data)
def sync_optimizer_state(targ_optim: Optimizer, optim: Optimizer) -> None:
# source optimizer state
state = optim.state_dict()["state"]
# destination optimizer param_groups
param_groups = targ_optim.state_dict()["param_groups"]
# update only state
targ_optim.load_state_dict({"state": state, "param_groups": param_groups})
def map_location(device: str) -> Any:
if "cuda" in device:
return lambda storage, loc: storage.cuda(device)
if "cpu" in device:
return "cpu"
raise ValueError(f"invalid device={device}")
def convert_to_torch(array: NDArray, device: str) -> torch.Tensor:
dtype = torch.uint8 if array.dtype == np.uint8 else torch.float32
tensor = torch.tensor(data=array, dtype=dtype, device=device)
return tensor.float()
@overload
def convert_to_torch_recursively(
array: NDArray, device: str
) -> torch.Tensor: ...
@overload
def convert_to_torch_recursively(
array: Sequence[NDArray], device: str
) -> Sequence[torch.Tensor]: ...
def convert_to_torch_recursively(
array: Union[NDArray, Sequence[NDArray]], device: str
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
if isinstance(array, (list, tuple)):
return [convert_to_torch(data, device) for data in array]
elif isinstance(array, np.ndarray):
return convert_to_torch(array, device)
else:
raise ValueError(f"invalid array type: {type(array)}")
def convert_to_numpy_recursively(
array: Union[torch.Tensor, Sequence[torch.Tensor]]
) -> Union[NDArray, Sequence[NDArray]]:
if isinstance(array, (list, tuple)):
return [data.numpy() for data in array]
elif isinstance(array, torch.Tensor):
return array.numpy() # type: ignore
else:
raise ValueError(f"invalid array type: {type(array)}")
def get_device(x: Union[torch.Tensor, Sequence[torch.Tensor]]) -> str:
if isinstance(x, torch.Tensor):
return str(x.device)
else:
return str(x[0].device)
def get_batch_size(x: Union[torch.Tensor, Sequence[torch.Tensor]]) -> int:
if isinstance(x, torch.Tensor):
return int(x.shape[0])
else:
return int(x[0].shape[0])
def flatten_left_recursively(
x: Union[torch.Tensor, Sequence[torch.Tensor]], dim: int
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
if isinstance(x, torch.Tensor):
return x.reshape([-1, *x.shape[dim + 1 :]])
else:
return [flatten_left_recursively(_x, dim) for _x in x]
def expand_and_repeat_recursively(
x: Union[torch.Tensor, Sequence[torch.Tensor]], n: int
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
if isinstance(x, torch.Tensor):
# repeat observation
# (batch_size, M) -> (batch_size, 1, M)
reshaped_x = x.view(x.shape[0], 1, *x.shape[1:])
# (batch_sie, 1, M) -> (batch_size, N, M)
return reshaped_x.expand(x.shape[0], n, *x.shape[1:])
else:
return [expand_and_repeat_recursively(_x, n) for _x in x]
def _compute_return_to_go(
gamma: float,
rewards_to_go: Float32NDArray,
reward_scaler: Optional[RewardScaler],
) -> Float32NDArray:
rewards = (
reward_scaler.transform_numpy(rewards_to_go)
if reward_scaler
else rewards_to_go
)
cum_gammas: Float32NDArray = np.array(
np.expand_dims(gamma ** np.arange(rewards.shape[0]), axis=1),
dtype=np.float32,
)
return np.sum(cum_gammas * rewards, axis=0) # type: ignore
@dataclasses.dataclass(frozen=True)
class TorchMiniBatch:
observations: TorchObservation
actions: torch.Tensor
rewards: torch.Tensor
next_observations: TorchObservation
returns_to_go: torch.Tensor
terminals: torch.Tensor
intervals: torch.Tensor
device: str
numpy_batch: Optional[TransitionMiniBatch] = None
@classmethod
def from_batch(
cls,
batch: TransitionMiniBatch,
gamma: float,
compute_returns_to_go: bool,
device: str,
observation_scaler: Optional[ObservationScaler] = None,
action_scaler: Optional[ActionScaler] = None,
reward_scaler: Optional[RewardScaler] = None,
) -> "TorchMiniBatch":
# convert numpy array to torch tensor
observations = convert_to_torch_recursively(batch.observations, device)
actions = convert_to_torch(batch.actions, device)
rewards = convert_to_torch(batch.rewards, device)
next_observations = convert_to_torch_recursively(
batch.next_observations, device
)
terminals = convert_to_torch(batch.terminals, device)
intervals = convert_to_torch(batch.intervals, device)
if compute_returns_to_go:
returns_to_go = convert_to_torch(
np.array(
[
_compute_return_to_go(
gamma=gamma,
rewards_to_go=transition.rewards_to_go,
reward_scaler=reward_scaler,
)
for transition in batch.transitions
]
),
device,
)
else:
returns_to_go = torch.zeros_like(rewards)
# apply scaler
if observation_scaler:
observations = observation_scaler.transform(observations)
next_observations = observation_scaler.transform(next_observations)
if action_scaler:
actions = action_scaler.transform(actions)
if reward_scaler:
rewards = reward_scaler.transform(rewards)
return TorchMiniBatch(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
returns_to_go=returns_to_go,
terminals=terminals,
intervals=intervals,
device=device,
numpy_batch=batch,
)
@dataclasses.dataclass(frozen=True)
class TorchTrajectoryMiniBatch:
observations: TorchObservation # (B, L, ...)
actions: torch.Tensor # (B, L, ...)
rewards: torch.Tensor # (B, L, 1)
returns_to_go: torch.Tensor # (B, L, 1)
terminals: torch.Tensor # (B, L, 1)
timesteps: torch.Tensor # (B, L, 1)
masks: torch.Tensor # (B, L)
device: str
numpy_batch: Optional[TrajectoryMiniBatch] = None
@classmethod
def from_batch(
cls,
batch: TrajectoryMiniBatch,
device: str,
observation_scaler: Optional[ObservationScaler] = None,
action_scaler: Optional[ActionScaler] = None,
reward_scaler: Optional[RewardScaler] = None,
) -> "TorchTrajectoryMiniBatch":
# convert numpy array to torch tensor
observations = convert_to_torch_recursively(batch.observations, device)
actions = convert_to_torch(batch.actions, device)
rewards = convert_to_torch(batch.rewards, device)
returns_to_go = convert_to_torch(batch.returns_to_go, device)
terminals = convert_to_torch(batch.terminals, device)
timesteps = convert_to_torch(batch.timesteps, device).long()
masks = convert_to_torch(batch.masks, device)
# apply scaler
if observation_scaler:
observations = observation_scaler.transform(observations)
if action_scaler:
actions = action_scaler.transform(actions)
if reward_scaler:
rewards = reward_scaler.transform(rewards)
# NOTE: some operations might be incompatible with returns
returns_to_go = reward_scaler.transform(returns_to_go)
return TorchTrajectoryMiniBatch(
observations=observations,
actions=actions,
rewards=rewards,
returns_to_go=returns_to_go,
terminals=terminals,
timesteps=timesteps,
masks=masks,
device=device,
numpy_batch=batch,
)
class Checkpointer:
_modules: Dict[str, Union[nn.Module, Optimizer]]
_device: str
def __init__(
self, modules: Dict[str, Union[nn.Module, Optimizer]], device: str
):
self._modules = modules
self._device = device
def save(self, f: BinaryIO) -> None:
states = {k: v.state_dict() for k, v in self._modules.items()}
torch.save(states, f)
def load(self, f: BinaryIO) -> None:
chkpt = torch.load(f, map_location=map_location(self._device))
for k, v in self._modules.items():
v.load_state_dict(chkpt[k])
@property
def modules(self) -> Dict[str, Union[nn.Module, Optimizer]]:
return self._modules
@dataclasses.dataclass(frozen=True)
class Modules:
def create_checkpointer(self, device: str) -> Checkpointer:
modules = {
k: v
for k, v in asdict_without_copy(self).items()
if isinstance(v, (nn.Module, torch.optim.Optimizer))
}
return Checkpointer(modules=modules, device=device)
def freeze(self) -> None:
for v in asdict_without_copy(self).values():
if isinstance(v, nn.Module):
for p in v.parameters():
p.requires_grad = False
def unfreeze(self) -> None:
for v in asdict_without_copy(self).values():
if isinstance(v, nn.Module):
for p in v.parameters():
p.requires_grad = True
def set_eval(self) -> None:
for v in asdict_without_copy(self).values():
if isinstance(v, nn.Module):
v.eval()
def set_train(self) -> None:
for v in asdict_without_copy(self).values():
if isinstance(v, nn.Module):
v.train()
def reset_optimizer_states(self) -> None:
for v in asdict_without_copy(self).values():
if isinstance(v, torch.optim.Optimizer):
v.state = collections.defaultdict(dict)
def wrap_models_by_ddp(self) -> Self:
dict_values = asdict_without_copy(self)
for k, v in dict_values.items():
if isinstance(v, nn.Module):
device_id = next(v.parameters()).device.index
dict_values[k] = DDP(
v, device_ids=[device_id] if device_id else None
)
return self.__class__(**dict_values)
def unwrap_models_by_ddp(self) -> Self:
dict_values = asdict_without_copy(self)
for k, v in dict_values.items():
if isinstance(v, DDP):
dict_values[k] = v.module
return self.__class__(**dict_values)
TCallable = TypeVar("TCallable")
def eval_api(f: TCallable) -> TCallable:
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
assert hasattr(self, "modules")
assert isinstance(self.modules, Modules)
self.modules.set_eval()
return f(self, *args, **kwargs) # type: ignore
return wrapper # type: ignore
def train_api(f: TCallable) -> TCallable:
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
assert hasattr(self, "modules")
assert isinstance(self.modules, Modules)
self.modules.set_train()
return f(self, *args, **kwargs) # type: ignore
return wrapper # type: ignore
class View(nn.Module): # type: ignore
_shape: Sequence[int]
def __init__(self, shape: Sequence[int]):
super().__init__()
self._shape = shape
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.view(self._shape)
class Swish(nn.Module): # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class GEGLU(nn.Module): # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] % 2 == 0
a, b = x.chunk(2, dim=-1)
return a * F.gelu(b)