MushroomRL/mushroom-rl

View on GitHub
mushroom_rl/algorithms/policy_search/black_box_optimization/eppo.py

Summary

Maintainability
A
1 hr
Test Coverage
F
20%
import torch

from mushroom_rl.algorithms.policy_search.black_box_optimization import BlackBoxOptimization
from mushroom_rl.utils.minibatches import minibatch_generator
from mushroom_rl.rl_utils.parameters import to_parameter


class ePPO(BlackBoxOptimization):
    """
    Episodic adaptation of the Proximal Policy Optimization algorithm.
    "Proximal Policy Optimization Algorithms".
    Schulman J. et al.. 2017.

    """
    def __init__(self, mdp_info, distribution, policy, optimizer, n_epochs_policy, batch_size, eps_ppo, ent_coeff=0.0,
                 context_builder=None):
        """
        Constructor.

        Args:
            optimizer: the gradient step optimizer.

        """
        assert hasattr(distribution, 'parameters')

        self._optimizer = optimizer['class'](distribution.parameters(), **optimizer['params'])
        self._n_epochs_policy = to_parameter(n_epochs_policy)
        self._batch_size = to_parameter(batch_size)
        self._eps_ppo = to_parameter(eps_ppo)
        self._ent_coeff = to_parameter(ent_coeff)

        super().__init__(mdp_info, distribution, policy, context_builder=context_builder, backend='torch')

        self._add_save_attr(
            _optimizer='torch',
            _n_epochs_policy='mushroom',
            _batch_size='mushroom',
            _eps_ppo='mushroom',
            _ent_coeff='mushroom',
        )

    def _update(self, Jep, theta, context):
        Jep = torch.tensor(Jep)
        J_mean = torch.mean(Jep)
        J_std = torch.std(Jep)

        Jep = (Jep - J_mean) / (J_std + 1e-8)

        old_dist = self.distribution.log_pdf(theta).detach()

        if self.distribution.is_contextual:
            full_batch = (theta, Jep, old_dist, context)
        else:
            full_batch = (theta, Jep, old_dist)

        for epoch in range(self._n_epochs_policy()):
            for minibatch in minibatch_generator(self._batch_size(), *full_batch):

                theta_i, context_i, Jep_i, old_dist_i = self._unpack(minibatch)

                self._optimizer.zero_grad()
                prob_ratio = torch.exp(self.distribution.log_pdf(theta_i, context_i) - old_dist_i)
                clipped_ratio = torch.clamp(prob_ratio, 1 - self._eps_ppo(), 1 + self._eps_ppo.get_value())
                loss = -torch.mean(torch.min(prob_ratio * Jep_i, clipped_ratio * Jep_i))
                loss -= self._ent_coeff() * self.distribution.entropy(context_i)
                loss.backward()
                self._optimizer.step()

    def _unpack(self, minibatch):
        if self.distribution.is_contextual:
            theta_i, Jep_i, old_dist_i, context_i = minibatch
        else:
            theta_i, Jep_i, old_dist_i = minibatch
            context_i = None

        return theta_i, context_i, Jep_i, old_dist_i