tensorflow/models

View on GitHub
research/efficient-hrl/context/samplers.py

Summary

Maintainability
F
6 days
Test Coverage
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Samplers for Contexts.

  Each sampler class should define __call__(batch_size).
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
slim = tf.contrib.slim
import gin.tf


@gin.configurable
class BaseSampler(object):
  """Base sampler."""

  def __init__(self, context_spec, context_range=None, k=2, scope='sampler'):
    """Construct a base sampler.

    Args:
      context_spec: A context spec.
      context_range: A tuple of (minval, max), where minval, maxval are floats
        or Numpy arrays with the same shape as the context.
      scope: A string denoting scope.
    """
    self._context_spec = context_spec
    self._context_range = context_range
    self._k = k
    self._scope = scope

  def __call__(self, batch_size, **kwargs):
    raise NotImplementedError

  def set_replay(self, replay=None):
    pass

  def _validate_contexts(self, contexts):
    """Validate if contexts have right spec.

    Args:
      contexts: A [batch_size, num_contexts_dim] tensor.
    Raises:
      ValueError: If shape or dtype mismatches that of spec.
    """
    if contexts[0].shape != self._context_spec.shape:
      raise ValueError('contexts has invalid shape %s wrt spec shape %s' %
                       (contexts[0].shape, self._context_spec.shape))
    if contexts.dtype != self._context_spec.dtype:
      raise ValueError('contexts has invalid dtype %s wrt spec dtype %s' %
                       (contexts.dtype, self._context_spec.dtype))


@gin.configurable
class ZeroSampler(BaseSampler):
  """Zero sampler."""

  def __call__(self, batch_size, **kwargs):
    """Sample a batch of context.

    Args:
      batch_size: Batch size.
    Returns:
      Two [batch_size, num_context_dims] tensors.
    """
    contexts = tf.zeros(
        dtype=self._context_spec.dtype,
        shape=[
            batch_size,
        ] + self._context_spec.shape.as_list())
    return contexts, contexts


@gin.configurable
class BinarySampler(BaseSampler):
  """Binary sampler."""

  def __init__(self, probs=0.5, *args, **kwargs):
    """Constructor."""
    super(BinarySampler, self).__init__(*args, **kwargs)
    self._probs = probs

  def __call__(self, batch_size, **kwargs):
    """Sample a batch of context."""
    spec = self._context_spec
    contexts = tf.random_uniform(
        shape=[
            batch_size,
        ] + spec.shape.as_list(), dtype=tf.float32)
    contexts = tf.cast(tf.greater(contexts, self._probs), dtype=spec.dtype)
    return contexts, contexts


@gin.configurable
class RandomSampler(BaseSampler):
  """Random sampler."""

  def __call__(self, batch_size, **kwargs):
    """Sample a batch of context.

    Args:
      batch_size: Batch size.
    Returns:
      Two [batch_size, num_context_dims] tensors.
    """
    spec = self._context_spec
    context_range = self._context_range
    if isinstance(context_range[0], (int, float)):
      contexts = tf.random_uniform(
          shape=[
              batch_size,
          ] + spec.shape.as_list(),
          minval=context_range[0],
          maxval=context_range[1],
          dtype=spec.dtype)
    elif isinstance(context_range[0], (list, tuple, np.ndarray)):
      assert len(spec.shape.as_list()) == 1
      assert spec.shape.as_list()[0] == len(context_range[0])
      assert spec.shape.as_list()[0] == len(context_range[1])
      contexts = tf.concat(
          [
              tf.random_uniform(
                  shape=[
                      batch_size, 1,
                  ] + spec.shape.as_list()[1:],
                  minval=context_range[0][i],
                  maxval=context_range[1][i],
                  dtype=spec.dtype) for i in range(spec.shape.as_list()[0])
          ],
          axis=1)
    else: raise NotImplementedError(context_range)
    self._validate_contexts(contexts)
    state, next_state = kwargs['state'], kwargs['next_state']
    if state is not None and next_state is not None:
      pass
      #contexts = tf.concat(
      #    [tf.random_normal(tf.shape(state[:, :self._k]), dtype=tf.float64) +
      #     tf.random_shuffle(state[:, :self._k]),
      #     contexts[:, self._k:]], 1)

    return contexts, contexts


