tensorflow/models

View on GitHub
research/pcl_rl/trust_region.py

Summary

Maintainability
D
2 days
Test Coverage
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Trust region optimization.

A lot of this is adapted from other's code.
See Schulman's Modular RL, wojzaremba's TRPO, etc.

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from six.moves import xrange
import tensorflow as tf
import numpy as np


def var_size(v):
  return int(np.prod([int(d) for d in v.shape]))


def gradients(loss, var_list):
  grads = tf.gradients(loss, var_list)
  return [g if g is not None else tf.zeros(v.shape)
          for g, v in zip(grads, var_list)]

def flatgrad(loss, var_list):
  grads = gradients(loss, var_list)
  return tf.concat([tf.reshape(grad, [-1])
                    for (v, grad) in zip(var_list, grads)
                    if grad is not None], 0)


def get_flat(var_list):
  return tf.concat([tf.reshape(v, [-1]) for v in var_list], 0)


def set_from_flat(var_list, flat_theta):
  assigns = []
  shapes = [v.shape for v in var_list]
  sizes = [var_size(v) for v in var_list]

  start = 0
  assigns = []
  for (shape, size, v) in zip(shapes, sizes, var_list):
    assigns.append(v.assign(
        tf.reshape(flat_theta[start:start + size], shape)))
    start += size
  assert start == sum(sizes)

  return tf.group(*assigns)


class TrustRegionOptimization(object):

  def __init__(self, max_divergence=0.1, cg_damping=0.1):
    self.max_divergence = max_divergence
    self.cg_damping = cg_damping

  def setup_placeholders(self):
    self.flat_tangent = tf.placeholder(tf.float32, [None], 'flat_tangent')
    self.flat_theta = tf.placeholder(tf.float32, [None], 'flat_theta')

  def setup(self, var_list, raw_loss, self_divergence,
            divergence=None):
    self.setup_placeholders()

    self.raw_loss = raw_loss
    self.divergence = divergence
    self.loss_flat_gradient = flatgrad(raw_loss, var_list)
    self.divergence_gradient = gradients(self_divergence, var_list)

    shapes = [var.shape for var in var_list]
    sizes = [var_size(var) for var in var_list]

    start = 0
    tangents = []
    for shape, size in zip(shapes, sizes):
      param = tf.reshape(self.flat_tangent[start:start + size], shape)
      tangents.append(param)
      start += size
    assert start == sum(sizes)

    self.grad_vector_product = sum(
        tf.reduce_sum(g * t) for (g, t) in zip(self.divergence_gradient, tangents))
    self.fisher_vector_product = flatgrad(self.grad_vector_product, var_list)

    self.flat_vars = get_flat(var_list)
    self.set_vars = set_from_flat(var_list, self.flat_theta)

  def optimize(self, sess, feed_dict):
    old_theta = sess.run(self.flat_vars)
    loss_flat_grad = sess.run(self.loss_flat_gradient,
                              feed_dict=feed_dict)

    def calc_fisher_vector_product(tangent):
      feed_dict[self.flat_tangent] = tangent
      fvp = sess.run(self.fisher_vector_product,
                     feed_dict=feed_dict)
      fvp += self.cg_damping * tangent
      return fvp

    step_dir = conjugate_gradient(calc_fisher_vector_product, -loss_flat_grad)

    shs = 0.5 * step_dir.dot(calc_fisher_vector_product(step_dir))
    lm = np.sqrt(shs / self.max_divergence)
    fullstep = step_dir / lm
    neggdotstepdir = -loss_flat_grad.dot(step_dir)

    def calc_loss(theta):
      sess.run(self.set_vars, feed_dict={self.flat_theta: theta})
      if self.divergence is None:
        return sess.run(self.raw_loss, feed_dict=feed_dict), True
      else:
        raw_loss, divergence = sess.run(
            [self.raw_loss, self.divergence], feed_dict=feed_dict)
        return raw_loss, divergence < self.max_divergence

    # find optimal theta
    theta = linesearch(calc_loss, old_theta, fullstep, neggdotstepdir / lm)
    if self.divergence is not None:
      final_divergence = sess.run(self.divergence, feed_dict=feed_dict)
    else:
      final_divergence = None

    # set vars accordingly
    if final_divergence is None or final_divergence < self.max_divergence:
      sess.run(self.set_vars, feed_dict={self.flat_theta: theta})
    else:
      sess.run(self.set_vars, feed_dict={self.flat_theta: old_theta})


def conjugate_gradient(f_Ax, b, cg_iters=10, residual_tol=1e-10):
  p = b.copy()
  r = b.copy()
  x = np.zeros_like(b)
  rdotr = r.dot(r)
  for i in xrange(cg_iters):
    z = f_Ax(p)
    v = rdotr / p.dot(z)
    x += v * p
    r -= v * z
    newrdotr = r.dot(r)
    mu = newrdotr / rdotr
    p = r + mu * p
    rdotr = newrdotr
    if rdotr < residual_tol:
      break
  return x


def linesearch(f, x, fullstep, expected_improve_rate):
  accept_ratio = 0.1
  max_backtracks = 10

  fval, _ = f(x)
  for (_n_backtracks, stepfrac) in enumerate(.5 ** np.arange(max_backtracks)):
    xnew = x + stepfrac * fullstep
    newfval, valid = f(xnew)
    if not valid:
      continue
    actual_improve = fval - newfval
    expected_improve = expected_improve_rate * stepfrac
    ratio = actual_improve / expected_improve
    if ratio > accept_ratio and actual_improve > 0:
      return xnew

  return x