kengz/SLM-Lab

View on GitHub
slm_lab/agent/algorithm/policy_util.py

Summary

Maintainability
A
1 hr
Test Coverage
C
78%
'''
Action policy methods to sampling actions
Algorithm provides a `calc_pdparam` which takes a state and do a forward pass through its net,
and the pdparam is used to construct an action probability distribution as appropriate per the action type as indicated by the body
Then the prob. dist. is used to sample action.

The default form looks like:
```
ActionPD, pdparam, body = init_action_pd(state, algorithm, body)
action, action_pd = sample_action_pd(ActionPD, pdparam, body)
```

We can also augment pdparam before sampling - as in the case of Boltzmann sampling,
or do epsilon-greedy to use pdparam-sampling or random sampling.
'''
from slm_lab.env.wrapper import LazyFrames
from slm_lab.lib import logger, math_util, util
from torch import distributions
import numpy as np
import pydash as ps
import torch

logger = logger.get_logger(__name__)


# probability distributions constraints for different action types; the first in the list is the default
ACTION_PDS = {
    'continuous': ['Normal', 'Beta', 'Gumbel', 'LogNormal'],
    'multi_continuous': ['MultivariateNormal'],
    'discrete': ['Categorical', 'Argmax'],
    'multi_discrete': ['MultiCategorical'],
    'multi_binary': ['Bernoulli'],
}


class Argmax(distributions.Categorical):
    '''
    Special distribution class for argmax sampling, where probability is always 1 for the argmax.
    NOTE although argmax is not a sampling distribution, this implementation is for API consistency.
    '''

    def __init__(self, probs=None, logits=None, validate_args=None):
        if probs is not None:
            new_probs = torch.zeros_like(probs, dtype=torch.float)
            new_probs[torch.argmax(probs, dim=0)] = 1.0
            probs = new_probs
        elif logits is not None:
            new_logits = torch.full_like(logits, -1e8, dtype=torch.float)
            max_idx = torch.argmax(logits, dim=0)
            new_logits[max_idx] = logits[max_idx]
            logits = new_logits

        super(Argmax, self).__init__(probs=probs, logits=logits, validate_args=validate_args)


class MultiCategorical(distributions.Categorical):
    '''MultiCategorical as collection of Categoricals'''

    def __init__(self, probs=None, logits=None, validate_args=None):
        self.categoricals = []
        if probs is None:
            probs = [None] * len(logits)
        elif logits is None:
            logits = [None] * len(probs)
        else:
            raise ValueError('Either probs or logits must be None')

        for sub_probs, sub_logits in zip(probs, logits):
            categorical = distributions.Categorical(probs=sub_probs, logits=sub_logits, validate_args=validate_args)
            self.categoricals.append(categorical)

    @property
    def logits(self):
        return [cat.logits for cat in self.categoricals]

    @property
    def probs(self):
        return [cat.probs for cat in self.categoricals]

    @property
    def param_shape(self):
        return [cat.param_shape for cat in self.categoricals]

    @property
    def mean(self):
        return torch.stack([cat.mean for cat in self.categoricals])

    @property
    def variance(self):
        return torch.stack([cat.variance for cat in self.categoricals])

    def sample(self, sample_shape=torch.Size()):
        return torch.stack([cat.sample(sample_shape=sample_shape) for cat in self.categoricals])

    def log_prob(self, value):
        return torch.stack([cat.log_prob(value[idx]) for idx, cat in enumerate(self.categoricals)])

    def entropy(self):
        return torch.stack([cat.entropy() for cat in self.categoricals])

    def enumerate_support(self):
        return [cat.enumerate_support() for cat in self.categoricals]


setattr(distributions, 'Argmax', Argmax)
setattr(distributions, 'MultiCategorical', MultiCategorical)


# base methods

def try_preprocess(state, algorithm, body, append=True):
    '''Try calling preprocess as implemented in body's memory to use for net input'''
    if isinstance(state, LazyFrames):
        state = state.__array__()  # from global env preprocessor
    if hasattr(body.memory, 'preprocess_state'):
        state = body.memory.preprocess_state(state, append=append)
    # as float, and always as minibatch for net input
    state = torch.from_numpy(state).float().unsqueeze_(dim=0)
    return state