@gin.configurable
class ScheduledSampler(BaseSampler):
  """Scheduled sampler."""

  def __init__(self,
               scope='default',
               values=None,
               scheduler='cycle',
               scheduler_params=None,
               *args, **kwargs):
    """Construct sampler.

    Args:
      scope: Scope name.
      values: A list of numbers or [num_context_dim] Numpy arrays
        representing the values to cycle.
      scheduler: scheduler type.
      scheduler_params: scheduler parameters.
      *args: arguments.
      **kwargs: keyword arguments.
    """
    super(ScheduledSampler, self).__init__(*args, **kwargs)
    self._scope = scope
    self._values = values
    self._scheduler = scheduler
    self._scheduler_params = scheduler_params or {}
    assert self._values is not None and len(
        self._values), 'must provide non-empty values.'
    self._n = len(self._values)
    # TODO(shanegu): move variable creation outside. resolve tf.cond problem.
    self._count = 0
    self._i = tf.Variable(
        tf.zeros(shape=(), dtype=tf.int32),
        name='%s-scheduled_sampler_%d' % (self._scope, self._count))
    self._values = tf.constant(self._values, dtype=self._context_spec.dtype)

  def __call__(self, batch_size, **kwargs):
    """Sample a batch of context.

    Args:
      batch_size: Batch size.
    Returns:
      Two [batch_size, num_context_dims] tensors.
    """
    spec = self._context_spec
    next_op = self._next(self._i)
    with tf.control_dependencies([next_op]):
      value = self._values[self._i]
      if value.get_shape().as_list():
        values = tf.tile(
            tf.expand_dims(value, 0), (batch_size,) + (1,) * spec.shape.ndims)
      else:
        values = value + tf.zeros(
            shape=[
                batch_size,
            ] + spec.shape.as_list(), dtype=spec.dtype)
    self._validate_contexts(values)
    self._count += 1
    return values, values

  def _next(self, i):
    """Return op that increments pointer to next value.

    Args:
      i: A tensorflow integer variable.
    Returns:
      Op that increments pointer.
    """
    if self._scheduler == 'cycle':
      inc = ('inc' in self._scheduler_params and
             self._scheduler_params['inc']) or 1
      return tf.assign(i, tf.mod(i+inc, self._n))
    else:
      raise NotImplementedError(self._scheduler)


@gin.configurable
class ReplaySampler(BaseSampler):
  """Replay sampler."""

  def __init__(self,
               prefetch_queue_capacity=2,
               override_indices=None,
               state_indices=None,
               *args,
               **kwargs):
    """Construct sampler.

    Args:
      prefetch_queue_capacity: Capacity for prefetch queue.
      override_indices: Override indices.
      state_indices: Select certain indices from state dimension.
      *args: arguments.
      **kwargs: keyword arguments.
    """
    super(ReplaySampler, self).__init__(*args, **kwargs)
    self._prefetch_queue_capacity = prefetch_queue_capacity
    self._override_indices = override_indices
    self._state_indices = state_indices

  def set_replay(self, replay):
    """Set replay.

    Args:
      replay: A replay buffer.
    """
    self._replay = replay

  def __call__(self, batch_size, **kwargs):
    """Sample a batch of context.

    Args:
      batch_size: Batch size.
    Returns:
      Two [batch_size, num_context_dims] tensors.
    """
    batch = self._replay.GetRandomBatch(batch_size)
    next_states = batch[4]
    if self._prefetch_queue_capacity > 0:
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [next_states],
          capacity=self._prefetch_queue_capacity,
          name='%s/batch_context_queue' % self._scope)
      next_states = batch_queue.dequeue()
    if self._override_indices is not None:
      assert self._context_range is not None and isinstance(
          self._context_range[0], (int, long, float))
      next_states = tf.concat(
          [
              tf.random_uniform(
                  shape=next_states[:, :1].shape,
                  minval=self._context_range[0],
                  maxval=self._context_range[1],
                  dtype=next_states.dtype)
              if i in self._override_indices else next_states[:, i:i + 1]
              for i in range(self._context_spec.shape.as_list()[0])
          ],
          axis=1)
    if self._state_indices is not None:
      next_states = tf.concat(
          [
              next_states[:, i:i + 1]
              for i in range(self._context_spec.shape.as_list()[0])
          ],
          axis=1)
    self._validate_contexts(next_states)
    return next_states, next_states


