takuseno/d3rlpy

View on GitHub
d3rlpy/algos/qlearning/torch/cal_ql_impl.py

Summary

Maintainability
A
0 mins
Test Coverage
from typing import Tuple

import torch

from ....types import TorchObservation
from .cql_impl import CQLImpl

__all__ = ["CalQLImpl"]


class CalQLImpl(CQLImpl):
    def _compute_policy_is_values(
        self,
        policy_obs: TorchObservation,
        value_obs: TorchObservation,
        returns_to_go: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        values, log_probs = super()._compute_policy_is_values(
            policy_obs=policy_obs,
            value_obs=value_obs,
            returns_to_go=returns_to_go,
        )
        return torch.maximum(values, returns_to_go), log_probs