MushroomRL/mushroom-rl

View on GitHub
mushroom_rl/policy/torch_policy.py

Summary

Maintainability
A
1 hr
Test Coverage
A
93%
import numpy as np

import torch
import torch.nn as nn

from mushroom_rl.policy import Policy
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.torch import TorchUtils, CategoricalWrapper
from mushroom_rl.rl_utils.parameters import to_parameter

from itertools import chain


class TorchPolicy(Policy):
    """
    Interface for a generic PyTorch policy.
    A PyTorch policy is a policy implemented as a neural network using PyTorch.
    Functions ending with '_t' use tensors as input, and also as output when
    required.

    """

    # TODO: remove TorchUtils.to_float_tensor(array) and update the docstring to replace np.ndarray.

    def __init__(self, policy_state_shape=None):
        """
        Constructor.

        """
        super().__init__(policy_state_shape)

    def __call__(self, state, action, policy_state=None):
        s = TorchUtils.to_float_tensor(torch.atleast_2d(state))
        a = TorchUtils.to_float_tensor(torch.atleast_2d(action))

        return torch.exp(self.log_prob_t(s, a))

    def draw_action(self, state, policy_state=None):
        with torch.no_grad():
            s = TorchUtils.to_float_tensor(torch.atleast_2d(state))
            a = self.draw_action_t(s)

        return torch.squeeze(a, dim=0).detach(), None

    def distribution(self, state):
        """
        Compute the policy distribution in the given states.

        Args:
            state (np.ndarray): the set of states where the distribution is
                computed.

        Returns:
            The torch distribution for the provided states.

        """
        s = TorchUtils.to_float_tensor(state)

        return self.distribution_t(s)

    def entropy(self, state=None):
        """
        Compute the entropy of the policy.

        Args:
            state (np.ndarray, None): the set of states to consider. If the
                entropy of the policy can be computed in closed form, then
                ``state`` can be None.

        Returns:
            The value of the entropy of the policy.

        """
        s = TorchUtils.to_float_tensor(state) if state is not None else None

        return self.entropy_t(s).detach()

    def draw_action_t(self, state):
        """
        Draw an action given a tensor.

        Args:
            state (torch.Tensor): set of states.

        Returns:
            The tensor of the actions to perform in each state.

        """
        raise NotImplementedError

    def log_prob_t(self, state, action):
        """
        Compute the logarithm of the probability of taking ``action`` in
        ``state``.

        Args:
            state (torch.Tensor): set of states.
            action (torch.Tensor): set of actions.

        Returns:
            The tensor of log-probability.

        """
        raise NotImplementedError

    def entropy_t(self, state):
        """
        Compute the entropy of the policy.

        Args:
            state (torch.Tensor): the set of states to consider. If the
                entropy of the policy can be computed in closed form, then
                ``state`` can be None.

        Returns:
            The tensor value of the entropy of the policy.

        """
        raise NotImplementedError

    def distribution_t(self, state):
        """
        Compute the policy distribution in the given states.

        Args:
            state (torch.Tensor): the set of states where the distribution is
                computed.

        Returns:
            The torch distribution for the provided states.

        """
        raise NotImplementedError

    def set_weights(self, weights):
        """
        Setter.

        Args:
            weights (np.ndarray): the vector of the new weights to be used by
                the policy.

        """
        raise NotImplementedError

    def get_weights(self):
        """
        Getter.

        Returns:
             The current policy weights.

        """
        raise NotImplementedError

    def parameters(self):
        """
        Returns the trainable policy parameters, as expected by torch
        optimizers.

        Returns:
            List of parameters to be optimized.

        """
        raise NotImplementedError


