tensorflow/models

View on GitHub
research/efficient-hrl/agents/circular_buffer.py

Summary

Maintainability
A
1 hr
Test Coverage
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A circular buffer where each element is a list of tensors.

Each element of the buffer is a list of tensors. An example use case is a replay
buffer in reinforcement learning, where each element is a list of tensors
representing the state, action, reward etc.

New elements are added sequentially, and once the buffer is full, we
start overwriting them in a circular fashion. Reading does not remove any
elements, only adding new elements does.
"""

import collections
import numpy as np
import tensorflow as tf

import gin.tf


@gin.configurable
class CircularBuffer(object):
  """A circular buffer where each element is a list of tensors."""

  def __init__(self, buffer_size=1000, scope='replay_buffer'):
    """Circular buffer of list of tensors.

    Args:
      buffer_size: (integer) maximum number of tensor lists the buffer can hold.
      scope: (string) variable scope for creating the variables.
    """
    self._buffer_size = np.int64(buffer_size)
    self._scope = scope
    self._tensors = collections.OrderedDict()
    with tf.variable_scope(self._scope):
      self._num_adds = tf.Variable(0, dtype=tf.int64, name='num_adds')
    self._num_adds_cs = tf.CriticalSection(name='num_adds')

  @property
  def buffer_size(self):
    return self._buffer_size

  @property
  def scope(self):
    return self._scope

  @property
  def num_adds(self):
    return self._num_adds

  def _create_variables(self, tensors):
    with tf.variable_scope(self._scope):
      for name in tensors.keys():
        tensor = tensors[name]
        self._tensors[name] = tf.get_variable(
            name='BufferVariable_' + name,
            shape=[self._buffer_size] + tensor.get_shape().as_list(),
            dtype=tensor.dtype,
            trainable=False)

  def _validate(self, tensors):
    """Validate shapes of tensors."""
    if len(tensors) != len(self._tensors):
      raise ValueError('Expected tensors to have %d elements. Received %d '
                       'instead.' % (len(self._tensors), len(tensors)))
    if self._tensors.keys() != tensors.keys():
      raise ValueError('The keys of tensors should be the always the same.'
                       'Received %s instead %s.' %
                       (tensors.keys(), self._tensors.keys()))
    for name, tensor in tensors.items():
      if tensor.get_shape().as_list() != self._tensors[
          name].get_shape().as_list()[1:]:
        raise ValueError('Tensor %s has incorrect shape.' % name)
      if not tensor.dtype.is_compatible_with(self._tensors[name].dtype):
        raise ValueError(
            'Tensor %s has incorrect data type. Expected %s, received %s' %
            (name, self._tensors[name].read_value().dtype, tensor.dtype))

  def add(self, tensors):
    """Adds an element (list/tuple/dict of tensors) to the buffer.

    Args:
      tensors: (list/tuple/dict of tensors) to be added to the buffer.
    Returns:
      An add operation that adds the input `tensors` to the buffer. Similar to
        an enqueue_op.
    Raises:
      ValueError: If the shapes and data types of input `tensors' are not the
        same across calls to the add function.
    """
    return self.maybe_add(tensors, True)

  def maybe_add(self, tensors, condition):
    """Adds an element (tensors) to the buffer based on the condition..

    Args:
      tensors: (list/tuple of tensors) to be added to the buffer.
      condition: A boolean Tensor controlling whether the tensors would be added
        to the buffer or not.
    Returns:
      An add operation that adds the input `tensors` to the buffer. Similar to
        an maybe_enqueue_op.
    Raises:
      ValueError: If the shapes and data types of input `tensors' are not the
        same across calls to the add function.
    """
    if not isinstance(tensors, dict):
      names = [str(i) for i in range(len(tensors))]
      tensors = collections.OrderedDict(zip(names, tensors))
    if not isinstance(tensors, collections.OrderedDict):
      tensors = collections.OrderedDict(
          sorted(tensors.items(), key=lambda t: t[0]))
    if not self._tensors:
      self._create_variables(tensors)
    else:
      self._validate(tensors)

    #@tf.critical_section(self._position_mutex)
    def _increment_num_adds():
      # Adding 0 to the num_adds variable is a trick to read the value of the
      # variable and return a read-only tensor. Doing this in a critical
      # section allows us to capture a snapshot of the variable that will
      # not be affected by other threads updating num_adds.
      return self._num_adds.assign_add(1) + 0
    def _add():
      num_adds_inc = self._num_adds_cs.execute(_increment_num_adds)
      current_pos = tf.mod(num_adds_inc - 1, self._buffer_size)
      update_ops = []
      for name in self._tensors.keys():
        update_ops.append(
            tf.scatter_update(self._tensors[name], current_pos, tensors[name]))
      return tf.group(*update_ops)

    return tf.contrib.framework.smart_cond(condition, _add, tf.no_op)

  def get_random_batch(self, batch_size, keys=None, num_steps=1):
    """Samples a batch of tensors from the buffer with replacement.

    Args:
      batch_size: (integer) number of elements to sample.
      keys: List of keys of tensors to retrieve. If None retrieve all.
      num_steps: (integer) length of trajectories to return. If > 1 will return
        a list of lists, where each internal list represents a trajectory of
        length num_steps.
    Returns:
      A list of tensors, where each element in the list is a batch sampled from
        one of the tensors in the buffer.
    Raises:
      ValueError: If get_random_batch is called before calling the add function.
      tf.errors.InvalidArgumentError: If this operation is executed before any
        items are added to the buffer.
    """
    if not self._tensors:
      raise ValueError('The add function must be called before get_random_batch.')
    if keys is None:
      keys = self._tensors.keys()

    latest_start_index = self.get_num_adds() - num_steps + 1
    empty_buffer_assert = tf.Assert(
        tf.greater(latest_start_index, 0),
        ['Not enough elements have been added to the buffer.'])
    with tf.control_dependencies([empty_buffer_assert]):
      max_index = tf.minimum(self._buffer_size, latest_start_index)
      indices = tf.random_uniform(
          [batch_size],
          minval=0,
          maxval=max_index,
          dtype=tf.int64)
      if num_steps == 1:
        return self.gather(indices, keys)
      else:
        return self.gather_nstep(num_steps, indices, keys)

  def gather(self, indices, keys=None):
    """Returns elements at the specified indices from the buffer.

    Args:
      indices: (list of integers or rank 1 int Tensor) indices in the buffer to
        retrieve elements from.
      keys: List of keys of tensors to retrieve. If None retrieve all.
    Returns:
      A list of tensors, where each element in the list is obtained by indexing
        one of the tensors in the buffer.
    Raises:
      ValueError: If gather is called before calling the add function.
      tf.errors.InvalidArgumentError: If indices are bigger than the number of
        items in the buffer.
    """
    if not self._tensors:
      raise ValueError('The add function must be called before calling gather.')
    if keys is None:
      keys = self._tensors.keys()
    with tf.name_scope('Gather'):
      index_bound_assert = tf.Assert(
          tf.less(
              tf.to_int64(tf.reduce_max(indices)),
              tf.minimum(self.get_num_adds(), self._buffer_size)),
          ['Index out of bounds.'])
      with tf.control_dependencies([index_bound_assert]):
        indices = tf.convert_to_tensor(indices)

      batch = []
      for key in keys:
        batch.append(tf.gather(self._tensors[key], indices, name=key))
      return batch

  def gather_nstep(self, num_steps, indices, keys=None):
    """Returns elements at the specified indices from the buffer.

    Args:
      num_steps: (integer) length of trajectories to return.
      indices: (list of rank num_steps int Tensor) indices in the buffer to
        retrieve elements from for multiple trajectories. Each Tensor in the
        list represents the indices for a trajectory.
      keys: List of keys of tensors to retrieve. If None retrieve all.
    Returns:
      A list of list-of-tensors, where each element in the list is obtained by
        indexing one of the tensors in the buffer.
    Raises:
      ValueError: If gather is called before calling the add function.
      tf.errors.InvalidArgumentError: If indices are bigger than the number of
        items in the buffer.
    """
    if not self._tensors:
      raise ValueError('The add function must be called before calling gather.')
    if keys is None:
      keys = self._tensors.keys()
    with tf.name_scope('Gather'):
      index_bound_assert = tf.Assert(
          tf.less_equal(
              tf.to_int64(tf.reduce_max(indices) + num_steps),
              self.get_num_adds()),
          ['Trajectory indices go out of bounds.'])
      with tf.control_dependencies([index_bound_assert]):
        indices = tf.map_fn(
            lambda x: tf.mod(tf.range(x, x + num_steps), self._buffer_size),
            indices,
            dtype=tf.int64)

      batch = []
      for key in keys:

        def SampleTrajectories(trajectory_indices, key=key,
                               num_steps=num_steps):
          trajectory_indices.set_shape([num_steps])
          return tf.gather(self._tensors[key], trajectory_indices, name=key)

        batch.append(tf.map_fn(SampleTrajectories, indices,
                               dtype=self._tensors[key].dtype))
      return batch

  def get_position(self):
    """Returns the position at which the last element was added.

    Returns:
      An int tensor representing the index at which the last element was added
        to the buffer or -1 if no elements were added.
    """
    return tf.cond(self.get_num_adds() < 1,
                   lambda: self.get_num_adds() - 1,
                   lambda: tf.mod(self.get_num_adds() - 1, self._buffer_size))

  def get_num_adds(self):
    """Returns the number of additions to the buffer.

    Returns:
      An int tensor representing the number of elements that were added.
    """
    def num_adds():
      return self._num_adds.value()

    return self._num_adds_cs.execute(num_adds)

  def get_num_tensors(self):
    """Returns the number of tensors (slots) in the buffer."""
    return len(self._tensors)