tensorflow/models

View on GitHub
research/efficient-hrl/train_utils.py

Summary

Maintainability
C
7 hrs
Test Coverage
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r""""""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from collections import namedtuple
import os
import time

import tensorflow as tf

import gin.tf

flags = tf.app.flags


flags.DEFINE_multi_string('config_file', None,
                          'List of paths to the config files.')
flags.DEFINE_multi_string('params', None,
                          'Newline separated list of Gin parameter bindings.')

flags.DEFINE_string('train_dir', None,
                    'Directory for writing logs/summaries during training.')
flags.DEFINE_string('master', 'local',
                    'BNS name of the TensorFlow master to use.')
flags.DEFINE_integer('task', 0, 'task id')
flags.DEFINE_integer('save_interval_secs', 300, 'The frequency at which '
                     'checkpoints are saved, in seconds.')
flags.DEFINE_integer('save_summaries_secs', 30, 'The frequency at which '
                     'summaries are saved, in seconds.')
flags.DEFINE_boolean('summarize_gradients', False,
                     'Whether to generate gradient summaries.')

FLAGS = flags.FLAGS

TrainOps = namedtuple('TrainOps',
                      ['train_op', 'meta_train_op', 'collect_experience_op'])


class TrainStep(object):
  """Handles training step."""

  def __init__(self,
               max_number_of_steps=0,
               num_updates_per_observation=1,
               num_collect_per_update=1,
               num_collect_per_meta_update=1,
               log_every_n_steps=1,
               policy_save_fn=None,
               save_policy_every_n_steps=0,
               should_stop_early=None):
    """Returns a function that is executed at each step of slim training.

    Args:
      max_number_of_steps: Optional maximum number of train steps to take.
      num_updates_per_observation: Number of updates per observation.
      log_every_n_steps: The frequency, in terms of global steps, that the loss
      and global step and logged.
      policy_save_fn: A tf.Saver().save function to save the policy.
      save_policy_every_n_steps: How frequently to save the policy.
      should_stop_early: Optional hook to report whether training should stop.
    Raises:
      ValueError: If policy_save_fn is not provided when
        save_policy_every_n_steps > 0.
    """
    if save_policy_every_n_steps and policy_save_fn is None:
      raise ValueError(
          'policy_save_fn is required when save_policy_every_n_steps > 0')
    self.max_number_of_steps = max_number_of_steps
    self.num_updates_per_observation = num_updates_per_observation
    self.num_collect_per_update = num_collect_per_update
    self.num_collect_per_meta_update = num_collect_per_meta_update
    self.log_every_n_steps = log_every_n_steps
    self.policy_save_fn = policy_save_fn
    self.save_policy_every_n_steps = save_policy_every_n_steps
    self.should_stop_early = should_stop_early
    self.last_global_step_val = 0
    self.train_op_fn = None
    self.collect_and_train_fn = None
    tf.logging.info('Training for %d max_number_of_steps',
                    self.max_number_of_steps)

  def train_step(self, sess, train_ops, global_step, _):
    """This function will be called at each step of training.

    This represents one step of the DDPG algorithm and can include:
    1. collect a <state, action, reward, next_state> transition
    2. update the target network
    3. train the actor
    4. train the critic

    Args:
      sess: A Tensorflow session.
      train_ops: A DdpgTrainOps tuple of train ops to run.
      global_step: The global step.

    Returns:
      A scalar total loss.
      A boolean should stop.
    """
    start_time = time.time()
    if self.train_op_fn is None:
      self.train_op_fn = sess.make_callable([train_ops.train_op, global_step])
      self.meta_train_op_fn = sess.make_callable([train_ops.meta_train_op, global_step])
      self.collect_fn = sess.make_callable([train_ops.collect_experience_op, global_step])
      self.collect_and_train_fn = sess.make_callable(
          [train_ops.train_op, global_step, train_ops.collect_experience_op])
      self.collect_and_meta_train_fn = sess.make_callable(
          [train_ops.meta_train_op, global_step, train_ops.collect_experience_op])
    for _ in range(self.num_collect_per_update - 1):
      self.collect_fn()
    for _ in range(self.num_updates_per_observation - 1):
      self.train_op_fn()

    total_loss, global_step_val, _ = self.collect_and_train_fn()
    if (global_step_val // self.num_collect_per_meta_update !=
        self.last_global_step_val // self.num_collect_per_meta_update):
      self.meta_train_op_fn()

    time_elapsed = time.time() - start_time
    should_stop = False
    if self.max_number_of_steps:
      should_stop = global_step_val >= self.max_number_of_steps
    if global_step_val != self.last_global_step_val:
      if (self.save_policy_every_n_steps and
          global_step_val // self.save_policy_every_n_steps !=
          self.last_global_step_val // self.save_policy_every_n_steps):
        self.policy_save_fn(sess)

      if (self.log_every_n_steps and
          global_step_val % self.log_every_n_steps == 0):
        tf.logging.info(
            'global step %d: loss = %.4f (%.3f sec/step) (%d steps/sec)',
            global_step_val, total_loss, time_elapsed, 1 / time_elapsed)

    self.last_global_step_val = global_step_val
    stop_early = bool(self.should_stop_early and self.should_stop_early())
    return total_loss, should_stop or stop_early


def create_counter_summaries(counters):
  """Add named summaries to counters, a list of tuples (name, counter)."""
  if counters:
    with tf.name_scope('Counters/'):
      for name, counter in counters:
        tf.summary.scalar(name, counter)


def gen_debug_batch_summaries(batch):
  """Generates summaries for the sampled replay batch."""
  states, actions, rewards, _, next_states = batch
  with tf.name_scope('batch'):
    for s in range(states.get_shape()[-1]):
      tf.summary.histogram('states_%d' % s, states[:, s])
    for s in range(states.get_shape()[-1]):
      tf.summary.histogram('next_states_%d' % s, next_states[:, s])
    for a in range(actions.get_shape()[-1]):
      tf.summary.histogram('actions_%d' % a, actions[:, a])
    tf.summary.histogram('rewards', rewards)