MushroomRL/mushroom-rl

View on GitHub
mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py

Summary

Maintainability
F
3 days
Test Coverage
F
16%
import torch

from mushroom_rl.algorithms.actor_critic.deep_actor_critic import OnPolicyDeepAC
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.torch import TorchUtils
from mushroom_rl.utils.minibatches import minibatch_generator
from mushroom_rl.rl_utils.parameters import to_parameter


class PPO_BPTT(OnPolicyDeepAC):
    """
    Proximal Policy Optimization algorithm.
    "Proximal Policy Optimization Algorithms".
    Schulman J. et al.. 2017.

    """
    def __init__(self, mdp_info, policy, actor_optimizer, critic_params,
                 n_epochs_policy, batch_size, eps_ppo, lam, dim_env_state, ent_coeff=0.0,
                 critic_fit_params=None, truncation_length=5):
        """
        Constructor.

        Args:
            policy (TorchPolicy): torch policy to be learned by the algorithm
            actor_optimizer (dict): parameters to specify the actor optimizer
                algorithm;
            critic_params (dict): parameters of the critic approximator to
                build;
            n_epochs_policy ([int, Parameter]): number of policy updates for every dataset;
            batch_size ([int, Parameter]): size of minibatches for every optimization step
            eps_ppo ([float, Parameter]): value for probability ratio clipping;
            lam ([float, Parameter], 1.): lambda coefficient used by generalized
                advantage estimation;
            ent_coeff ([float, Parameter], 1.): coefficient for the entropy regularization term;
            critic_fit_params (dict, None): parameters of the fitting algorithm
                of the critic approximator.

        """
        self._critic_fit_params = dict(n_epochs=10) if critic_fit_params is None else critic_fit_params

        self._n_epochs_policy = to_parameter(n_epochs_policy)
        self._batch_size = to_parameter(batch_size)
        self._eps_ppo = to_parameter(eps_ppo)

        self._optimizer = actor_optimizer['class'](policy.parameters(), **actor_optimizer['params'])

        self._lambda = to_parameter(lam)
        self._ent_coeff = to_parameter(ent_coeff)

        self._V = Regressor(TorchApproximator, **critic_params)

        self._truncation_length = truncation_length
        self._dim_env_state = dim_env_state

        self._iter = 1

        super().__init__(mdp_info, policy, backend='torch')

        self._add_save_attr(
            _critic_fit_params='pickle',
            _n_epochs_policy='mushroom',
            _batch_size='mushroom',
            _eps_ppo='mushroom',
            _ent_coeff='mushroom',
            _optimizer='torch',
            _lambda='mushroom',
            _V='mushroom',
            _iter='primitive',
            _dim_env_state='primitive'
        )

    def fit(self, dataset):
        state, action, reward, next_state, absorbing, last = dataset.parse(to='torch')
        state, next_state, state_old = self._preprocess_state(state, next_state)

        policy_state, policy_next_state = dataset.parse_policy_state(to='torch')
        state_old_seq, state_seq, policy_state_seq, act_seq, state_next_seq, policy_next_state_seq, lengths = \
            self._transform_to_sequences(state_old, state, policy_state, action, next_state, policy_next_state,
                                         last, absorbing)

        v_target, adv = self.compute_gae(self._V, state_seq, policy_state_seq, state_next_seq, policy_next_state_seq,
                                         lengths, reward, absorbing, last, self.mdp_info.gamma, self._lambda())
        adv = (adv - torch.mean(adv)) / (torch.std(adv) + 1e-8)

        old_pol_dist = self.policy.distribution_t(state_old_seq, policy_state_seq, lengths)
        old_log_p = old_pol_dist.log_prob(action)[:, None].detach()

        self._V.fit(state_seq, policy_state_seq, lengths, v_target, **self._critic_fit_params)

        self._update_policy(state_seq, policy_state_seq, action, lengths, adv, old_log_p)

        # Print fit information
        self._log_info(dataset, state_seq, policy_state_seq, lengths, v_target, old_pol_dist)
        self._iter += 1

    def _transform_to_sequences(self, states_old, states, policy_states, actions, next_states, policy_next_states,
                                last, absorbing):
        with torch.no_grad():
            s_old = torch.empty(len(states), self._truncation_length, states.shape[-1])
            s = torch.empty(len(states), self._truncation_length, states.shape[-1])
            ps = torch.empty(len(states), policy_states.shape[-1])
            a = torch.empty(len(actions), self._truncation_length, actions.shape[-1])
            ss = torch.empty(len(states), self._truncation_length, states.shape[-1])
            pss = torch.empty(len(states), policy_states.shape[-1])
            lengths = torch.empty(len(states), dtype=torch.long)

            for i in range(len(states)):
                # determine the begin of a sequence
                begin_seq = max(i - self._truncation_length + 1, 0)
                end_seq = i + 1

                # the sequence may contain more than one trajectory, we need to cut it so that it contains only one
                lasts_absorbing = last[begin_seq - 1: i].int() + absorbing[begin_seq - 1: i].int()
                begin_traj = torch.where(lasts_absorbing > 0)
                sequence_is_shorter_than_requested = len(*begin_traj) > 0
                if sequence_is_shorter_than_requested:
                    begin_seq = begin_seq + begin_traj[0][-1]

                # get the sequences
                states_old_seq = states_old[begin_seq:end_seq]
                states_seq = states[begin_seq:end_seq]
                actions_seq = actions[begin_seq:end_seq]
                next_states_seq = next_states[begin_seq:end_seq]

                # apply padding
                length_seq = len(states_seq)
                padded_states_old = torch.concatenate([states_old_seq,
                                                       torch.zeros((self._truncation_length - states_old_seq.shape[0],
                                                                    states_old_seq.shape[1]))])
                padded_states = torch.concatenate([states_seq,
                                                   torch.zeros((self._truncation_length - states_seq.shape[0],
                                                                states_seq.shape[1]))])
                padded_next_states = torch.concatenate([next_states_seq,
                                                        torch.zeros((self._truncation_length - next_states_seq.shape[0],
                                                                     next_states_seq.shape[1]))])
                padded_action_seq = torch.concatenate([actions_seq,
                                                       torch.zeros((self._truncation_length - actions_seq.shape[0],
                                                                    actions_seq.shape[1]))])

                s_old[i] = padded_states_old
                s[i] = padded_states
                ps[i] = policy_states[begin_seq]
                a[i] = padded_action_seq
                ss[i] = padded_next_states
                pss[i] = policy_next_states[begin_seq]

                lengths[i] = length_seq

            return s_old, s, ps, a, ss, pss, lengths

    def _update_policy(self, obs, pi_h, act, lengths, adv, old_log_p):
        for epoch in range(self._n_epochs_policy()):
            for obs_i, pi_h_i, act_i, length_i, adv_i, old_log_p_i in minibatch_generator(
                    self._batch_size(), obs, pi_h, act, lengths, adv, old_log_p):
                self._optimizer.zero_grad()
                prob_ratio = torch.exp(
                    self.policy.log_prob_t(obs_i, act_i, pi_h_i, length_i) - old_log_p_i
                )
                clipped_ratio = torch.clamp(prob_ratio, 1 - self._eps_ppo(), 1 + self._eps_ppo.get_value())
                loss = -torch.mean(torch.min(prob_ratio * adv_i, clipped_ratio * adv_i))
                loss -= self._ent_coeff()*self.policy.entropy_t(obs_i)
                loss.backward()
                self._optimizer.step()

    def _log_info(self, dataset, x, pi_h, lengths, v_target, old_pol_dist):
        pass

    def _post_load(self):
        if self._optimizer is not None:
            TorchUtils.update_optimizer_parameters(self._optimizer, list(self.policy.parameters()))

    @staticmethod
    def compute_gae(V, s, pi_h, ss, pi_hn, lengths, r, absorbing, last, gamma, lam):
        """
        Function to compute Generalized Advantage Estimation (GAE)
        and new value function target over a dataset.

        "High-Dimensional Continuous Control Using Generalized
        Advantage Estimation".
        Schulman J. et al.. 2016.

        Args:
            V (Regressor): the current value function regressor;
            s (numpy.ndarray): the set of states in which we want
                to evaluate the advantage;
            ss (numpy.ndarray): the set of next states in which we want
                to evaluate the advantage;
            r (numpy.ndarray): the reward obtained in each transition
                from state s to state ss;
            absorbing (numpy.ndarray): an array of boolean flags indicating
                if the reached state is absorbing;
            last (numpy.ndarray): an array of boolean flags indicating
                if the reached state is the last of the trajectory;
            gamma (float): the discount factor of the considered problem;
            lam (float): the value for the lamba coefficient used by GEA
                algorithm.
        Returns:
            The new estimate for the value function of the next state
            and the estimated generalized advantage.
        """
        with torch.no_grad():
            v = V(s, pi_h, lengths)
            v_next = V(ss, pi_hn, lengths)
            gen_adv = torch.empty_like(v)
            for rev_k in range(len(v)):
                k = len(v) - rev_k - 1
                if last[k] or rev_k == 0:
                    gen_adv[k] = r[k] - v[k]
                    if not absorbing[k]:
                        gen_adv[k] += gamma * v_next[k]
                else:
                    gen_adv[k] = r[k] + gamma * v_next[k] - v[k] + gamma * lam * gen_adv[k + 1]

            return gen_adv + v, gen_adv