mushroom_rl/algorithms/actor_critic/deep_actor_critic/deep_actor_critic.py
from mushroom_rl.core import Agent
from mushroom_rl.utils.torch import TorchUtils
class OnPolicyDeepAC(Agent):
def _preprocess_state(self, state, next_state, output_old=True):
state_old = None
if output_old:
state_old = self._agent_preprocess(state)
self._update_agent_preprocessor(state)
state = self._agent_preprocess(state)
next_state = self._agent_preprocess(next_state)
if output_old:
return state, next_state, state_old
else:
return state, next_state
class DeepAC(Agent):
"""
Base class for off policy deep actor-critic algorithms.
These algorithms use the reparametrization trick, such as SAC, DDPG and TD3.
"""
def __init__(self, mdp_info, policy, actor_optimizer, parameters, backend='torch'):
"""
Constructor.
Args:
actor_optimizer (dict): parameters to specify the actor optimizer
algorithm;
parameters (list): policy parameters to be optimized.
"""
if actor_optimizer is not None:
if parameters is not None and not isinstance(parameters, list):
parameters = list(parameters)
self._parameters = parameters
self._optimizer = actor_optimizer['class'](parameters, **actor_optimizer['params'])
self._clipping = None
if 'clipping' in actor_optimizer:
self._clipping = actor_optimizer['clipping']['method']
self._clipping_params = actor_optimizer['clipping']['params']
super().__init__(mdp_info, policy, backend=backend)
self._add_save_attr(
_optimizer='torch',
_clipping='torch',
_clipping_params='pickle'
)
def fit(self, dataset):
"""
Fit step.
Args:
dataset (list): the dataset.
"""
raise NotImplementedError('DeepAC is an abstract class')
def _optimize_actor_parameters(self, loss):
"""
Method used to update actor parameters to maximize a given loss.
Args:
loss (torch.tensor): the loss computed by the algorithm.
"""
self._optimizer.zero_grad()
loss.backward()
self._clip_gradient()
self._optimizer.step()
def _clip_gradient(self):
if self._clipping:
self._clipping(self._parameters, **self._clipping_params)
@staticmethod
def _init_target(online, target):
for i in range(len(target)):
target[i].set_weights(online[i].get_weights())
def _update_target(self, online, target):
for i in range(len(target)):
weights = self._tau() * online[i].get_weights()
weights += (1 - self._tau.get_value()) * target[i].get_weights()
target[i].set_weights(weights)
def _update_optimizer_parameters(self, parameters):
self._parameters = list(parameters)
if self._optimizer is not None:
TorchUtils.update_optimizer_parameters(self._optimizer, self._parameters)
def _post_load(self):
raise NotImplementedError('DeepAC is an abstract class. Subclasses need to implement the `_post_load` method.')