def cond_squeeze(out):
    '''Helper to squeeze output depending if it is tensor (discrete pdparam) or list of tensors (continuous pdparam of loc and scale)'''
    if isinstance(out, list):
        for out_t in out:
            out_t.squeeze_(dim=0)
    else:
        out.squeeze_(dim=0)
    return out


def init_action_pd(state, algorithm, body, append=True):
    '''
    Build the proper action prob. dist. to use for action sampling.
    state is passed through algorithm's net via calc_pdparam, which the algorithm must implement using its proper net.
    This will return body, ActionPD and pdparam to allow augmentation, e.g. applying temperature tau to pdparam for boltzmann.
    Then, output must be called with sample_action_pd(body, ActionPD, pdparam) to sample action.
    @returns {cls, tensor, *} ActionPD, pdparam, body
    '''
    pdtypes = ACTION_PDS[body.action_type]
    assert body.action_pdtype in pdtypes, f'Pdtype {body.action_pdtype} is not compatible/supported with action_type {body.action_type}. Options are: {ACTION_PDS[body.action_type]}'
    ActionPD = getattr(distributions, body.action_pdtype)

    state = try_preprocess(state, algorithm, body, append=append)
    state = state.to(algorithm.net.device)
    pdparam = algorithm.calc_pdparam(state, evaluate=False)
    return ActionPD, pdparam, body


def sample_action_pd(ActionPD, pdparam, body):
    '''
    This uses the outputs from init_action_pd and an optionally augmented pdparam to construct a action_pd for sampling action
    @returns {tensor, distribution} action, action_pd A sampled action, and the prob. dist. used for sampling to enable calculations like kl, entropy, etc. later.
    '''
    pdparam = cond_squeeze(pdparam)
    if body.is_discrete:
        action_pd = ActionPD(logits=pdparam)
    else:  # continuous outputs a list, loc and scale
        assert len(pdparam) == 2, pdparam
        # scale (stdev) must be >0, use softplus
        if pdparam[1] < 5:
            pdparam[1] = torch.log(1 + torch.exp(pdparam[1])) + 1e-8
        action_pd = ActionPD(*pdparam)
    action = action_pd.sample()
    return action, action_pd


# interface action sampling methods


def default(state, algorithm, body):
    '''Plain policy by direct sampling using outputs of net as logits and constructing ActionPD as appropriate'''
    ActionPD, pdparam, body = init_action_pd(state, algorithm, body)
    action, action_pd = sample_action_pd(ActionPD, pdparam, body)
    return action, action_pd


def random(state, algorithm, body):
    '''Random action sampling that returns the same data format as default(), but without forward pass. Uses gym.space.sample()'''
    state = try_preprocess(state, algorithm, body, append=True)  # for consistency with init_action_pd inner logic
    if body.action_type == 'discrete':
        action_pd = distributions.Categorical(logits=torch.ones(body.action_space.high, device=algorithm.net.device))
    elif body.action_type == 'continuous':
        # Possibly this should this have a 'device' set
        action_pd = distributions.Uniform(
            low=torch.tensor(body.action_space.low).float(),
            high=torch.tensor(body.action_space.high).float())
    elif body.action_type == 'multi_discrete':
        action_pd = distributions.Categorical(
            logits=torch.ones(body.action_space.high.size, body.action_space.high[0], device=algorithm.net.device))
    elif body.action_type == 'multi_continuous':
        raise NotImplementedError
    elif body.action_type == 'multi_binary':
        raise NotImplementedError
    else:
        raise NotImplementedError
    sample = body.action_space.sample()
    action = torch.tensor(sample, device=algorithm.net.device)
    return action, action_pd


def epsilon_greedy(state, algorithm, body):
    '''Epsilon-greedy policy: with probability epsilon, do random action, otherwise do default sampling.'''
    epsilon = body.explore_var
    if epsilon > np.random.rand():
        return random(state, algorithm, body)
    else:
        return default(state, algorithm, body)


