tensorflow/models

View on GitHub
research/adversarial_text/evaluate.py

Summary

Maintainability
A
1 hr
Test Coverage
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates text classification model."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import time

# Dependency imports

import tensorflow as tf

import graphs

flags = tf.app.flags
FLAGS = flags.FLAGS

flags.DEFINE_string('master', '',
                    'BNS name prefix of the Tensorflow eval master, '
                    'or "local".')
flags.DEFINE_string('eval_dir', '/tmp/text_eval',
                    'Directory where to write event logs.')
flags.DEFINE_string('eval_data', 'test', 'Specify which dataset is used. '
                    '("train", "valid", "test") ')

flags.DEFINE_string('checkpoint_dir', '/tmp/text_train',
                    'Directory where to read model checkpoints.')
flags.DEFINE_integer('eval_interval_secs', 60, 'How often to run the eval.')
flags.DEFINE_integer('num_examples', 32, 'Number of examples to run.')
flags.DEFINE_bool('run_once', False, 'Whether to run eval only once.')


def restore_from_checkpoint(sess, saver):
  """Restore model from checkpoint.

  Args:
    sess: Session.
    saver: Saver for restoring the checkpoint.

  Returns:
    bool: Whether the checkpoint was found and restored
  """
  ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
  if not ckpt or not ckpt.model_checkpoint_path:
    tf.logging.info('No checkpoint found at %s', FLAGS.checkpoint_dir)
    return False

  saver.restore(sess, ckpt.model_checkpoint_path)
  return True


def run_eval(eval_ops, summary_writer, saver):
  """Runs evaluation over FLAGS.num_examples examples.

  Args:
    eval_ops: dict<metric name, tuple(value, update_op)>
    summary_writer: Summary writer.
    saver: Saver.

  Returns:
    dict<metric name, value>, with value being the average over all examples.
  """
  sv = tf.train.Supervisor(
      logdir=FLAGS.eval_dir, saver=None, summary_op=None, summary_writer=None)
  with sv.managed_session(
      master=FLAGS.master, start_standard_services=False) as sess:
    if not restore_from_checkpoint(sess, saver):
      return
    sv.start_queue_runners(sess)

    metric_names, ops = zip(*eval_ops.items())
    value_ops, update_ops = zip(*ops)

    value_ops_dict = dict(zip(metric_names, value_ops))

    # Run update ops
    num_batches = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
    tf.logging.info('Running %d batches for evaluation.', num_batches)
    for i in range(num_batches):
      if (i + 1) % 10 == 0:
        tf.logging.info('Running batch %d/%d...', i + 1, num_batches)
      if (i + 1) % 50 == 0:
        _log_values(sess, value_ops_dict)
      sess.run(update_ops)

    _log_values(sess, value_ops_dict, summary_writer=summary_writer)


def _log_values(sess, value_ops, summary_writer=None):
  """Evaluate, log, and write summaries of the eval metrics in value_ops."""
  metric_names, value_ops = zip(*value_ops.items())
  values = sess.run(value_ops)

  tf.logging.info('Eval metric values:')
  summary = tf.summary.Summary()
  for name, val in zip(metric_names, values):
    summary.value.add(tag=name, simple_value=val)
    tf.logging.info('%s = %.3f', name, val)

  if summary_writer is not None:
    global_step_val = sess.run(tf.train.get_global_step())
    tf.logging.info('Finished eval for step ' + str(global_step_val))
    summary_writer.add_summary(summary, global_step_val)


def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.gfile.MakeDirs(FLAGS.eval_dir)
  tf.logging.info('Building eval graph...')
  output = graphs.get_model().eval_graph(FLAGS.eval_data)
  eval_ops, moving_averaged_variables = output

  saver = tf.train.Saver(moving_averaged_variables)
  summary_writer = tf.summary.FileWriter(
      FLAGS.eval_dir, graph=tf.get_default_graph())

  while True:
    run_eval(eval_ops, summary_writer, saver)
    if FLAGS.run_once:
      break
    time.sleep(FLAGS.eval_interval_secs)


if __name__ == '__main__':
  tf.app.run()