@gin.configurable
class TimeSampler(BaseSampler):
  """Time Sampler."""

  def __init__(self, minval=0, maxval=1, timestep=-1, *args, **kwargs):
    """Construct sampler.

    Args:
      minval: Min value integer.
      maxval: Max value integer.
      timestep: Time step between states and next_states.
      *args: arguments.
      **kwargs: keyword arguments.
    """
    super(TimeSampler, self).__init__(*args, **kwargs)
    assert self._context_spec.shape.as_list() == [1]
    self._minval = minval
    self._maxval = maxval
    self._timestep = timestep

  def __call__(self, batch_size, **kwargs):
    """Sample a batch of context.

    Args:
      batch_size: Batch size.
    Returns:
      Two [batch_size, num_context_dims] tensors.
    """
    if self._maxval == self._minval:
      contexts = tf.constant(
          self._maxval, shape=[batch_size, 1], dtype=tf.int32)
    else:
      contexts = tf.random_uniform(
          shape=[batch_size, 1],
          dtype=tf.int32,
          maxval=self._maxval,
          minval=self._minval)
    next_contexts = tf.maximum(contexts + self._timestep, 0)

    return tf.cast(
        contexts, dtype=self._context_spec.dtype), tf.cast(
            next_contexts, dtype=self._context_spec.dtype)


@gin.configurable
class ConstantSampler(BaseSampler):
  """Constant sampler."""

  def __init__(self, value=None, *args, **kwargs):
    """Construct sampler.

    Args:
      value: A list or Numpy array for values of the constant.
      *args: arguments.
      **kwargs: keyword arguments.
    """
    super(ConstantSampler, self).__init__(*args, **kwargs)
    self._value = value

  def __call__(self, batch_size, **kwargs):
    """Sample a batch of context.

    Args:
      batch_size: Batch size.
    Returns:
      Two [batch_size, num_context_dims] tensors.
    """
    spec = self._context_spec
    value_ = tf.constant(self._value, shape=spec.shape, dtype=spec.dtype)
    values = tf.tile(
        tf.expand_dims(value_, 0), (batch_size,) + (1,) * spec.shape.ndims)
    self._validate_contexts(values)
    return values, values


@gin.configurable
class DirectionSampler(RandomSampler):
  """Direction sampler."""

  def __call__(self, batch_size, **kwargs):
    """Sample a batch of context.

    Args:
      batch_size: Batch size.
    Returns:
      Two [batch_size, num_context_dims] tensors.
    """
    spec = self._context_spec
    context_range = self._context_range
    if isinstance(context_range[0], (int, float)):
      contexts = tf.random_uniform(
          shape=[
              batch_size,
          ] + spec.shape.as_list(),
          minval=context_range[0],
          maxval=context_range[1],
          dtype=spec.dtype)
    elif isinstance(context_range[0], (list, tuple, np.ndarray)):
      assert len(spec.shape.as_list()) == 1
      assert spec.shape.as_list()[0] == len(context_range[0])
      assert spec.shape.as_list()[0] == len(context_range[1])
      contexts = tf.concat(
          [
              tf.random_uniform(
                  shape=[
                      batch_size, 1,
                  ] + spec.shape.as_list()[1:],
                  minval=context_range[0][i],
                  maxval=context_range[1][i],
                  dtype=spec.dtype) for i in range(spec.shape.as_list()[0])
          ],
          axis=1)
    else: raise NotImplementedError(context_range)
    self._validate_contexts(contexts)
    if 'sampler_fn' in kwargs:
      other_contexts = kwargs['sampler_fn']()
    else:
      other_contexts = contexts
    state, next_state = kwargs['state'], kwargs['next_state']
    if state is not None and next_state is not None:
      my_context_range = (np.array(context_range[1]) - np.array(context_range[0])) / 2 * np.ones(spec.shape.as_list())
      contexts = tf.concat(
          [0.1 * my_context_range[:self._k] *
           tf.random_normal(tf.shape(state[:, :self._k]), dtype=state.dtype) +
           tf.random_shuffle(state[:, :self._k]) - state[:, :self._k],
           other_contexts[:, self._k:]], 1)
      #contexts = tf.Print(contexts,
      #                    [contexts, tf.reduce_max(contexts, 0),
      #                     tf.reduce_min(state, 0), tf.reduce_max(state, 0)], 'contexts', summarize=15)
      next_contexts = tf.concat( #LALA
          [state[:, :self._k] + contexts[:, :self._k] - next_state[:, :self._k],
           other_contexts[:, self._k:]], 1)
      next_contexts = contexts  #LALA cosine
    else:
      next_contexts = contexts
    return tf.stop_gradient(contexts), tf.stop_gradient(next_contexts)