def boltzmann(state, algorithm, body):
    '''
    Boltzmann policy: adjust pdparam with temperature tau; the higher the more randomness/noise in action.
    '''
    tau = body.explore_var
    ActionPD, pdparam, body = init_action_pd(state, algorithm, body)
    pdparam /= tau
    action, action_pd = sample_action_pd(ActionPD, pdparam, body)
    return action, action_pd


# multi-body policy with a single forward pass to calc pdparam

def multi_default(states, algorithm, body_list, pdparam):
    '''
    Apply default policy body-wise
    Note, for efficiency, do a single forward pass to calculate pdparam, then call this policy like:
    @example

    pdparam = self.calc_pdparam(state, evaluate=False)
    action_a, action_pd_a = self.action_policy(pdparam, self, body_list)
    '''
    pdparam.squeeze_(dim=0)
    # assert pdparam has been chunked
    assert len(pdparam.shape) > 1 and len(pdparam) == len(body_list), f'pdparam shape: {pdparam.shape}, bodies: {len(body_list)}'
    action_list, action_pd_a = [], []
    for idx, sub_pdparam in enumerate(pdparam):
        body = body_list[idx]
        try_preprocess(states[idx], algorithm, body, append=True)  # for consistency with init_action_pd inner logic
        ActionPD = getattr(distributions, body.action_pdtype)
        action, action_pd = sample_action_pd(ActionPD, sub_pdparam, body)
        action_list.append(action)
        action_pd_a.append(action_pd)
    action_a = torch.tensor(action_list, device=algorithm.net.device).unsqueeze_(dim=1)
    return action_a, action_pd_a


def multi_random(states, algorithm, body_list, pdparam):
    '''Apply random policy body-wise.'''
    pdparam.squeeze_(dim=0)
    action_list, action_pd_a = [], []
    for idx, body in body_list:
        action, action_pd = random(states[idx], algorithm, body)
        action_list.append(action)
        action_pd_a.append(action_pd)
    action_a = torch.tensor(action_list, device=algorithm.net.device).unsqueeze_(dim=1)
    return action_a, action_pd_a


def multi_epsilon_greedy(states, algorithm, body_list, pdparam):
    '''Apply epsilon-greedy policy body-wise'''
    assert len(pdparam) > 1 and len(pdparam) == len(body_list), f'pdparam shape: {pdparam.shape}, bodies: {len(body_list)}'
    action_list, action_pd_a = [], []
    for idx, sub_pdparam in enumerate(pdparam):
        body = body_list[idx]
        epsilon = body.explore_var
        if epsilon > np.random.rand():
            action, action_pd = random(states[idx], algorithm, body)
        else:
            try_preprocess(states[idx], algorithm, body, append=True)  # for consistency with init_action_pd inner logic
            ActionPD = getattr(distributions, body.action_pdtype)
            action, action_pd = sample_action_pd(ActionPD, sub_pdparam, body)
        action_list.append(action)
        action_pd_a.append(action_pd)
    action_a = torch.tensor(action_list, device=algorithm.net.device).unsqueeze_(dim=1)
    return action_a, action_pd_a


def multi_boltzmann(states, algorithm, body_list, pdparam):
    '''Apply Boltzmann policy body-wise'''
    # pdparam.squeeze_(dim=0)
    assert len(pdparam) > 1 and len(pdparam) == len(body_list), f'pdparam shape: {pdparam.shape}, bodies: {len(body_list)}'
    action_list, action_pd_a = [], []
    for idx, sub_pdparam in enumerate(pdparam):
        body = body_list[idx]
        try_preprocess(states[idx], algorithm, body, append=True)  # for consistency with init_action_pd inner logic
        tau = body.explore_var
        sub_pdparam /= tau
        ActionPD = getattr(distributions, body.action_pdtype)
        action, action_pd = sample_action_pd(ActionPD, sub_pdparam, body)
        action_list.append(action)
        action_pd_a.append(action_pd)
    action_a = torch.tensor(action_list, device=algorithm.net.device).unsqueeze_(dim=1)
    return action_a, action_pd_a


