tensorflow/models

View on GitHub
research/efficient-hrl/context/context.py

Summary

Maintainability
F
3 days
Test Coverage
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Context for Universal Value Function agents.

A context specifies a list of contextual variables, each with
  own sampling and reward computation methods.

Examples of contextual variables include
  goal states, reward combination vectors, etc.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tf_agents import specs
import gin.tf
from utils import utils as uvf_utils


@gin.configurable
class Context(object):
  """Base context."""
  VAR_NAME = 'action'

  def __init__(self,
               tf_env,
               context_ranges=None,
               context_shapes=None,
               state_indices=None,
               variable_indices=None,
               gamma_index=None,
               settable_context=False,
               timers=None,
               samplers=None,
               reward_weights=None,
               reward_fn=None,
               random_sampler_mode='random',
               normalizers=None,
               context_transition_fn=None,
               context_multi_transition_fn=None,
               meta_action_every_n=None):
    self._tf_env = tf_env
    self.variable_indices = variable_indices
    self.gamma_index = gamma_index
    self._settable_context = settable_context
    self.timers = timers
    self._context_transition_fn = context_transition_fn
    self._context_multi_transition_fn = context_multi_transition_fn
    self._random_sampler_mode = random_sampler_mode

    # assign specs
    self._obs_spec = self._tf_env.observation_spec()
    self._context_shapes = tuple([
        shape if shape is not None else self._obs_spec.shape
        for shape in context_shapes
    ])
    self.context_specs = tuple([
        specs.TensorSpec(dtype=self._obs_spec.dtype, shape=shape)
        for shape in self._context_shapes
    ])
    if context_ranges is not None:
      self.context_ranges = context_ranges
    else:
      self.context_ranges = [None] * len(self._context_shapes)

    self.context_as_action_specs = tuple([
        specs.BoundedTensorSpec(
            shape=shape,
            dtype=(tf.float32 if self._obs_spec.dtype in
                   [tf.float32, tf.float64] else self._obs_spec.dtype),
            minimum=context_range[0],
            maximum=context_range[-1])
        for shape, context_range in zip(self._context_shapes, self.context_ranges)
    ])

    if state_indices is not None:
      self.state_indices = state_indices
    else:
      self.state_indices = [None] * len(self._context_shapes)
    if self.variable_indices is not None and self.n != len(
        self.variable_indices):
      raise ValueError(
          'variable_indices (%s) must have the same length as contexts (%s).' %
          (self.variable_indices, self.context_specs))
    assert self.n == len(self.context_ranges)
    assert self.n == len(self.state_indices)

    # assign reward/sampler fns
    self._sampler_fns = dict()
    self._samplers = dict()
    self._reward_fns = dict()

    # assign reward fns
    self._add_custom_reward_fns()
    reward_weights = reward_weights or None
    self._reward_fn = self._make_reward_fn(reward_fn, reward_weights)

    # assign samplers
    self._add_custom_sampler_fns()
    for mode, sampler_fns in samplers.items():
      self._make_sampler_fn(sampler_fns, mode)

    # create normalizers
    if normalizers is None:
      self._normalizers = [None] * len(self.context_specs)
    else:
      self._normalizers = [
          normalizer(tf.zeros(shape=spec.shape, dtype=spec.dtype))
          if normalizer is not None else None
          for normalizer, spec in zip(normalizers, self.context_specs)
      ]
    assert self.n == len(self._normalizers)

    self.meta_action_every_n = meta_action_every_n

    # create vars
    self.context_vars = {}
    self.timer_vars = {}
    self.create_vars(self.VAR_NAME)
    self.t = tf.Variable(
        tf.zeros(shape=(), dtype=tf.int32), name='num_timer_steps')

  def _add_custom_reward_fns(self):
    pass

  def _add_custom_sampler_fns(self):
    pass

  def sample_random_contexts(self, batch_size):
    """Sample random batch contexts."""
    assert self._random_sampler_mode is not None
    return self.sample_contexts(self._random_sampler_mode, batch_size)[0]

  def sample_contexts(self, mode, batch_size, state=None, next_state=None,
                      **kwargs):
    """Sample a batch of contexts.

    Args:
      mode: A string representing the mode [`train`, `explore`, `eval`].
      batch_size: Batch size.
    Returns:
      Two lists of [batch_size, num_context_dims] contexts.
    """
    contexts, next_contexts = self._sampler_fns[mode](
        batch_size, state=state, next_state=next_state,
        **kwargs)
    self._validate_contexts(contexts)
    self._validate_contexts(next_contexts)
    return contexts, next_contexts

  def compute_rewards(self, mode, states, actions, rewards, next_states,
                      contexts):
    """Compute context-based rewards.

    Args:
      mode: A string representing the mode ['uvf', 'task'].
      states: A [batch_size, num_state_dims] tensor.
      actions: A [batch_size, num_action_dims] tensor.
      rewards: A [batch_size] tensor representing unmodified rewards.
      next_states: A [batch_size, num_state_dims] tensor.
      contexts: A list of [batch_size, num_context_dims] tensors.
    Returns:
      A [batch_size] tensor representing rewards.
    """
    return self._reward_fn(states, actions, rewards, next_states,
                           contexts)

  def _make_reward_fn(self, reward_fns_list, reward_weights):
    """Returns a fn that computes rewards.

    Args:
      reward_fns_list: A fn or a list of reward fns.
      mode: A string representing the operating mode.
      reward_weights: A list of reward weights.
    """
    if not isinstance(reward_fns_list, (list, tuple)):
      reward_fns_list = [reward_fns_list]
    if reward_weights is None:
      reward_weights = [1.0] * len(reward_fns_list)
    assert len(reward_fns_list) == len(reward_weights)

    reward_fns_list = [
        self._custom_reward_fns[fn] if isinstance(fn, (str,)) else fn
        for fn in reward_fns_list
    ]

    def reward_fn(*args, **kwargs):
      """Returns rewards, discounts."""
      reward_tuples = [
          reward_fn(*args, **kwargs) for reward_fn in reward_fns_list
      ]
      rewards_list = [reward_tuple[0] for reward_tuple in reward_tuples]
      discounts_list = [reward_tuple[1] for reward_tuple in reward_tuples]
      ndims = max([r.shape.ndims for r in rewards_list])
      if ndims > 1:  # expand reward shapes to allow broadcasting
        for i in range(len(rewards_list)):
          for _ in range(rewards_list[i].shape.ndims - ndims):
            rewards_list[i] = tf.expand_dims(rewards_list[i], axis=-1)
          for _ in range(discounts_list[i].shape.ndims - ndims):
            discounts_list[i] = tf.expand_dims(discounts_list[i], axis=-1)
      rewards = tf.add_n(
          [r * tf.to_float(w) for r, w in zip(rewards_list, reward_weights)])
      discounts = discounts_list[0]
      for d in discounts_list[1:]:
        discounts *= d

      return rewards, discounts

    return reward_fn

  def _make_sampler_fn(self, sampler_cls_list, mode):
    """Returns a fn that samples a list of context vars.

    Args:
      sampler_cls_list: A list of sampler classes.
      mode: A string representing the operating mode.
    """
    if not isinstance(sampler_cls_list, (list, tuple)):
      sampler_cls_list = [sampler_cls_list]

    self._samplers[mode] = []
    sampler_fns = []
    for spec, sampler in zip(self.context_specs, sampler_cls_list):
      if isinstance(sampler, (str,)):
        sampler_fn = self._custom_sampler_fns[sampler]
      else:
        sampler_fn = sampler(context_spec=spec)
        self._samplers[mode].append(sampler_fn)
      sampler_fns.append(sampler_fn)

    def batch_sampler_fn(batch_size, state=None, next_state=None, **kwargs):
      """Sampler fn."""
      contexts_tuples = [
          sampler(batch_size, state=state, next_state=next_state, **kwargs)
          for sampler in sampler_fns]
      contexts = [c[0] for c in contexts_tuples]
      next_contexts = [c[1] for c in contexts_tuples]
      contexts = [
          normalizer.update_apply(c) if normalizer is not None else c
          for normalizer, c in zip(self._normalizers, contexts)
      ]
      next_contexts = [
          normalizer.apply(c) if normalizer is not None else c
          for normalizer, c in zip(self._normalizers, next_contexts)
      ]
      return contexts, next_contexts

    self._sampler_fns[mode] = batch_sampler_fn

  def set_env_context_op(self, context, disable_unnormalizer=False):
    """Returns a TensorFlow op that sets the environment context.

    Args:
      context: A list of context Tensor variables.
      disable_unnormalizer: Disable unnormalization.
    Returns:
      A TensorFlow op that sets the environment context.
    """
    ret_val = np.array(1.0, dtype=np.float32)
    if not self._settable_context:
      return tf.identity(ret_val)

    if not disable_unnormalizer:
      context = [
          normalizer.unapply(tf.expand_dims(c, 0))[0]
          if normalizer is not None else c
          for normalizer, c in zip(self._normalizers, context)
      ]

    def set_context_func(*env_context_values):
      tf.logging.info('[set_env_context_op] Setting gym environment context.')
      # pylint: disable=protected-access
      self.gym_env.set_context(*env_context_values)
      return ret_val
      # pylint: enable=protected-access

    with tf.name_scope('set_env_context'):
      set_op = tf.py_func(set_context_func, context, tf.float32,
                          name='set_env_context_py_func')
      set_op.set_shape([])
    return set_op

  def set_replay(self, replay):
    """Set replay buffer for samplers.

    Args:
      replay: A replay buffer.
    """
    for _, samplers in self._samplers.items():
      for sampler in samplers:
        sampler.set_replay(replay)

  def get_clip_fns(self):
    """Returns a list of clip fns for contexts.

    Returns:
      A list of fns that clip context tensors.
    """
    clip_fns = []
    for context_range in self.context_ranges:
      def clip_fn(var_, range_=context_range):
        """Clip a tensor."""
        if range_ is None:
          clipped_var = tf.identity(var_)
        elif isinstance(range_[0], (int, long, float, list, np.ndarray)):
          clipped_var = tf.clip_by_value(
              var_,
              range_[0],
              range_[1],)
        else: raise NotImplementedError(range_)
        return clipped_var
      clip_fns.append(clip_fn)
    return clip_fns

  def _validate_contexts(self, contexts):
    """Validate if contexts have right specs.

    Args:
      contexts: A list of [batch_size, num_context_dim] tensors.
    Raises:
      ValueError: If shape or dtype mismatches that of spec.
    """
    for i, (context, spec) in enumerate(zip(contexts, self.context_specs)):
      if context[0].shape != spec.shape:
        raise ValueError('contexts[%d] has invalid shape %s wrt spec shape %s' %
                         (i, context[0].shape, spec.shape))
      if context.dtype != spec.dtype:
        raise ValueError('contexts[%d] has invalid dtype %s wrt spec dtype %s' %
                         (i, context.dtype, spec.dtype))

  def context_multi_transition_fn(self, contexts, **kwargs):
    """Returns multiple future contexts starting from a batch."""
    assert self._context_multi_transition_fn
    return self._context_multi_transition_fn(contexts, None, None, **kwargs)

  def step(self, mode, agent=None, action_fn=None, **kwargs):
    """Returns [next_contexts..., next_timer] list of ops.

    Args:
      mode: a string representing the mode=[train, explore, eval].
      **kwargs: kwargs for context_transition_fn.
    Returns:
      a list of ops that set the context.
    """
    if agent is None:
      ops = []
      if self._context_transition_fn is not None:
        def sampler_fn():
          samples = self.sample_contexts(mode, 1)[0]
          return [s[0] for s in samples]
        values = self._context_transition_fn(self.vars, self.t, sampler_fn, **kwargs)
        ops += [tf.assign(var, value) for var, value in zip(self.vars, values)]
      ops.append(tf.assign_add(self.t, 1))  # increment timer
      return ops
    else:
      ops = agent.tf_context.step(mode, **kwargs)
      state = kwargs['state']
      next_state = kwargs['next_state']
      state_repr = kwargs['state_repr']
      next_state_repr = kwargs['next_state_repr']
      with tf.control_dependencies(ops):  # Step high level context before computing low level one.
        # Get the context transition function output.
        values = self._context_transition_fn(self.vars, self.t, None,
                                             state=state_repr,
                                             next_state=next_state_repr)
        # Select a new goal every C steps, otherwise use context transition.
        low_level_context = [
            tf.cond(tf.equal(self.t % self.meta_action_every_n, 0),
                    lambda: tf.cast(action_fn(next_state, context=None), tf.float32),
                    lambda: values)]
        ops = [tf.assign(var, value)
               for var, value in zip(self.vars, low_level_context)]
        with tf.control_dependencies(ops):
          return [tf.assign_add(self.t, 1)]  # increment timer
        return ops

  def reset(self, mode, agent=None, action_fn=None, state=None):
    """Returns ops that reset the context.

    Args:
      mode: a string representing the mode=[train, explore, eval].
    Returns:
      a list of ops that reset the context.
    """
    if agent is None:
      values = self.sample_contexts(mode=mode, batch_size=1)[0]
      if values is None:
        return []
      values = [value[0] for value in values]
      values[0] = uvf_utils.tf_print(
          values[0],
          values,
          message='context:reset, mode=%s' % mode,
          first_n=10,
          name='context:reset:%s' % mode)
      all_ops = []
      for _, context_vars in sorted(self.context_vars.items()):
        ops = [tf.assign(var, value) for var, value in zip(context_vars, values)]
      all_ops += ops
      all_ops.append(self.set_env_context_op(values))
      all_ops.append(tf.assign(self.t, 0))  # reset timer
      return all_ops
    else:
      ops = agent.tf_context.reset(mode)
      # NOTE: The code is currently written in such a way that the higher level
      # policy does not provide a low-level context until the second
      # observation.  Insead, we just zero-out low-level contexts.
      for key, context_vars in sorted(self.context_vars.items()):
        ops += [tf.assign(var, tf.zeros_like(var)) for var, meta_var in
                zip(context_vars, agent.tf_context.context_vars[key])]

      ops.append(tf.assign(self.t, 0))  # reset timer
      return ops

  def create_vars(self, name, agent=None):
    """Create tf variables for contexts.

    Args:
      name: Name of the variables.
    Returns:
      A list of [num_context_dims] tensors.
    """
    if agent is not None:
      meta_vars = agent.create_vars(name)
    else:
      meta_vars = {}
    assert name not in self.context_vars, ('Conflict! %s is already '
                                           'initialized.') % name
    self.context_vars[name] = tuple([
        tf.Variable(
            tf.zeros(shape=spec.shape, dtype=spec.dtype),
            name='%s_context_%d' % (name, i))
        for i, spec in enumerate(self.context_specs)
    ])
    return self.context_vars[name], meta_vars

  @property
  def n(self):
    return len(self.context_specs)

  @property
  def vars(self):
    return self.context_vars[self.VAR_NAME]

  # pylint: disable=protected-access
  @property
  def gym_env(self):
    return self._tf_env.pyenv._gym_env

  @property
  def tf_env(self):
    return self._tf_env
  # pylint: enable=protected-access