takuseno/d3rlpy

View on GitHub
d3rlpy/models/torch/q_functions/qr_q_function.py

Summary

Maintainability
A
0 mins
Test Coverage
from typing import Optional, Union

import torch
from torch import nn

from ....torch_utility import get_batch_size, get_device
from ....types import TorchObservation
from ..encoders import Encoder, EncoderWithAction
from .base import (
    ContinuousQFunction,
    ContinuousQFunctionForwarder,
    DiscreteQFunction,
    DiscreteQFunctionForwarder,
    QFunctionOutput,
)
from .utility import (
    compute_quantile_loss,
    compute_reduce,
    pick_quantile_value_by_action,
)

__all__ = [
    "DiscreteQRQFunction",
    "ContinuousQRQFunction",
    "ContinuousQRQFunctionForwarder",
    "DiscreteQRQFunctionForwarder",
]


def _make_taus(n_quantiles: int, device: torch.device) -> torch.Tensor:
    steps = torch.arange(n_quantiles, dtype=torch.float32, device=device)
    taus = ((steps + 1).float() / n_quantiles).view(1, -1)
    taus_dot = (steps.float() / n_quantiles).view(1, -1)
    return (taus + taus_dot) / 2.0


class DiscreteQRQFunction(DiscreteQFunction):
    _action_size: int
    _encoder: Encoder
    _n_quantiles: int
    _fc: nn.Linear

    def __init__(
        self,
        encoder: Encoder,
        hidden_size: int,
        action_size: int,
        n_quantiles: int,
    ):
        super().__init__()
        self._encoder = encoder
        self._action_size = action_size
        self._n_quantiles = n_quantiles
        self._fc = nn.Linear(hidden_size, action_size * n_quantiles)

    def forward(self, x: TorchObservation) -> QFunctionOutput:
        quantiles = self._fc(self._encoder(x))
        quantiles = quantiles.view(-1, self._action_size, self._n_quantiles)
        return QFunctionOutput(
            q_value=quantiles.mean(dim=2),
            quantiles=quantiles,
            taus=_make_taus(self._n_quantiles, device=get_device(x)),
        )

    @property
    def encoder(self) -> Encoder:
        return self._encoder


class DiscreteQRQFunctionForwarder(DiscreteQFunctionForwarder):
    _q_func: DiscreteQRQFunction
    _n_quantiles: int

    def __init__(self, q_func: DiscreteQRQFunction, n_quantiles: int):
        self._q_func = q_func
        self._n_quantiles = n_quantiles

    def compute_expected_q(self, x: TorchObservation) -> torch.Tensor:
        return self._q_func(x).q_value

    def compute_error(
        self,
        observations: TorchObservation,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        target: torch.Tensor,
        terminals: torch.Tensor,
        gamma: Union[float, torch.Tensor] = 0.99,
        reduction: str = "mean",
    ) -> torch.Tensor:
        batch_size = get_batch_size(observations)
        assert target.shape == (batch_size, self._n_quantiles)

        # extraect quantiles corresponding to act_t
        output = self._q_func(observations)
        all_quantiles = output.quantiles
        taus = output.taus
        assert all_quantiles is not None and taus is not None
        quantiles = pick_quantile_value_by_action(all_quantiles, actions)

        loss = compute_quantile_loss(
            quantiles=quantiles,
            rewards=rewards,
            target=target,
            terminals=terminals,
            taus=taus,
            gamma=gamma,
        )

        return compute_reduce(loss, reduction)

    def compute_target(
        self, x: TorchObservation, action: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        quantiles = self._q_func(x).quantiles
        assert quantiles is not None
        if action is None:
            return quantiles
        return pick_quantile_value_by_action(quantiles, action)


class ContinuousQRQFunction(ContinuousQFunction):
    _encoder: EncoderWithAction
    _fc: nn.Linear
    _n_quantiles: int

    def __init__(
        self,
        encoder: EncoderWithAction,
        hidden_size: int,
        n_quantiles: int,
    ):
        super().__init__()
        self._encoder = encoder
        self._fc = nn.Linear(hidden_size, n_quantiles)
        self._n_quantiles = n_quantiles

    def forward(
        self, x: TorchObservation, action: torch.Tensor
    ) -> QFunctionOutput:
        quantiles = self._fc(self._encoder(x, action))
        return QFunctionOutput(
            q_value=quantiles.mean(dim=1, keepdim=True),
            quantiles=quantiles,
            taus=_make_taus(self._n_quantiles, device=get_device(x)),
        )

    @property
    def encoder(self) -> EncoderWithAction:
        return self._encoder


class ContinuousQRQFunctionForwarder(ContinuousQFunctionForwarder):
    _q_func: ContinuousQRQFunction
    _n_quantiles: int

    def __init__(self, q_func: ContinuousQRQFunction, n_quantiles: int):
        self._q_func = q_func
        self._n_quantiles = n_quantiles

    def compute_expected_q(
        self, x: torch.Tensor, action: torch.Tensor
    ) -> torch.Tensor:
        return self._q_func(x, action).q_value

    def compute_error(
        self,
        observations: TorchObservation,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        target: torch.Tensor,
        terminals: torch.Tensor,
        gamma: Union[float, torch.Tensor] = 0.99,
        reduction: str = "mean",
    ) -> torch.Tensor:
        batch_size = get_batch_size(observations)
        assert target.shape == (batch_size, self._n_quantiles)

        output = self._q_func(observations, actions)
        quantiles = output.quantiles
        taus = output.taus
        assert quantiles is not None and taus is not None

        loss = compute_quantile_loss(
            quantiles=quantiles,
            rewards=rewards,
            target=target,
            terminals=terminals,
            taus=taus,
            gamma=gamma,
        )

        return compute_reduce(loss, reduction)

    def compute_target(
        self, x: TorchObservation, action: torch.Tensor
    ) -> torch.Tensor:
        quantiles = self._q_func(x, action).quantiles
        assert quantiles is not None
        return quantiles