tensorflow/models

View on GitHub
research/pcl_rl/trainer.py

Summary

Maintainability
D
2 days
Test Coverage
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Trainer for coordinating single or multi-replica training.

Main point of entry for running models.  Specifies most of
the parameters used by different algorithms.
"""

import tensorflow as tf
import numpy as np
import random
import os
import pickle

from six.moves import xrange
import controller
import model
import policy
import baseline
import objective
import full_episode_objective
import trust_region
import optimizers
import replay_buffer
import expert_paths
import gym_wrapper
import env_spec

app = tf.app
flags = tf.flags
logging = tf.logging
gfile = tf.gfile

FLAGS = flags.FLAGS

flags.DEFINE_string('env', 'Copy-v0', 'environment name')
flags.DEFINE_integer('batch_size', 100, 'batch size')
flags.DEFINE_integer('replay_batch_size', None, 'replay batch size; defaults to batch_size')
flags.DEFINE_integer('num_samples', 1,
                     'number of samples from each random seed initialization')
flags.DEFINE_integer('max_step', 200, 'max number of steps to train on')
flags.DEFINE_integer('cutoff_agent', 0,
                     'number of steps at which to cut-off agent. '
                     'Defaults to always cutoff')
flags.DEFINE_integer('num_steps', 100000, 'number of training steps')
flags.DEFINE_integer('validation_frequency', 100,
                     'every so many steps, output some stats')

flags.DEFINE_float('target_network_lag', 0.95,
                   'This exponential decay on online network yields target '
                   'network')
flags.DEFINE_string('sample_from', 'online',
                    'Sample actions from "online" network or "target" network')

flags.DEFINE_string('objective', 'pcl',
                    'pcl/upcl/a3c/trpo/reinforce/urex')
flags.DEFINE_bool('trust_region_p', False,
                  'use trust region for policy optimization')
flags.DEFINE_string('value_opt', None,
                    'leave as None to optimize it along with policy '
                    '(using critic_weight). Otherwise set to '
                    '"best_fit" (least squares regression), "lbfgs", or "grad"')
flags.DEFINE_float('max_divergence', 0.01,
                   'max divergence (i.e. KL) to allow during '
                   'trust region optimization')

flags.DEFINE_float('learning_rate', 0.01, 'learning rate')
flags.DEFINE_float('clip_norm', 5.0, 'clip norm')
flags.DEFINE_float('clip_adv', 0.0, 'Clip advantages at this value.  '
                   'Leave as 0 to not clip at all.')
flags.DEFINE_float('critic_weight', 0.1, 'critic weight')
flags.DEFINE_float('tau', 0.1, 'entropy regularizer.'
                   'If using decaying tau, this is the final value.')
flags.DEFINE_float('tau_decay', None,
                   'decay tau by this much every 100 steps')
flags.DEFINE_float('tau_start', 0.1,
                   'start tau at this value')
flags.DEFINE_float('eps_lambda', 0.0, 'relative entropy regularizer.')
flags.DEFINE_bool('update_eps_lambda', False,
                  'Update lambda automatically based on last 100 episodes.')
flags.DEFINE_float('gamma', 1.0, 'discount')
flags.DEFINE_integer('rollout', 10, 'rollout')
flags.DEFINE_bool('use_target_values', False,
                  'use target network for value estimates')
flags.DEFINE_bool('fixed_std', True,
                  'fix the std in Gaussian distributions')
flags.DEFINE_bool('input_prev_actions', True,
                  'input previous actions to policy network')
flags.DEFINE_bool('recurrent', True,
                  'use recurrent connections')
flags.DEFINE_bool('input_time_step', False,
                  'input time step into value calucations')

flags.DEFINE_bool('use_online_batch', True, 'train on batches as they are sampled')
flags.DEFINE_bool('batch_by_steps', False,
                  'ensure each training batch has batch_size * max_step steps')
flags.DEFINE_bool('unify_episodes', False,
                  'Make sure replay buffer holds entire episodes, '
                  'even across distinct sampling steps')
flags.DEFINE_integer('replay_buffer_size', 5000, 'replay buffer size')
flags.DEFINE_float('replay_buffer_alpha', 0.5, 'replay buffer alpha param')
flags.DEFINE_integer('replay_buffer_freq', 0,
                     'replay buffer frequency (only supports -1/0/1)')
flags.DEFINE_string('eviction', 'rand',
                    'how to evict from replay buffer: rand/rank/fifo')
flags.DEFINE_string('prioritize_by', 'rewards',
                    'Prioritize replay buffer by "rewards" or "step"')
flags.DEFINE_integer('num_expert_paths', 0,
                     'number of expert paths to seed replay buffer with')

flags.DEFINE_integer('internal_dim', 256, 'RNN internal dim')
flags.DEFINE_integer('value_hidden_layers', 0,
                     'number of hidden layers in value estimate')
flags.DEFINE_integer('tf_seed', 42, 'random seed for tensorflow')

flags.DEFINE_string('save_trajectories_dir', None,
                    'directory to save trajectories to, if desired')
flags.DEFINE_string('load_trajectories_file', None,
                    'file to load expert trajectories from')

# supervisor flags
flags.DEFINE_bool('supervisor', False, 'use supervisor training')
flags.DEFINE_integer('task_id', 0, 'task id')
flags.DEFINE_integer('ps_tasks', 0, 'number of ps tasks')
flags.DEFINE_integer('num_replicas', 1, 'number of replicas used')
flags.DEFINE_string('master', 'local', 'name of master')
flags.DEFINE_string('save_dir', '', 'directory to save model to')
flags.DEFINE_string('load_path', '', 'path of saved model to load (if none in save_dir)')


class Trainer(object):
  """Coordinates single or multi-replica training."""

  def __init__(self):
    self.batch_size = FLAGS.batch_size
    self.replay_batch_size = FLAGS.replay_batch_size
    if self.replay_batch_size is None:
      self.replay_batch_size = self.batch_size
    self.num_samples = FLAGS.num_samples

    self.env_str = FLAGS.env
    self.env = gym_wrapper.GymWrapper(self.env_str,
                                      distinct=FLAGS.batch_size // self.num_samples,
                                      count=self.num_samples)
    self.eval_env = gym_wrapper.GymWrapper(
        self.env_str,
        distinct=FLAGS.batch_size // self.num_samples,
        count=self.num_samples)
    self.env_spec = env_spec.EnvSpec(self.env.get_one())

    self.max_step = FLAGS.max_step
    self.cutoff_agent = FLAGS.cutoff_agent
    self.num_steps = FLAGS.num_steps
    self.validation_frequency = FLAGS.validation_frequency

    self.target_network_lag = FLAGS.target_network_lag
    self.sample_from = FLAGS.sample_from
    assert self.sample_from in ['online', 'target']

    self.critic_weight = FLAGS.critic_weight
    self.objective = FLAGS.objective
    self.trust_region_p = FLAGS.trust_region_p
    self.value_opt = FLAGS.value_opt
    assert not self.trust_region_p or self.objective in ['pcl', 'trpo']
    assert self.objective != 'trpo' or self.trust_region_p
    assert self.value_opt is None or self.value_opt == 'None' or \
        self.critic_weight == 0.0
    self.max_divergence = FLAGS.max_divergence

    self.learning_rate = FLAGS.learning_rate
    self.clip_norm = FLAGS.clip_norm
    self.clip_adv = FLAGS.clip_adv
    self.tau = FLAGS.tau
    self.tau_decay = FLAGS.tau_decay
    self.tau_start = FLAGS.tau_start
    self.eps_lambda = FLAGS.eps_lambda
    self.update_eps_lambda = FLAGS.update_eps_lambda
    self.gamma = FLAGS.gamma
    self.rollout = FLAGS.rollout
    self.use_target_values = FLAGS.use_target_values
    self.fixed_std = FLAGS.fixed_std
    self.input_prev_actions = FLAGS.input_prev_actions
    self.recurrent = FLAGS.recurrent
    assert not self.trust_region_p or not self.recurrent
    self.input_time_step = FLAGS.input_time_step
    assert not self.input_time_step or (self.cutoff_agent <= self.max_step)

    self.use_online_batch = FLAGS.use_online_batch
    self.batch_by_steps = FLAGS.batch_by_steps
    self.unify_episodes = FLAGS.unify_episodes
    if self.unify_episodes:
      assert self.batch_size == 1

    self.replay_buffer_size = FLAGS.replay_buffer_size
    self.replay_buffer_alpha = FLAGS.replay_buffer_alpha
    self.replay_buffer_freq = FLAGS.replay_buffer_freq
    assert self.replay_buffer_freq in [-1, 0, 1]
    self.eviction = FLAGS.eviction
    self.prioritize_by = FLAGS.prioritize_by
    assert self.prioritize_by in ['rewards', 'step']
    self.num_expert_paths = FLAGS.num_expert_paths

    self.internal_dim = FLAGS.internal_dim
    self.value_hidden_layers = FLAGS.value_hidden_layers
    self.tf_seed = FLAGS.tf_seed

    self.save_trajectories_dir = FLAGS.save_trajectories_dir
    self.save_trajectories_file = (
        os.path.join(
            self.save_trajectories_dir, self.env_str.replace('-', '_'))
        if self.save_trajectories_dir else None)
    self.load_trajectories_file = FLAGS.load_trajectories_file

    self.hparams = dict((attr, getattr(self, attr))
                        for attr in dir(self)
                        if not attr.startswith('__') and
                        not callable(getattr(self, attr)))

  def hparams_string(self):
    return '\n'.join('%s: %s' % item for item in sorted(self.hparams.items()))

  def get_objective(self):
    tau = self.tau
    if self.tau_decay is not None:
      assert self.tau_start >= self.tau
      tau = tf.maximum(
          tf.train.exponential_decay(
              self.tau_start, self.global_step, 100, self.tau_decay),
          self.tau)

    if self.objective in ['pcl', 'a3c', 'trpo', 'upcl']:
      cls = (objective.PCL if self.objective in ['pcl', 'upcl'] else
             objective.TRPO if self.objective == 'trpo' else
             objective.ActorCritic)
      policy_weight = 1.0

      return cls(self.learning_rate,
                 clip_norm=self.clip_norm,
                 policy_weight=policy_weight,
                 critic_weight=self.critic_weight,
                 tau=tau, gamma=self.gamma, rollout=self.rollout,
                 eps_lambda=self.eps_lambda, clip_adv=self.clip_adv,
                 use_target_values=self.use_target_values)
    elif self.objective in ['reinforce', 'urex']:
      cls = (full_episode_objective.Reinforce
             if self.objective == 'reinforce' else
             full_episode_objective.UREX)
      return cls(self.learning_rate,
                 clip_norm=self.clip_norm,
                 num_samples=self.num_samples,
                 tau=tau, bonus_weight=1.0)  # TODO: bonus weight?
    else:
      assert False, 'Unknown objective %s' % self.objective

  def get_policy(self):
    if self.recurrent:
      cls = policy.Policy
    else:
      cls = policy.MLPPolicy
    return cls(self.env_spec, self.internal_dim,
               fixed_std=self.fixed_std,
               recurrent=self.recurrent,
               input_prev_actions=self.input_prev_actions)

  def get_baseline(self):
    cls = (baseline.UnifiedBaseline if self.objective == 'upcl' else
           baseline.Baseline)
    return cls(self.env_spec, self.internal_dim,
               input_prev_actions=self.input_prev_actions,
               input_time_step=self.input_time_step,
               input_policy_state=self.recurrent,  # may want to change this
               n_hidden_layers=self.value_hidden_layers,
               hidden_dim=self.internal_dim,
               tau=self.tau)

  def get_trust_region_p_opt(self):
    if self.trust_region_p:
      return trust_region.TrustRegionOptimization(
          max_divergence=self.max_divergence)
    else:
      return None

  def get_value_opt(self):
    if self.value_opt == 'grad':
      return optimizers.GradOptimization(
          learning_rate=self.learning_rate, max_iter=5, mix_frac=0.05)
    elif self.value_opt == 'lbfgs':
      return optimizers.LbfgsOptimization(max_iter=25, mix_frac=0.1)
    elif self.value_opt == 'best_fit':
      return optimizers.BestFitOptimization(mix_frac=1.0)
    else:
      return None

  def get_model(self):
    cls = model.Model
    return cls(self.env_spec, self.global_step,
               target_network_lag=self.target_network_lag,
               sample_from=self.sample_from,
               get_policy=self.get_policy,
               get_baseline=self.get_baseline,
               get_objective=self.get_objective,
               get_trust_region_p_opt=self.get_trust_region_p_opt,
               get_value_opt=self.get_value_opt)

  def get_replay_buffer(self):
    if self.replay_buffer_freq <= 0:
      return None
    else:
      assert self.objective in ['pcl', 'upcl'], 'Can\'t use replay buffer with %s' % (
          self.objective)
    cls = replay_buffer.PrioritizedReplayBuffer
    return cls(self.replay_buffer_size,
               alpha=self.replay_buffer_alpha,
               eviction_strategy=self.eviction)

  def get_buffer_seeds(self):
    return expert_paths.sample_expert_paths(
        self.num_expert_paths, self.env_str, self.env_spec,
        load_trajectories_file=self.load_trajectories_file)

  def get_controller(self, env):
    """Get controller."""
    cls = controller.Controller
    return cls(env, self.env_spec, self.internal_dim,
               use_online_batch=self.use_online_batch,
               batch_by_steps=self.batch_by_steps,
               unify_episodes=self.unify_episodes,
               replay_batch_size=self.replay_batch_size,
               max_step=self.max_step,
               cutoff_agent=self.cutoff_agent,
               save_trajectories_file=self.save_trajectories_file,
               use_trust_region=self.trust_region_p,
               use_value_opt=self.value_opt not in [None, 'None'],
               update_eps_lambda=self.update_eps_lambda,
               prioritize_by=self.prioritize_by,
               get_model=self.get_model,
               get_replay_buffer=self.get_replay_buffer,
               get_buffer_seeds=self.get_buffer_seeds)

  def do_before_step(self, step):
    pass

  def run(self):
    """Run training."""
    is_chief = FLAGS.task_id == 0 or not FLAGS.supervisor
    sv = None

    def init_fn(sess, saver):
      ckpt = None
      if FLAGS.save_dir and sv is None:
        load_dir = FLAGS.save_dir
        ckpt = tf.train.get_checkpoint_state(load_dir)
      if ckpt and ckpt.model_checkpoint_path:
        logging.info('restoring from %s', ckpt.model_checkpoint_path)
        saver.restore(sess, ckpt.model_checkpoint_path)
      elif FLAGS.load_path:
        logging.info('restoring from %s', FLAGS.load_path)
        saver.restore(sess, FLAGS.load_path)

    if FLAGS.supervisor:
      with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)):
        self.global_step = tf.contrib.framework.get_or_create_global_step()
        tf.set_random_seed(FLAGS.tf_seed)
        self.controller = self.get_controller(self.env)
        self.model = self.controller.model
        self.controller.setup()
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
          self.eval_controller = self.get_controller(self.eval_env)
          self.eval_controller.setup(train=False)

        saver = tf.train.Saver(max_to_keep=10)
        step = self.model.global_step
        sv = tf.Supervisor(logdir=FLAGS.save_dir,
                           is_chief=is_chief,
                           saver=saver,
                           save_model_secs=600,
                           summary_op=None,  # we define it ourselves
                           save_summaries_secs=60,
                           global_step=step,
                           init_fn=lambda sess: init_fn(sess, saver))
        sess = sv.PrepareSession(FLAGS.master)
    else:
      tf.set_random_seed(FLAGS.tf_seed)
      self.global_step = tf.contrib.framework.get_or_create_global_step()
      self.controller = self.get_controller(self.env)
      self.model = self.controller.model
      self.controller.setup()
      with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        self.eval_controller = self.get_controller(self.eval_env)
        self.eval_controller.setup(train=False)

      saver = tf.train.Saver(max_to_keep=10)
      sess = tf.Session()
      sess.run(tf.initialize_all_variables())
      init_fn(sess, saver)

    self.sv = sv
    self.sess = sess

    logging.info('hparams:\n%s', self.hparams_string())

    model_step = sess.run(self.model.global_step)
    if model_step >= self.num_steps:
      logging.info('training has reached final step')
      return

    losses = []
    rewards = []
    all_ep_rewards = []
    for step in xrange(1 + self.num_steps):

      if sv is not None and sv.ShouldStop():
        logging.info('stopping supervisor')
        break

      self.do_before_step(step)

      (loss, summary,
       total_rewards, episode_rewards) = self.controller.train(sess)
      _, greedy_episode_rewards = self.eval_controller.eval(sess)
      self.controller.greedy_episode_rewards = greedy_episode_rewards
      losses.append(loss)
      rewards.append(total_rewards)
      all_ep_rewards.extend(episode_rewards)

      if (random.random() < 0.1 and summary and episode_rewards and
          is_chief and sv and sv._summary_writer):
        sv.summary_computed(sess, summary)

      model_step = sess.run(self.model.global_step)
      if is_chief and step % self.validation_frequency == 0:
        logging.info('at training step %d, model step %d: '
                     'avg loss %f, avg reward %f, '
                     'episode rewards: %f, greedy rewards: %f',
                     step, model_step,
                     np.mean(losses), np.mean(rewards),
                     np.mean(all_ep_rewards),
                     np.mean(greedy_episode_rewards))

        losses = []
        rewards = []
        all_ep_rewards = []

      if model_step >= self.num_steps:
        logging.info('training has reached final step')
        break

    if is_chief and sv is not None:
      logging.info('saving final model to %s', sv.save_path)
      sv.saver.save(sess, sv.save_path, global_step=sv.global_step)


def main(unused_argv):
  logging.set_verbosity(logging.INFO)
  trainer = Trainer()
  trainer.run()


if __name__ == '__main__':
  app.run()