# action policy update methods

class VarScheduler:
    '''
    Variable scheduler for decaying variables such as explore_var (epsilon, tau) and entropy

    e.g. spec
    "explore_var_spec": {
        "name": "linear_decay",
        "start_val": 1.0,
        "end_val": 0.1,
        "start_step": 0,
        "end_step": 800,
    },
    '''

    def __init__(self, var_decay_spec=None):
        self._updater_name = 'no_decay' if var_decay_spec is None else var_decay_spec['name']
        self._updater = getattr(math_util, self._updater_name)
        util.set_attr(self, dict(
            start_val=np.nan,
        ))
        util.set_attr(self, var_decay_spec, [
            'start_val',
            'end_val',
            'start_step',
            'end_step',
        ])
        if not getattr(self, 'end_val', None):
            self.end_val = self.start_val

    def update(self, algorithm, clock):
        '''Get an updated value for var'''
        if (util.in_eval_lab_modes()) or self._updater_name == 'no_decay':
            return self.end_val
        step = clock.get(clock.max_tick_unit)
        val = self._updater(self.start_val, self.end_val, self.start_step, self.end_step, step)
        return val


# misc calc methods

def guard_multi_pdparams(pdparams, body):
    '''Guard pdparams for multi action'''
    action_dim = body.action_dim
    is_multi_action = ps.is_iterable(action_dim)
    if is_multi_action:
        assert ps.is_list(pdparams)
        pdparams = [t.clone() for t in pdparams]  # clone for grad safety
        assert len(pdparams) == len(action_dim), pdparams
        # transpose into (batch_size, [action_dims])
        pdparams = [list(torch.split(t, action_dim, dim=0)) for t in torch.cat(pdparams, dim=1)]
    return pdparams


def calc_log_probs(algorithm, net, body, batch):
    '''
    Method to calculate log_probs fresh from batch data
    Body already stores log_prob from self.net. This is used for PPO where log_probs needs to be recalculated.
    '''
    states, actions = batch['states'], batch['actions']
    action_dim = body.action_dim
    is_multi_action = ps.is_iterable(action_dim)
    # construct log_probs for each state-action
    pdparams = algorithm.calc_pdparam(states, net=net)
    pdparams = guard_multi_pdparams(pdparams, body)
    assert len(pdparams) == len(states), f'batch_size of pdparams: {len(pdparams)} vs states: {len(states)}'

    pdtypes = ACTION_PDS[body.action_type]
    ActionPD = getattr(distributions, body.action_pdtype)

    log_probs = []
    for idx, pdparam in enumerate(pdparams):
        if not is_multi_action:  # already cloned  for multi_action above
            pdparam = pdparam.clone()  # clone for grad safety
        _action, action_pd = sample_action_pd(ActionPD, pdparam, body)
        log_probs.append(action_pd.log_prob(actions[idx].float()).sum(dim=0))
    log_probs = torch.stack(log_probs)
    assert not torch.isnan(log_probs).any(), f'log_probs: {log_probs}, \npdparams: {pdparams} \nactions: {actions}'
    logger.debug(f'log_probs: {log_probs}')
    return log_probs


