MushroomRL/mushroom-rl

View on GitHub
mushroom_rl/algorithms/actor_critic/deep_actor_critic/deep_actor_critic.py

Summary

Maintainability
A
25 mins
Test Coverage
A
94%
from mushroom_rl.core import Agent
from mushroom_rl.utils.torch import TorchUtils


class OnPolicyDeepAC(Agent):
    def _preprocess_state(self, state, next_state, output_old=True):
        state_old = None

        if output_old:
            state_old = self._agent_preprocess(state)

        self._update_agent_preprocessor(state)
        state = self._agent_preprocess(state)
        next_state = self._agent_preprocess(next_state)

        if output_old:
            return state, next_state, state_old
        else:
            return state, next_state


class DeepAC(Agent):
    """
    Base class for off policy deep actor-critic algorithms.
    These algorithms use the reparametrization trick, such as SAC, DDPG and TD3.

    """

    def __init__(self, mdp_info, policy, actor_optimizer, parameters, backend='torch'):
        """
        Constructor.

        Args:
            actor_optimizer (dict): parameters to specify the actor optimizer
                algorithm;
            parameters (list): policy parameters to be optimized.

        """
        if actor_optimizer is not None:
            if parameters is not None and not isinstance(parameters, list):
                parameters = list(parameters)
            self._parameters = parameters

            self._optimizer = actor_optimizer['class'](parameters, **actor_optimizer['params'])

            self._clipping = None

            if 'clipping' in actor_optimizer:
                self._clipping = actor_optimizer['clipping']['method']
                self._clipping_params = actor_optimizer['clipping']['params']

        super().__init__(mdp_info, policy, backend=backend)
        
        self._add_save_attr(
            _optimizer='torch',
            _clipping='torch',
            _clipping_params='pickle'
        )

    def fit(self, dataset):
        """
        Fit step.

        Args:
            dataset (list): the dataset.

        """
        raise NotImplementedError('DeepAC is an abstract class')

    def _optimize_actor_parameters(self, loss):
        """
        Method used to update actor parameters to maximize a given loss.

        Args:
            loss (torch.tensor): the loss computed by the algorithm.

        """
        self._optimizer.zero_grad()
        loss.backward()
        self._clip_gradient()
        self._optimizer.step()

    def _clip_gradient(self):
        if self._clipping:
            self._clipping(self._parameters, **self._clipping_params)

    @staticmethod
    def _init_target(online, target):
        for i in range(len(target)):
            target[i].set_weights(online[i].get_weights())

    def _update_target(self, online, target):
        for i in range(len(target)):
            weights = self._tau() * online[i].get_weights()
            weights += (1 - self._tau.get_value()) * target[i].get_weights()
            target[i].set_weights(weights)

    def _update_optimizer_parameters(self, parameters):
        self._parameters = list(parameters)
        if self._optimizer is not None:
            TorchUtils.update_optimizer_parameters(self._optimizer, self._parameters)

    def _post_load(self):
        raise NotImplementedError('DeepAC is an abstract class. Subclasses need to implement the `_post_load` method.')