d3rlpy/models/torch/q_functions/mean_q_function.py
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ....types import TorchObservation
from ..encoders import Encoder, EncoderWithAction
from .base import (
ContinuousQFunction,
ContinuousQFunctionForwarder,
DiscreteQFunction,
DiscreteQFunctionForwarder,
QFunctionOutput,
)
from .utility import compute_huber_loss, compute_reduce, pick_value_by_action
__all__ = [
"DiscreteMeanQFunction",
"ContinuousMeanQFunction",
"DiscreteMeanQFunctionForwarder",
"ContinuousMeanQFunctionForwarder",
]
class DiscreteMeanQFunction(DiscreteQFunction):
_encoder: Encoder
_fc: nn.Linear
def __init__(self, encoder: Encoder, hidden_size: int, action_size: int):
super().__init__()
self._encoder = encoder
self._fc = nn.Linear(hidden_size, action_size)
def forward(self, x: TorchObservation) -> QFunctionOutput:
return QFunctionOutput(
q_value=self._fc(self._encoder(x)),
quantiles=None,
taus=None,
)
@property
def encoder(self) -> Encoder:
return self._encoder
class DiscreteMeanQFunctionForwarder(DiscreteQFunctionForwarder):
_q_func: DiscreteMeanQFunction
_action_size: int
def __init__(self, q_func: DiscreteMeanQFunction, action_size: int):
self._q_func = q_func
self._action_size = action_size
def compute_expected_q(self, x: TorchObservation) -> torch.Tensor:
return self._q_func(x).q_value
def compute_error(
self,
observations: TorchObservation,
actions: torch.Tensor,
rewards: torch.Tensor,
target: torch.Tensor,
terminals: torch.Tensor,
gamma: Union[float, torch.Tensor] = 0.99,
reduction: str = "mean",
) -> torch.Tensor:
one_hot = F.one_hot(actions.view(-1), num_classes=self._action_size)
value = (self._q_func(observations).q_value * one_hot.float()).sum(
dim=1, keepdim=True
)
y = rewards + gamma * target * (1 - terminals)
loss = compute_huber_loss(value, y)
return compute_reduce(loss, reduction)
def compute_target(
self, x: TorchObservation, action: Optional[torch.Tensor] = None
) -> torch.Tensor:
if action is None:
return self._q_func(x).q_value
return pick_value_by_action(
self._q_func(x).q_value, action, keepdim=True
)
class ContinuousMeanQFunction(ContinuousQFunction):
_encoder: EncoderWithAction
_fc: nn.Linear
def __init__(self, encoder: EncoderWithAction, hidden_size: int):
super().__init__()
self._encoder = encoder
self._fc = nn.Linear(hidden_size, 1)
def forward(
self, x: TorchObservation, action: torch.Tensor
) -> QFunctionOutput:
return QFunctionOutput(
q_value=self._fc(self._encoder(x, action)),
quantiles=None,
taus=None,
)
@property
def encoder(self) -> EncoderWithAction:
return self._encoder
class ContinuousMeanQFunctionForwarder(ContinuousQFunctionForwarder):
_q_func: ContinuousMeanQFunction
def __init__(self, q_func: ContinuousMeanQFunction):
self._q_func = q_func
def compute_expected_q(
self, x: TorchObservation, action: torch.Tensor
) -> torch.Tensor:
return self._q_func(x, action).q_value
def compute_error(
self,
observations: TorchObservation,
actions: torch.Tensor,
rewards: torch.Tensor,
target: torch.Tensor,
terminals: torch.Tensor,
gamma: Union[float, torch.Tensor] = 0.99,
reduction: str = "mean",
) -> torch.Tensor:
value = self._q_func(observations, actions).q_value
y = rewards + gamma * target * (1 - terminals)
loss = F.mse_loss(value, y, reduction="none")
return compute_reduce(loss, reduction)
def compute_target(
self, x: TorchObservation, action: torch.Tensor
) -> torch.Tensor:
return self._q_func(x, action).q_value