def update_online_stats(body, state):
    '''
    Method to calculate the running mean and standard deviation of the state space.
    See https://www.johndcook.com/blog/standard_deviation/ for more details
    for n >= 1
        M_n = M_n-1 + (state - M_n-1) / n
        S_n = S_n-1 + (state - M_n-1) * (state - M_n)
        variance = S_n / (n - 1)
        std_dev = sqrt(variance)
    '''
    logger.debug(f'mean: {body.state_mean}, std: {body.state_std_dev}, num examples: {body.state_n}')
    # Assumes only one state is given
    if ("Atari" in body.memory.__class__.__name__):
        assert state.ndim == 3
    elif getattr(body.memory, 'raw_state_dim', False):
        assert state.size == body.memory.raw_state_dim
    else:
        assert state.size == body.state_dim or state.shape == body.state_dim
    mean = body.state_mean
    body.state_n += 1
    if np.isnan(mean).any():
        assert np.isnan(body.state_std_dev_int)
        assert np.isnan(body.state_std_dev)
        body.state_mean = state
        body.state_std_dev_int = 0
        body.state_std_dev = 0
    else:
        assert body.state_n > 1
        body.state_mean = mean + (state - mean) / body.state_n
        body.state_std_dev_int = body.state_std_dev_int + (state - mean) * (state - body.state_mean)
        body.state_std_dev = np.sqrt(body.state_std_dev_int / (body.state_n - 1))
        # Guard against very small std devs
        if (body.state_std_dev < 1e-8).any():
            body.state_std_dev[np.where(body.state_std_dev < 1e-8)] += 1e-8
    logger.debug(f'new mean: {body.state_mean}, new std: {body.state_std_dev}, num examples: {body.state_n}')


def normalize_state(body, state):
    '''
    Normalizes one or more states using a running mean and standard deviation
    Details of the normalization from Deep RL Bootcamp, L6
    https://www.youtube.com/watch?v=8EcdaCk9KaQ&feature=youtu.be
    '''
    same_shape = False if type(state) == list else state.shape == body.state_mean.shape
    has_preprocess = getattr(body.memory, 'preprocess_state', False)
    if ('Atari' in body.memory.__class__.__name__):
        # never normalize atari, it has its own normalization step
        logger.debug('skipping normalizing for Atari, already handled by preprocess')
        return state
    elif ('Replay' in body.memory.__class__.__name__) and has_preprocess:
        # normalization handled by preprocess_state function in the memory
        logger.debug('skipping normalizing, already handled by preprocess')
        return state
    elif same_shape:
        # if not atari, always normalize the state the first time we see it during act
        # if the shape is not transformed in some way
        if np.sum(body.state_std_dev) == 0:
            return np.clip(state - body.state_mean, -10, 10)
        else:
            return np.clip((state - body.state_mean) / body.state_std_dev, -10, 10)
    else:
        # broadcastable sample from an un-normalized memory so we should normalize
        logger.debug('normalizing sample from memory')
        if np.sum(body.state_std_dev) == 0:
            return np.clip(state - body.state_mean, -10, 10)
        else:
            return np.clip((state - body.state_mean) / body.state_std_dev, -10, 10)


# TODO Not currently used, this will crash for more exotic memory structures
# def unnormalize_state(body, state):
#     '''
#     Un-normalizes one or more states using a running mean and new_std_dev
#     '''
#     return state * body.state_mean + body.state_std_dev


def update_online_stats_and_normalize_state(body, state):
    '''
    Convenience combination function for updating running state mean and std_dev and normalizing the state in one go.
    '''
    logger.debug(f'state: {state}')
    update_online_stats(body, state)
    state = normalize_state(body, state)
    logger.debug(f'normalized state: {state}')
    return state


def normalize_states_and_next_states(body, batch, episodic_flag=None):
    '''
    Convenience function for normalizing the states and next states in a batch of data
    '''
    logger.debug(f'states: {batch["states"]}')
    logger.debug(f'next states: {batch["next_states"]}')
    episodic = episodic_flag if episodic_flag is not None else body.memory.is_episodic
    logger.debug(f'Episodic: {episodic}, episodic_flag: {episodic_flag}, body.memory: {body.memory.is_episodic}')
    if episodic:
        normalized = []
        for epi in batch['states']:
            normalized.append(normalize_state(body, epi))
        batch['states'] = normalized
        normalized = []
        for epi in batch['next_states']:
            normalized.append(normalize_state(body, epi))
        batch['next_states'] = normalized
    else:
        batch['states'] = normalize_state(body, batch['states'])
        batch['next_states'] = normalize_state(body, batch['next_states'])
    logger.debug(f'normalized states: {batch["states"]}')
    logger.debug(f'normalized next states: {batch["next_states"]}')
    return batch