class GaussianTorchPolicy(TorchPolicy):
    """
    Torch policy implementing a Gaussian policy with trainable standard
    deviation. The standard deviation is not state-dependent.

    """
    def __init__(self, network, input_shape, output_shape, std_0=1., policy_state_shape=None, **params):
        """
        Constructor.

        Args:
            network (object): the network class used to implement the mean
                regressor;
            input_shape (tuple): the shape of the state space;
            output_shape (tuple): the shape of the action space;
            std_0 (float, 1.): initial standard deviation;
            params (dict): parameters used by the network constructor.

        """
        super().__init__(policy_state_shape)

        self._action_dim = output_shape[0]

        self._mu = Regressor(TorchApproximator, input_shape, output_shape, network=network, **params)
        self._predict_params = dict()

        log_sigma_init = torch.ones(self._action_dim, device=TorchUtils.get_device()) * torch.log(TorchUtils.to_float_tensor(std_0))

        self._log_sigma = nn.Parameter(log_sigma_init)

        self._add_save_attr(
            _action_dim='primitive',
            _mu='mushroom',
            _predict_params='pickle',
            _log_sigma='torch'
        )

    def draw_action_t(self, state):
        return self.distribution_t(state).sample().detach()

    def log_prob_t(self, state, action):
        return self.distribution_t(state).log_prob(action)[:, None]

    def entropy_t(self, state=None):
        return self._action_dim / 2 * torch.log(TorchUtils.to_float_tensor(2 * np.pi * np.e))\
               + torch.sum(self._log_sigma)

    def distribution_t(self, state):
        mu, chol_sigma = self.get_mean_and_chol(state)
        return torch.distributions.MultivariateNormal(loc=mu, scale_tril=chol_sigma, validate_args=False)

    def get_mean_and_chol(self, state):
        assert torch.all(torch.exp(self._log_sigma) > 0)
        return self._mu(state, **self._predict_params), torch.diag(torch.exp(self._log_sigma))

    def set_weights(self, weights):
        log_sigma_data = TorchUtils.to_float_tensor(weights[-self._action_dim:])
        self._log_sigma.data = log_sigma_data

        self._mu.set_weights(weights[:-self._action_dim])

    def get_weights(self):
        mu_weights = self._mu.get_weights()
        sigma_weights = self._log_sigma.data.detach()

        return torch.concatenate([mu_weights, sigma_weights])

    def parameters(self):
        return chain(self._mu.model.network.parameters(), [self._log_sigma])


class BoltzmannTorchPolicy(TorchPolicy):
    """
    Torch policy implementing a Boltzmann policy.

    """
    def __init__(self, network, input_shape, output_shape, beta, policy_state_shape=None, **params):
        """
        Constructor.

        Args:
            network (object): the network class used to implement the mean
                regressor;
            input_shape (tuple): the shape of the state space;
            output_shape (tuple): the shape of the action space;
            beta ([float, Parameter]): the inverse of the temperature distribution. As
                the temperature approaches infinity, the policy becomes more and
                more random. As the temperature approaches 0.0, the policy becomes
                more and more greedy.
            **params: parameters used by the network constructor.

        """
        super().__init__(policy_state_shape)

        self._action_dim = output_shape[0]
        self._predict_params = dict()

        self._logits = Regressor(TorchApproximator, input_shape, output_shape, network=network, **params)
        self._beta = to_parameter(beta)

        self._add_save_attr(
            _action_dim='primitive',
            _predict_params='pickle',
            _beta='mushroom',
            _logits='mushroom'
        )

    def draw_action_t(self, state):
        action = self.distribution_t(state).sample().detach()

        if len(action.shape) > 1:
            return action
        else:
            return action.unsqueeze(0)

    def log_prob_t(self, state, action):
        return self.distribution_t(state).log_prob(action)[:, None]

    def entropy_t(self, state):
        return torch.mean(self.distribution_t(state).entropy())

    def distribution_t(self, state):
        logits = self._logits(state, **self._predict_params) * self._beta(state.numpy())
        return CategoricalWrapper(logits)

    def set_weights(self, weights):
        self._logits.set_weights(weights)

    def get_weights(self):
        return self._logits.get_weights()

    def parameters(self):
        return self._logits.model.network.parameters()

    def set_beta(self, beta):
        self._beta = to_parameter(beta)