d3rlpy/algos/qlearning/torch/cql_impl.py
import dataclasses
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from ....models.torch import (
ContinuousEnsembleQFunctionForwarder,
DiscreteEnsembleQFunctionForwarder,
Parameter,
get_parameter,
)
from ....torch_utility import (
TorchMiniBatch,
expand_and_repeat_recursively,
flatten_left_recursively,
)
from ....types import Shape, TorchObservation
from .ddpg_impl import DDPGBaseCriticLoss
from .dqn_impl import DoubleDQNImpl, DQNLoss, DQNModules
from .sac_impl import SACImpl, SACModules
from .utility import sample_q_values_with_policy
__all__ = ["CQLImpl", "DiscreteCQLImpl", "CQLModules", "DiscreteCQLLoss"]
@dataclasses.dataclass(frozen=True)
class CQLModules(SACModules):
log_alpha: Parameter
alpha_optim: Optional[Optimizer]
@dataclasses.dataclass(frozen=True)
class CQLCriticLoss(DDPGBaseCriticLoss):
conservative_loss: torch.Tensor
alpha: torch.Tensor
class CQLImpl(SACImpl):
_modules: CQLModules
_alpha_threshold: float
_conservative_weight: float
_n_action_samples: int
_soft_q_backup: bool
_max_q_backup: bool
def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: CQLModules,
q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
gamma: float,
tau: float,
alpha_threshold: float,
conservative_weight: float,
n_action_samples: int,
soft_q_backup: bool,
max_q_backup: bool,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
tau=tau,
device=device,
)
self._alpha_threshold = alpha_threshold
self._conservative_weight = conservative_weight
self._n_action_samples = n_action_samples
self._soft_q_backup = soft_q_backup
self._max_q_backup = max_q_backup
def compute_critic_loss(
self, batch: TorchMiniBatch, q_tpn: torch.Tensor
) -> CQLCriticLoss:
loss = super().compute_critic_loss(batch, q_tpn)
conservative_loss = self._compute_conservative_loss(
obs_t=batch.observations,
act_t=batch.actions,
obs_tp1=batch.next_observations,
returns_to_go=batch.returns_to_go,
)
if self._modules.alpha_optim:
self.update_alpha(conservative_loss)
return CQLCriticLoss(
critic_loss=loss.critic_loss + conservative_loss.sum(),
conservative_loss=conservative_loss.sum(),
alpha=get_parameter(self._modules.log_alpha).exp(),
)
def update_alpha(self, conservative_loss: torch.Tensor) -> None:
assert self._modules.alpha_optim
self._modules.alpha_optim.zero_grad()
# the original implementation does scale the loss value
loss = -conservative_loss.mean()
loss.backward(retain_graph=True)
self._modules.alpha_optim.step()
def _compute_policy_is_values(
self,
policy_obs: TorchObservation,
value_obs: TorchObservation,
returns_to_go: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
return sample_q_values_with_policy(
policy=self._modules.policy,
q_func_forwarder=self._q_func_forwarder,
policy_observations=policy_obs,
value_observations=value_obs,
n_action_samples=self._n_action_samples,
detach_policy_output=True,
)
def _compute_random_is_values(
self, obs: TorchObservation
) -> Tuple[torch.Tensor, float]:
# (batch, observation) -> (batch, n, observation)
repeated_obs = expand_and_repeat_recursively(
obs, self._n_action_samples
)
# (batch, n, observation) -> (batch * n, observation)
flat_obs = flatten_left_recursively(repeated_obs, dim=1)
# estimate action-values for actions from uniform distribution
# uniform distribution between [-1.0, 1.0]
batch_size = (
obs.shape[0] if isinstance(obs, torch.Tensor) else obs[0].shape[0]
)
flat_shape = (batch_size * self._n_action_samples, self._action_size)
zero_tensor = torch.zeros(flat_shape, device=self._device)
random_actions = zero_tensor.uniform_(-1.0, 1.0)
random_values = self._q_func_forwarder.compute_expected_q(
flat_obs, random_actions, "none"
)
random_values = random_values.view(
-1, batch_size, self._n_action_samples
)
random_log_probs = math.log(0.5**self._action_size)
# importance sampling
return random_values, random_log_probs
def _compute_conservative_loss(
self,
obs_t: TorchObservation,
act_t: torch.Tensor,
obs_tp1: TorchObservation,
returns_to_go: torch.Tensor,
) -> torch.Tensor:
policy_values_t, log_probs_t = self._compute_policy_is_values(
policy_obs=obs_t,
value_obs=obs_t,
returns_to_go=returns_to_go,
)
policy_values_tp1, log_probs_tp1 = self._compute_policy_is_values(
policy_obs=obs_tp1,
value_obs=obs_t,
returns_to_go=returns_to_go,
)
random_values, random_log_probs = self._compute_random_is_values(obs_t)
# compute logsumexp
# (n critics, batch, 3 * n samples) -> (n critics, batch, 1)
target_values = torch.cat(
[
policy_values_t - log_probs_t,
policy_values_tp1 - log_probs_tp1,
random_values - random_log_probs,
],
dim=2,
)
logsumexp = torch.logsumexp(target_values, dim=2, keepdim=True)
# estimate action-values for data actions
data_values = self._q_func_forwarder.compute_expected_q(
obs_t, act_t, "none"
)
loss = (logsumexp - data_values).mean(dim=[1, 2])
# clip for stability
log_alpha = get_parameter(self._modules.log_alpha)
clipped_alpha = log_alpha.exp().clamp(0, 1e6)[0][0]
return (
clipped_alpha
* self._conservative_weight
* (loss - self._alpha_threshold)
)
def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
if self._soft_q_backup:
target_value = super().compute_target(batch)
else:
with torch.no_grad():
target_value = self._compute_deterministic_target(batch)
return target_value
def _compute_deterministic_target(
self, batch: TorchMiniBatch
) -> torch.Tensor:
if self._max_q_backup:
q_values, _ = sample_q_values_with_policy(
policy=self._modules.policy,
q_func_forwarder=self._targ_q_func_forwarder,
policy_observations=batch.next_observations,
value_observations=batch.next_observations,
n_action_samples=self._n_action_samples,
detach_policy_output=True,
)
return q_values.min(dim=0).values.max(dim=1, keepdims=True).values
else:
action = self._modules.policy(batch.next_observations).squashed_mu
return self._targ_q_func_forwarder.compute_target(
batch.next_observations,
action,
reduction="min",
)
@dataclasses.dataclass(frozen=True)
class DiscreteCQLLoss(DQNLoss):
td_loss: torch.Tensor
conservative_loss: torch.Tensor
class DiscreteCQLImpl(DoubleDQNImpl):
_alpha: float
def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: DQNModules,
q_func_forwarder: DiscreteEnsembleQFunctionForwarder,
targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder,
target_update_interval: int,
gamma: float,
alpha: float,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
target_update_interval=target_update_interval,
gamma=gamma,
device=device,
)
self._alpha = alpha
def _compute_conservative_loss(
self, obs_t: TorchObservation, act_t: torch.Tensor
) -> torch.Tensor:
# compute logsumexp
values = self._q_func_forwarder.compute_expected_q(obs_t)
logsumexp = torch.logsumexp(values, dim=1, keepdim=True)
# estimate action-values under data distribution
one_hot = F.one_hot(act_t.view(-1), num_classes=self.action_size)
data_values = (values * one_hot).sum(dim=1, keepdim=True)
return (logsumexp - data_values).mean()
def compute_loss(
self,
batch: TorchMiniBatch,
q_tpn: torch.Tensor,
) -> DiscreteCQLLoss:
td_loss = super().compute_loss(batch, q_tpn).loss
conservative_loss = self._compute_conservative_loss(
batch.observations, batch.actions.long()
)
loss = td_loss + self._alpha * conservative_loss
return DiscreteCQLLoss(
loss=loss, td_loss=td_loss, conservative_loss=conservative_loss
)