tensorflow/tensorflow

View on GitHub
tensorflow/python/training/monitored_session_test.py

Summary

Maintainability
F
1 mo
Test Coverage
# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for monitored_session."""

import collections
import glob
import os
import sys
import threading
import time
import traceback

from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import debug_pb2
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.client import session as session_lib
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribute_coordinator
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_assert
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_v1
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import load as saved_model_load
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import summary_io
from tensorflow.python.training import training_util


def latest_summaries(base_dir):
  """Parse summary events from latest event file in base_dir."""
  file_paths = glob.glob(os.path.join(base_dir, 'events.*'))
  file_path = sorted(file_paths)[-1] if file_paths else None
  latest_events = summary_io.summary_iterator(file_path) if file_path else []
  return [e for e in latest_events if e.HasField('summary')]

class ScaffoldTest(test.TestCase):
  """Scaffold tests."""

  def test_nothing_created_before_finalize(self):
    with ops.Graph().as_default():
      scaffold = monitored_session.Scaffold()
      self.assertEqual(None, scaffold.init_op)
      self.assertEqual(None, scaffold.init_feed_dict)
      self.assertEqual(None, scaffold.init_fn)
      self.assertEqual(None, scaffold.ready_op)
      self.assertEqual(None, scaffold.ready_for_local_init_op)
      self.assertEqual(None, scaffold.local_init_op)
      self.assertEqual(None, scaffold.saver)

  def test_defaults_empty_graph(self):
    with ops.Graph().as_default():
      scaffold = monitored_session.Scaffold()
      variable_v1.VariableV1(1, name='my_var')
      variable_v1.VariableV1(
          2, name='my_local_var', collections=[ops.GraphKeys.LOCAL_VARIABLES])
      scaffold.finalize()
      self.assertTrue(isinstance(scaffold.init_op, ops.Operation))
      self.assertEqual(None, scaffold.init_feed_dict)
      self.assertEqual(None, scaffold.init_fn)
      self.assertTrue(isinstance(scaffold.ready_op, tensor.Tensor))
      self.assertTrue(isinstance(
          scaffold.ready_for_local_init_op, tensor.Tensor))
      self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
      self.assertEqual(None, scaffold.local_init_feed_dict)
      self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
      with self.cached_session() as sess:
        self.assertItemsEqual([b'my_var', b'my_local_var'],
                              sess.run(scaffold.ready_op))
        self.assertItemsEqual([b'my_var'],
                              sess.run(scaffold.ready_for_local_init_op))
        sess.run(scaffold.init_op)
        self.assertEqual(0, len(sess.run(scaffold.ready_for_local_init_op)))
        sess.run(scaffold.local_init_op)
        self.assertEqual(0, len(sess.run(scaffold.ready_op)))

  def test_defaults_no_variables(self):
    with ops.Graph().as_default():
      scaffold = monitored_session.Scaffold()
      constant_op.constant(1, name='my_const')
      scaffold.finalize()
      self.assertTrue(isinstance(scaffold.init_op, ops.Operation))
      self.assertEqual(None, scaffold.init_feed_dict)
      self.assertEqual(None, scaffold.init_fn)
      self.assertTrue(isinstance(scaffold.ready_op, tensor.Tensor))
      self.assertTrue(isinstance(
          scaffold.ready_for_local_init_op, tensor.Tensor))
      self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
      self.assertEqual(None, scaffold.local_init_feed_dict)
      self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))

  def test_caches_values(self):
    with ops.Graph().as_default():
      variable_v1.VariableV1([1])
      scaffold1 = monitored_session.Scaffold()
      scaffold1.finalize()
      scaffold2 = monitored_session.Scaffold()
      scaffold2.finalize()
      self.assertEqual(scaffold1.init_op, scaffold2.init_op)
      self.assertEqual(scaffold1.ready_op, scaffold2.ready_op)
      self.assertEqual(scaffold1.ready_for_local_init_op,
                       scaffold2.ready_for_local_init_op)
      self.assertEqual(scaffold1.local_init_op, scaffold2.local_init_op)
      self.assertEqual(scaffold1.saver, scaffold2.saver)

  def test_raise_error_if_more_than_one_cached_item(self):
    with ops.Graph().as_default():
      variable_v1.VariableV1([1])
      ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
      ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
      with self.assertRaisesRegex(RuntimeError, 'More than one item'):
        monitored_session.Scaffold().finalize()

  def test_uses_passed_values(self):
    with ops.Graph().as_default():
      variable_v1.VariableV1([1])
      saver = saver_lib.Saver()
      scaffold = monitored_session.Scaffold(
          init_op=2,
          init_feed_dict=3,
          init_fn=lambda scaffold, sess: 4,
          ready_op=5,
          ready_for_local_init_op=6,
          local_init_op=7,
          local_init_feed_dict=8,
          saver=saver)
      scaffold.finalize()
      self.assertEqual(2, scaffold.init_op)
      self.assertEqual(3, scaffold.init_feed_dict)
      self.assertTrue(callable(scaffold.init_fn))
      self.assertEqual(5, scaffold.ready_op)
      self.assertEqual(6, scaffold.ready_for_local_init_op)
      self.assertEqual(7, scaffold.local_init_op)
      self.assertEqual(8, scaffold.local_init_feed_dict)
      self.assertEqual(saver, scaffold.saver)

  def test_graph_is_finalized(self):
    with ops.Graph().as_default():
      variable_v1.VariableV1([1])
      monitored_session.Scaffold().finalize()
      with self.assertRaisesRegex(RuntimeError,
                                  'Graph is finalized and cannot be modified'):
        constant_op.constant([0])

  def test_new_scaffold_from_default_scaffold(self):
    scaffold1 = monitored_session.Scaffold()
    with ops.Graph().as_default():
      variable_v1.VariableV1([1])
      saver = saver_lib.Saver()
      scaffold2 = monitored_session.Scaffold(
          init_op=2,
          init_feed_dict=3,
          init_fn=lambda scaffold, sess: 4,
          ready_op=5,
          ready_for_local_init_op=6,
          local_init_op=7,
          local_init_feed_dict=8,
          saver=saver,
          copy_from_scaffold=scaffold1)

      scaffold2.finalize()
      self.assertEqual(2, scaffold2.init_op)
      self.assertEqual(3, scaffold2.init_feed_dict)
      self.assertTrue(callable(scaffold2.init_fn))
      self.assertEqual(5, scaffold2.ready_op)
      self.assertEqual(6, scaffold2.ready_for_local_init_op)
      self.assertEqual(7, scaffold2.local_init_op)
      self.assertEqual(8, scaffold2.local_init_feed_dict)
      self.assertEqual(saver, scaffold2.saver)

  def test_new_scaffold_from_existing_scaffold(self):
    with ops.Graph().as_default():
      variable_v1.VariableV1([1])
      saver = saver_lib.Saver()
      scaffold1 = monitored_session.Scaffold(
          init_op=2,
          init_feed_dict=3,
          init_fn=lambda scaffold, sess: 4,
          ready_op=5,
          ready_for_local_init_op=6,
          local_init_op=7,
          local_init_feed_dict=8,
          saver=saver)

      scaffold2 = monitored_session.Scaffold(
          init_op=4,
          init_feed_dict=6,
          init_fn=lambda scaffold, sess: 8,
          ready_op=10,
          ready_for_local_init_op=12,
          local_init_op=14,
          local_init_feed_dict=15,
          saver=saver,
          copy_from_scaffold=scaffold1)

      scaffold2.finalize()
      self.assertEqual(4, scaffold2.init_op)
      self.assertEqual(6, scaffold2.init_feed_dict)
      self.assertTrue(callable(scaffold2.init_fn))
      self.assertEqual(10, scaffold2.ready_op)
      self.assertEqual(12, scaffold2.ready_for_local_init_op)
      self.assertEqual(14, scaffold2.local_init_op)
      self.assertEqual(15, scaffold2.local_init_feed_dict)
      self.assertEqual(saver, scaffold2.saver)

  def test_copy_from_scaffold_is_scaffold(self):
    with ops.Graph().as_default():
      with self.assertRaisesRegex(
          TypeError, 'copy_from_scaffold is not a Scaffold instance'):
        monitored_session.Scaffold(copy_from_scaffold=1)


def _test_dir(temp_dir, test_name):
  """Create an empty dir to use for tests.

  Args:
    temp_dir: Tmp directory path.
    test_name: Name of the test.

  Returns:
    Absolute path to the test directory.
  """
  test_dir = os.path.join(temp_dir, test_name)
  if os.path.isdir(test_dir):
    for f in glob.glob('%s/*' % test_dir):
      os.remove(f)
  else:
    os.makedirs(test_dir)
  return test_dir


class FakeHook(session_run_hook.SessionRunHook):

  def __init__(self):
    self.should_stop = False
    self.request = None
    self.call_counter = collections.Counter()
    self.last_run_context = None
    self.last_run_values = None

  def begin(self):
    self.call_counter['begin'] += 1

  def after_create_session(self, session, coord):  # pylint: disable=unused-argument
    self.call_counter['after_create_session'] += 1

  def before_run(self, run_context):
    self.call_counter['before_run'] += 1
    self.last_run_context = run_context
    return self.request

  def after_run(self, run_context, run_values):
    self.call_counter['after_run'] += 1
    self.last_run_values = run_values
    if self.should_stop:
      run_context.request_stop()

  def end(self, session):
    self.call_counter['end'] += 1


class MonitoredTrainingSessionTest(test.TestCase):
  """Tests MonitoredTrainingSession."""

  def test_saving_restoring_checkpoint(self):
    logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      with monitored_session.MonitoredTrainingSession(
          is_chief=True, checkpoint_dir=logdir) as session:
        self.assertEqual(0, session.run(gstep))
        self.assertEqual(1, session.run(do_step))
        self.assertEqual(2, session.run(do_step))
      # A restart will find the checkpoint and recover automatically.
      with monitored_session.MonitoredTrainingSession(
          is_chief=True, checkpoint_dir=logdir) as session:
        self.assertEqual(2, session.run(gstep))

  def test_save_checkpoint_steps(self):
    logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_steps')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      new_gstep = state_ops.assign_add(gstep, 1)
      with monitored_session.MonitoredTrainingSession(
          is_chief=True,
          checkpoint_dir=logdir,
          save_checkpoint_steps=100,
          log_step_count_steps=10) as session:
        for _ in range(100):
          session.run(new_gstep)
      # A restart will find the checkpoint and recover automatically.
      with monitored_session.MonitoredTrainingSession(
          is_chief=True, checkpoint_dir=logdir) as session:
        self.assertEqual(100, session.run(gstep))

  def test_save_checkpoint_secs(self):
    logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_secs')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      new_gstep = state_ops.assign_add(gstep, 1)
      with monitored_session.MonitoredTrainingSession(
          is_chief=True,
          checkpoint_dir=logdir,
          save_checkpoint_secs=0.1,
          log_step_count_steps=10) as session:
        session.run(new_gstep)
        time.sleep(0.2)
        for _ in range(10):
          session.run(new_gstep)
      # A restart will find the checkpoint and recover automatically.
      with monitored_session.MonitoredTrainingSession(
          is_chief=True, checkpoint_dir=logdir) as session:
        self.assertEqual(11, session.run(gstep))

  def test_save_restore_checkpoint_v1_saved_model(self):

    def _write_v1_simple_saved_model(export_dir):
      # Create v1 Saved Model with single variable `w0` with value 5.0.
      builder = saved_model_builder.SavedModelBuilder(export_dir)
      with ops.Graph().as_default():
        _ = resource_variable_ops.ResourceVariable(5.0)
        with self.cached_session() as session:
          session.run(variables.global_variables_initializer())
          builder.add_meta_graph_and_variables(session, ['foo'])
      builder.save()

    test_dir = _test_dir(self.get_temp_dir(), 'saved_model')
    _write_v1_simple_saved_model(test_dir)

    with ops.Graph().as_default():
      # Load saved model with `load_v1_in_v2`.
      model = saved_model_load.load(test_dir)
      w0 = model.variables[0]
      # Define operation that increments `w0`.
      w_add = w0.assign_add(1.)
      gstep = training_util.get_or_create_global_step()
      new_gstep = state_ops.assign_add(gstep, 1)

      with monitored_session.MonitoredTrainingSession(
          checkpoint_dir=test_dir) as session:
        w1 = session.run(w_add)
        self.assertEqual(w1, 6.)
        session.run(new_gstep)
        w2 = session.run(w_add)
        self.assertEqual(w2, 7.)

      # Stop and resume training.
      with monitored_session.MonitoredTrainingSession(
          checkpoint_dir=test_dir) as session:
        # `w0` saves its value of 7.
        w3 = session.run(w_add)
        self.assertEqual(w3, 8.)

  def test_summaries_steps(self):
    logdir = _test_dir(self.get_temp_dir(), 'test_summaries_steps')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      new_gstep = state_ops.assign_add(gstep, 1)
      summary.scalar('my_summary_tag', new_gstep * 2)
      with monitored_session.MonitoredTrainingSession(
          is_chief=True,
          checkpoint_dir=logdir,
          save_summaries_steps=100,
          log_step_count_steps=10) as session:
        for _ in range(101):
          session.run(new_gstep)
    summaries = latest_summaries(logdir)
    tags = [s.summary.value[0].tag for s in summaries]
    self.assertIn('my_summary_tag', tags)
    self.assertIn('global_step/sec', tags)

  def test_summaries_secs(self):
    logdir = _test_dir(self.get_temp_dir(), 'test_summaries_secs')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      new_gstep = state_ops.assign_add(gstep, 1)
      summary.scalar('my_summary_tag', new_gstep * 2)
      with monitored_session.MonitoredTrainingSession(
          is_chief=True,
          checkpoint_dir=logdir,
          save_summaries_steps=None,
          save_summaries_secs=0.1,
          log_step_count_steps=10) as session:
        session.run(new_gstep)
        time.sleep(0.2)
        for _ in range(101):
          session.run(new_gstep)
    summaries = latest_summaries(logdir)
    tags = [s.summary.value[0].tag for s in summaries]
    self.assertIn('my_summary_tag', tags)
    self.assertIn('global_step/sec', tags)

  def test_custom_saving(self):
    logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint')
    fake_hook = FakeHook()
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      with monitored_session.MonitoredTrainingSession(
          is_chief=True,
          checkpoint_dir=logdir,
          chief_only_hooks=[fake_hook],
          save_checkpoint_secs=0) as session:
        self.assertEqual(0, session.run(gstep))
        self.assertEqual(1, session.run(do_step))
        self.assertEqual(2, session.run(do_step))

      # Check whether custom hook called or not
      self.assertEqual(1, fake_hook.call_counter['begin'])
      # A restart will not find the checkpoint, since we didn't save.
      with monitored_session.MonitoredTrainingSession(
          is_chief=True, checkpoint_dir=logdir) as session:
        self.assertEqual(0, session.run(gstep))

  def test_save_graph_def(self):
    logdir = _test_dir(self.get_temp_dir(), 'test_save_graph_def')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      new_gstep = state_ops.assign_add(gstep, 1)
      with monitored_session.MonitoredTrainingSession(
          is_chief=True,
          checkpoint_dir=logdir,
          save_checkpoint_steps=1,
          save_graph_def=True) as session:
        self.assertIn('graph.pbtxt', os.listdir(logdir))
        self.assertLen(glob.glob(os.path.join(logdir, '*.meta')), 1)
        session.run(new_gstep)
        self.assertLen(glob.glob(os.path.join(logdir, '*.meta')), 2)

  def test_save_graph_def_false(self):
    logdir = _test_dir(self.get_temp_dir(), 'test_save_graph_def')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      new_gstep = state_ops.assign_add(gstep, 1)
      with monitored_session.MonitoredTrainingSession(
          is_chief=True,
          checkpoint_dir=logdir,
          save_checkpoint_steps=1,
          save_graph_def=False) as session:
        self.assertNotIn('graph.pbtxt', os.listdir(logdir))
        self.assertEmpty(glob.glob(os.path.join(logdir, '*.meta')))
        session.run(new_gstep)
        self.assertEmpty(glob.glob(os.path.join(logdir, '*.meta')))


class MockExtended:

  def __init__(self, between_graph, should_init, should_checkpoint,
               should_save_summary):
    self.experimental_between_graph = between_graph
    self.experimental_should_init = should_init
    self.should_checkpoint = should_checkpoint
    self.should_save_summary = should_save_summary


class MockStrategy:

  def __init__(self,
               between_graph=False,
               should_init=True,
               should_checkpoint=None,
               should_save_summary=None):
    self.extended = MockExtended(between_graph, should_init, should_checkpoint,
                                 should_save_summary)


class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
  """Test distribute coordinator controls summary saving and checkpointing."""

  def test_summary_hook_enabled(self):
    context = distribute_coordinator._WorkerContext(
        MockStrategy(should_save_summary=True), None, None, None)

    logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      new_gstep = state_ops.assign_add(gstep, 1)
      summary.scalar('my_summary_tag', new_gstep * 2)
      with context, monitored_session.MonitoredTrainingSession(
          checkpoint_dir=logdir,
          save_summaries_steps=100,
          log_step_count_steps=10) as session:
        for _ in range(101):
          session.run(new_gstep)

    summaries = latest_summaries(logdir)
    tags = [s.summary.value[0].tag for s in summaries]
    self.assertIn('my_summary_tag', tags)
    self.assertIn('global_step/sec', tags)

  def test_summary_hook_disabled(self):
    context = distribute_coordinator._WorkerContext(
        MockStrategy(should_save_summary=False), None, None, None)

    logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      new_gstep = state_ops.assign_add(gstep, 1)
      summary.scalar('my_summary_tag', new_gstep * 2)
      with context, monitored_session.MonitoredTrainingSession(
          checkpoint_dir=logdir,
          save_summaries_steps=100,
          log_step_count_steps=10) as session:
        for _ in range(101):
          session.run(new_gstep)

    # No summary is saved.
    summaries = latest_summaries(logdir)
    self.assertEqual(len(summaries), 0)

  def test_checkpoint_hook_enabled(self):
    context = distribute_coordinator._WorkerContext(
        MockStrategy(should_checkpoint=True), None, None, None)

    logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      new_gstep = state_ops.assign_add(gstep, 1)
      with context, monitored_session.MonitoredTrainingSession(
          checkpoint_dir=logdir,
          save_checkpoint_steps=100,
          log_step_count_steps=10) as session:
        for _ in range(100):
          session.run(new_gstep)

      # A restart will find the checkpoint and recover automatically.
      with monitored_session.MonitoredTrainingSession(
          is_chief=True, checkpoint_dir=logdir) as session:
        self.assertEqual(100, session.run(gstep))

  def test_checkpoint_hook_disabled(self):
    context = distribute_coordinator._WorkerContext(
        MockStrategy(should_checkpoint=False), None, None, None)

    logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      new_gstep = state_ops.assign_add(gstep, 1)
      with context, monitored_session.MonitoredTrainingSession(
          checkpoint_dir=logdir,
          save_checkpoint_steps=100,
          log_step_count_steps=10) as session:
        for _ in range(100):
          session.run(new_gstep)

    # No checkpoint is saved.
    checkpoint = checkpoint_management.latest_checkpoint(logdir)
    self.assertIsNone(checkpoint)

  def test_checkpoint_hook_enable_on_non_chief_with_collective_ops(self):
    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
    strategy.extended._is_chief = False

    context = distribute_coordinator._WorkerContext(strategy, None, 'worker', 1)

    logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      new_gstep = state_ops.assign_add(gstep, 1)
      with context, monitored_session.MonitoredTrainingSession(
          checkpoint_dir=logdir,
          save_checkpoint_steps=100,
          log_step_count_steps=10) as session:
        for _ in range(100):
          session.run(new_gstep)

    # No checkpoint is saved.
    checkpoint = checkpoint_management.latest_checkpoint(logdir)
    self.assertIsNone(checkpoint)

    # But saved to a temporary directory.
    checkpoint = checkpoint_management.latest_checkpoint(
        os.path.join(logdir, 'tmp_worker_1'))
    self.assertIsNotNone(checkpoint)


class StopAtNSession(monitored_session._WrappedSession):
  """A wrapped session that stops at the N-th call to _check_stop."""

  def __init__(self, sess, n):
    super(StopAtNSession, self).__init__(sess)
    self._count = n

  def _check_stop(self):
    if self._count == 0:
      return True
    self._count -= 1
    return False


class WrappedSessionTest(test.TestCase):
  """_WrappedSession tests."""

  @test_util.run_deprecated_v1
  def test_properties(self):
    with self.cached_session() as sess:
      constant_op.constant(0.0)
      wrapped_sess = monitored_session._WrappedSession(sess)
      self.assertEqual(sess.graph, wrapped_sess.graph)
      self.assertEqual(sess.sess_str, wrapped_sess.sess_str)

  @test_util.run_deprecated_v1
  def test_should_stop_on_close(self):
    with self.cached_session() as sess:
      wrapped_sess = monitored_session._WrappedSession(sess)
      self.assertFalse(wrapped_sess.should_stop())
      wrapped_sess.close()
      self.assertTrue(wrapped_sess.should_stop())

  @test_util.run_deprecated_v1
  def test_should_stop_uses_check_stop(self):
    with self.cached_session() as sess:
      wrapped_sess = StopAtNSession(sess, 3)
      self.assertFalse(wrapped_sess.should_stop())
      self.assertFalse(wrapped_sess.should_stop())
      self.assertFalse(wrapped_sess.should_stop())
      self.assertTrue(wrapped_sess.should_stop())

  @test_util.run_deprecated_v1
  def test_should_stop_delegates_to_wrapped_session(self):
    with self.cached_session() as sess:
      wrapped_sess0 = StopAtNSession(sess, 4)
      wrapped_sess1 = monitored_session._WrappedSession(wrapped_sess0)
      self.assertFalse(wrapped_sess1.should_stop())
      self.assertFalse(wrapped_sess1.should_stop())
      self.assertFalse(wrapped_sess1.should_stop())
      self.assertFalse(wrapped_sess1.should_stop())
      self.assertTrue(wrapped_sess1.should_stop())

  @test_util.run_deprecated_v1
  def test_close_twice(self):
    with self.cached_session() as sess:
      wrapped_sess = monitored_session._WrappedSession(sess)
      wrapped_sess.close()
      self.assertTrue(wrapped_sess.should_stop())
      wrapped_sess.close()
      self.assertTrue(wrapped_sess.should_stop())

  @test_util.run_deprecated_v1
  def test_run(self):
    with self.cached_session() as sess:
      c = constant_op.constant(0)
      v = array_ops.identity(c)
      self.assertEqual(42, sess.run(v, feed_dict={c: 42}))
      wrapped_sess = monitored_session._WrappedSession(sess)
      self.assertEqual(51, wrapped_sess.run(v, feed_dict={c: 51}))


def busy_wait_for_coord_stop(coord):
  while not coord.should_stop():
    time.sleep(0.001)


class CoordinatedSessionTest(test.TestCase):
  """_CoordinatedSession tests."""

  @test_util.run_deprecated_v1
  def test_properties(self):
    with self.cached_session() as sess:
      constant_op.constant(0.0)
      coord = coordinator.Coordinator()
      coord_sess = monitored_session._CoordinatedSession(sess, coord)
      self.assertEqual(sess.graph, coord_sess.graph)
      self.assertEqual(sess.sess_str, coord_sess.sess_str)

  @test_util.run_deprecated_v1
  def test_run(self):
    with self.cached_session() as sess:
      c = constant_op.constant(0)
      v = array_ops.identity(c)
      coord = coordinator.Coordinator()
      coord_sess = monitored_session._CoordinatedSession(sess, coord)
      self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42}))

  @test_util.run_deprecated_v1
  def test_should_stop_on_close(self):
    with self.cached_session() as sess:
      coord = coordinator.Coordinator()
      coord_sess = monitored_session._CoordinatedSession(sess, coord)
      self.assertFalse(coord_sess.should_stop())
      coord_sess.close()
      self.assertTrue(coord_sess.should_stop())

  @test_util.run_deprecated_v1
  def test_should_stop_on_coord_stop(self):
    with self.cached_session() as sess:
      coord = coordinator.Coordinator()
      coord_sess = monitored_session._CoordinatedSession(sess, coord)
      self.assertFalse(coord_sess.should_stop())
      coord.request_stop()
      self.assertTrue(coord_sess.should_stop())

  @test_util.run_deprecated_v1
  def test_dont_request_stop_on_exception_in_main_thread(self):
    with self.cached_session() as sess:
      c = constant_op.constant(0)
      v = array_ops.identity(c)
      coord = coordinator.Coordinator()
      coord_sess = monitored_session._CoordinatedSession(sess, coord)
      self.assertFalse(coord_sess.should_stop())
      self.assertEqual(0, coord_sess.run(c))
      self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
      with self.assertRaisesRegex(TypeError, 'None has invalid type'):
        coord_sess.run([None], feed_dict={c: 2})
      self.assertFalse(coord.should_stop())
      self.assertFalse(coord_sess.should_stop())

  @test_util.run_deprecated_v1
  def test_stop_threads_on_close_after_exception(self):
    with self.cached_session() as sess:
      c = constant_op.constant(0)
      v = array_ops.identity(c)
      coord = coordinator.Coordinator()
      threads = [
          threading.Thread(
              target=busy_wait_for_coord_stop, args=(coord,)) for _ in range(3)
      ]
      for t in threads:
        coord.register_thread(t)
        t.start()
      coord_sess = monitored_session._CoordinatedSession(sess, coord)
      self.assertFalse(coord_sess.should_stop())
      for t in threads:
        self.assertTrue(t.is_alive())
      self.assertEqual(0, coord_sess.run(c))
      for t in threads:
        self.assertTrue(t.is_alive())
      self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
      for t in threads:
        self.assertTrue(t.is_alive())
      with self.assertRaisesRegex(TypeError, 'None has invalid type'):
        coord_sess.run([None], feed_dict={c: 2})
      coord_sess.close()
      for t in threads:
        self.assertFalse(t.is_alive())
      self.assertTrue(coord.should_stop())
      self.assertTrue(coord_sess.should_stop())

  def test_stop_threads_on_close(self):
    with self.cached_session() as sess:
      coord = coordinator.Coordinator()
      threads = [
          threading.Thread(
              target=busy_wait_for_coord_stop, args=(coord,)) for _ in range(3)
      ]
      for t in threads:
        coord.register_thread(t)
        t.start()
      coord_sess = monitored_session._CoordinatedSession(sess, coord)
      coord_sess.close()
      for t in threads:
        self.assertFalse(t.is_alive())
      self.assertTrue(coord.should_stop())
      self.assertTrue(coord_sess.should_stop())

  @test_util.run_deprecated_v1
  def test_propagates_exception_trace(self):
    assertion = control_flow_assert.Assert(False, ['This should fail.'])
    with self.cached_session() as sess:
      coord = coordinator.Coordinator(clean_stop_exception_types=())
      coord_sess = monitored_session._CoordinatedSession(sess, coord)
      try:
        coord_sess.run([assertion])
        self.fail('No exception was raised by assertion.')
      except errors_impl.InvalidArgumentError:
        # Extract the name of the file where the exception was first raised.
        _, _, exc_traceback = sys.exc_info()
        tb = traceback.extract_tb(exc_traceback)
        exc_source_file = tb[-1][0]
        exc_source_basename = os.path.basename(exc_source_file)
        # If it's monitored_session.py then the original stack trace was not
        # correctly propagated.
        self.assertIn(
            exc_source_basename, ['session.py', 'monitored_session.py'],
            'The exception was raised from an unrecognized file. This unit '
            'test probably needs to be updated. Traceback:\n%s\n' % tb)
        self.assertEqual(
            exc_source_basename, 'session.py',
            'Original stack trace was not propagated by MonitoredSession. '
            'Traceback:\n%s' % tb)


class AbortAtNSession:
  """A mock session that aborts at the N-th run call."""

  def __init__(self, sess, n):
    self._sess = sess
    self._count = n

  def close(self):
    pass

  def run(self, *args, **kwargs):
    if self._count == 0:
      raise errors_impl.AbortedError('Aborted at N', None, None)
    self._count -= 1
    return self._sess.run(*args, **kwargs)


class StopCoordinatorWithException(session_run_hook.SessionRunHook):
  """With this hook Coordinator throws an exception after N-runs."""

  def __init__(self, calls_before_stopping, exception_to_raise=None):
    self._started_the_side_thread_already = False
    self._lock = threading.Lock()
    self._stored_exception_event = threading.Event()
    self._calls_before_stopping = calls_before_stopping
    self._exception_to_raise = (exception_to_raise or errors_impl.AbortedError(
        None, None, 'Aborted at N'))

  def _maybe_stop_with_exception(self, coord):
    while True:
      with self._lock:
        if self._calls_before_stopping == 0:
          try:
            raise self._exception_to_raise
          except Exception as e:  # pylint: disable=broad-except
            coord.request_stop(e)
            self._stored_exception_event.set()
            break

  def after_create_session(self, session, coord):
    if self._started_the_side_thread_already:
      return

    separate_thread = threading.Thread(
        target=self._maybe_stop_with_exception, args=(coord,))

    coord.register_thread(separate_thread)
    separate_thread.start()
    self._started_the_side_thread_already = True
    # Coordinator will take care of joining `separate_thread`.

  def after_run(self, run_context, run_values):
    stopping_now = False
    with self._lock:
      self._calls_before_stopping -= 1
      if self._calls_before_stopping == 0:
        stopping_now = True

    if stopping_now:
      self._stored_exception_event.wait()


class FailTrainingAfterCoordinatorStopped(StopCoordinatorWithException):
  """With this hook training encounters an exception after N-runs."""

  def __init__(self, calls_before_stopping):
    StopCoordinatorWithException.__init__(self, calls_before_stopping)
    self._coord = None

  def after_create_session(self, session, coord):
    self._coord = coord
    return StopCoordinatorWithException.after_create_session(
        self, session, coord)

  def after_run(self, run_context, run_values):
    StopCoordinatorWithException.after_run(self, run_context, run_values)
    try:
      # After a `run`, an exception could have been stored inside the
      # coordinator.
      self._coord.raise_requested_exception()
    except errors_impl.AbortedError:
      # In real world, the main thread may or may not know about the exception
      # that stopped the coordinator. Because the coordinator has stopped, the
      # main thread could have gotten stuck as well (for example, the
      # coordinator was supposed to execute `FIFOQueue.enqueue` while the main
      # thread is executing a blocking `FIFOQueue.dequeue`). After it got stuck,
      # the session is going to get garbage collected after some time with:
      raise errors_impl.CancelledError(None, None,
                                       'Session got garbage-collected.')


class CountingSessionCreator:
  """A creator that counts the number of created sessions."""

  def __init__(self, session):
    self._initial_session = session
    # We only have one session per test case. We can't re-create it, thus
    # it shouldn't be closed.
    self._initial_session.close = lambda *args: None
    self._create_session_calls = 0

  @property
  def number_of_sessions_created(self):
    return self._create_session_calls

  def create_session(self):
    self._create_session_calls += 1
    return self._initial_session


class RecoverableSessionTest(test.TestCase):
  """_RecoverableSession tests."""

  class _SessionReturner:

    def __init__(self, sess):
      self._sess = sess

    def create_session(self):
      return self._sess

  @test_util.run_deprecated_v1
  def test_properties(self):
    with self.cached_session() as sess:
      constant_op.constant(0.0)
      recoverable_sess = monitored_session._RecoverableSession(
          self._SessionReturner(sess))
      self.assertEqual(sess.graph, recoverable_sess.graph)
      self.assertEqual(sess.sess_str, recoverable_sess.sess_str)

  @test_util.run_deprecated_v1
  def test_run(self):
    with self.cached_session() as sess:
      c = constant_op.constant(0)
      v = array_ops.identity(c)
      recoverable_sess = monitored_session._RecoverableSession(
          self._SessionReturner(sess))
      self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))

  @test_util.run_deprecated_v1
  def test_recovery(self):
    with self.cached_session() as sess:

      class StackSessionCreator:

        def __init__(self, sess):
          self.sessions_to_use = [
              AbortAtNSession(sess, x + 1) for x in range(3)
          ]

        def create_session(self):
          return self.sessions_to_use.pop(0)

      c = constant_op.constant(0)
      v = array_ops.identity(c)
      session_creator = StackSessionCreator(sess)
      # List of 3 sessions to use for recovery.  The first one aborts
      # after 1 run() call, the second after 2 run calls, the third
      # after 3 run calls.
      self.assertEqual(3, len(session_creator.sessions_to_use))
      # Make the recoverable session uses these 3 sessions in sequence by
      # passing a factory that pops from the session_to_use list.
      recoverable_sess = monitored_session._RecoverableSession(session_creator)
      self.assertEqual(
          2, len(session_creator.sessions_to_use))  # One session popped.
      # Using first session.
      self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
      self.assertEqual(
          2, len(session_creator.sessions_to_use))  # Still 2 sessions available
      # This will fail and recover by picking up the second session.
      self.assertEqual(42, recoverable_sess.run(v, feed_dict={c: 42}))
      self.assertEqual(
          1, len(session_creator.sessions_to_use))  # Still 1 session available
      self.assertEqual(33, recoverable_sess.run(v, feed_dict={c: 33}))
      self.assertEqual(
          1, len(session_creator.sessions_to_use))  # Still 1 session available
      # This will fail and recover by picking up the last session.
      self.assertEqual(24, recoverable_sess.run(v, feed_dict={c: 24}))
      self.assertEqual(
          0, len(session_creator.sessions_to_use))  # All sessions used.
      self.assertEqual(11, recoverable_sess.run(v, feed_dict={c: 11}))
      self.assertEqual(0, recoverable_sess.run(v, feed_dict={c: 0}))
      # This will fail and throw a real error as the pop() will fail.
      with self.assertRaisesRegex(IndexError, 'pop from empty list'):
        recoverable_sess.run(v, feed_dict={c: -12})

  @test_util.run_deprecated_v1
  def test_recovery_from_coordinator_exception(self):
    with self.cached_session() as test_session:
      session_creator = CountingSessionCreator(test_session)
      session = monitored_session.MonitoredSession(
          session_creator,
          [StopCoordinatorWithException(calls_before_stopping=2)])

      self.assertEqual(1, session_creator.number_of_sessions_created)
      self.assertFalse(session.should_stop())

      c = constant_op.constant(0)
      v = array_ops.identity(c)

      # The coordinator will not abort during this call, since it's the call
      # number 0.
      self.assertEqual(51, session.run(v, feed_dict={c: 51}))
      self.assertFalse(session.should_stop())
      # The coordinator will abort during the next call, since it's the call
      # number 1.
      self.assertEqual(42, session.run(v, feed_dict={c: 42}))
      # Even though the coordinator was asked to stop, the underlying session is
      # recreated and is to be continued.
      self.assertFalse(session.should_stop())
      self.assertEqual(2, session_creator.number_of_sessions_created)

  @test_util.run_deprecated_v1
  def test_recovery_from_non_preemption_in_coordinator(self):
    with self.cached_session() as test_session:
      session_creator = CountingSessionCreator(test_session)
      hook = StopCoordinatorWithException(
          calls_before_stopping=2,
          exception_to_raise=errors_impl.UnknownError(
              None, None, 'Some fatal exception inside the coordinator.'))
      session = monitored_session.MonitoredSession(session_creator, [hook])

      self.assertEqual(1, session_creator.number_of_sessions_created)
      self.assertFalse(session.should_stop())

      c = constant_op.constant(0)
      v = array_ops.identity(c)

      # The coordinator will not abort during this call, since it's the call
      # number 0.
      self.assertEqual(51, session.run(v, feed_dict={c: 51}))
      self.assertFalse(session.should_stop())
      # The coordinator will abort during the next call, since it's the call
      # number 1.
      self.assertEqual(42, session.run(v, feed_dict={c: 42}))
      # The coordinator was asked to stop due to non-redeemable error. Training
      # should stop and the session should not be recreated.
      self.assertTrue(session.should_stop())
      self.assertEqual(1, session_creator.number_of_sessions_created)
      with self.assertRaises(errors_impl.UnknownError):
        session.close()

  @test_util.run_deprecated_v1
  def test_recovery_from_session_getting_stuck(self):
    with self.cached_session() as test_session:
      session_creator = CountingSessionCreator(test_session)
      session = monitored_session.MonitoredSession(
          session_creator,
          [FailTrainingAfterCoordinatorStopped(calls_before_stopping=2)])

      self.assertEqual(1, session_creator.number_of_sessions_created)
      self.assertFalse(session.should_stop())

      c = constant_op.constant(0)
      v = array_ops.identity(c)

      # Training will not fail, since it's the call number 0.
      self.assertEqual(51, session.run(v, feed_dict={c: 51}))
      self.assertFalse(session.should_stop())
      # Training will fail during the next call, since it's the call
      # number 1.
      self.assertEqual(42, session.run(v, feed_dict={c: 42}))
      # Even though the coordinator stopped which and training failed, the
      # underlying session is recreated and training is to be continued.
      self.assertFalse(session.should_stop())
      self.assertEqual(2, session_creator.number_of_sessions_created)

  @test_util.run_deprecated_v1
  def test_step_fn_recovery_from_coordinator_exception_when_run_hooks(self):
    with self.cached_session() as test_session:
      session_creator = CountingSessionCreator(test_session)
      session = monitored_session.MonitoredSession(
          session_creator,
          [StopCoordinatorWithException(calls_before_stopping=2)])

      self.assertEqual(1, session_creator.number_of_sessions_created)
      self.assertFalse(session.should_stop())

      c = constant_op.constant(0)
      v = array_ops.identity(c)

      def feed_step_fn(value):
        def step_fn(step_context):
          return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
        return step_fn

      # The coordinator will not abort during this call, since it's the call
      # number 0.
      self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
      self.assertFalse(session.should_stop())
      # The coordinator will abort during the next call, since it's the call
      # number 1.
      self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
      # Even though the coordinator was asked to stop, the underlying session is
      # recreated and is to be continued.
      self.assertFalse(session.should_stop())
      self.assertEqual(2, session_creator.number_of_sessions_created)

  @test_util.run_deprecated_v1
  def test_recovery_from_non_preemption_in_coordinator_when_run_hooks(self):
    with self.cached_session() as test_session:
      session_creator = CountingSessionCreator(test_session)
      hook = StopCoordinatorWithException(
          calls_before_stopping=2,
          exception_to_raise=errors_impl.UnknownError(
              None, None, 'Some fatal exception inside the coordinator.'))
      session = monitored_session.MonitoredSession(session_creator, [hook])

      self.assertEqual(1, session_creator.number_of_sessions_created)
      self.assertFalse(session.should_stop())

      c = constant_op.constant(0)
      v = array_ops.identity(c)

      def feed_step_fn(value):
        def step_fn(step_context):
          return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
        return step_fn

      # The coordinator will not abort during this call, since it's the call
      # number 0.
      self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
      self.assertFalse(session.should_stop())
      # The coordinator will abort during the next call, since it's the call
      # number 1.
      self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
      # The coordinator was asked to stop due to non-redeemable error. Training
      # should stop and the session should not be recreated.
      self.assertTrue(session.should_stop())
      self.assertEqual(1, session_creator.number_of_sessions_created)
      with self.assertRaises(errors_impl.UnknownError):
        session.close()

  @test_util.run_deprecated_v1
  def test_recovery_from_session_getting_stuck_when_run_hooks(self):
    with self.cached_session() as test_session:
      session_creator = CountingSessionCreator(test_session)
      session = monitored_session.MonitoredSession(
          session_creator,
          [FailTrainingAfterCoordinatorStopped(calls_before_stopping=2)])

      self.assertEqual(1, session_creator.number_of_sessions_created)
      self.assertFalse(session.should_stop())

      c = constant_op.constant(0)
      v = array_ops.identity(c)

      def feed_step_fn(value):
        def step_fn(step_context):
          return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
        return step_fn

      # Training will not fail, since it's the call number 0.
      self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
      self.assertFalse(session.should_stop())
      # Training will fail during the next call, since it's the call
      # number 1.
      self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
      # Even though the coordinator stopped which and training failed, the
      # underlying session is recreated and training is to be continued.
      self.assertFalse(session.should_stop())
      self.assertEqual(2, session_creator.number_of_sessions_created)

  def create_raw_session_with_failing_coordinator(self, session_creator, hook):
    """Return MonitoredSession that triggers coordinator failures."""
    session = monitored_session.MonitoredSession(session_creator, [hook])
    # We would like to test a situation where during fetches through the
    # raw session, the coordinator fails with an exception.  To do that, we
    # are going to use (raw_session + StopCoordinatorWithException) hook
    # combination that is stored in
    # `MonitoredSession._RecoverableSession._CoordinatedSession._sess`
    # at this point:
    session._tf_sess = lambda: session._sess._sess._sess
    # `run()` on such a session is equivalent to `run()` on the raw session
    # with separate coordinator threads independently stopping with an
    # exception.
    return session

  @test_util.run_deprecated_v1
  def test_step_fn_recovery_from_coordinator_exception_with_raw_session(self):
    with self.cached_session() as test_session:
      session_creator = CountingSessionCreator(test_session)
      session = self.create_raw_session_with_failing_coordinator(
          session_creator,
          StopCoordinatorWithException(calls_before_stopping=2))

      self.assertEqual(1, session_creator.number_of_sessions_created)
      self.assertFalse(session.should_stop())

      c = constant_op.constant(0)
      v = array_ops.identity(c)

      def feed_step_fn(value):

        def step_fn(step_context):
          return step_context.session.run(fetches=v, feed_dict={c: value})

        return step_fn

      # The coordinator will not abort during this call, since it's the call
      # number 0.
      self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
      self.assertFalse(session.should_stop())
      # The coordinator will abort during the next call, since it's the call
      # number 1.
      self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
      # Even though the coordinator was asked to stop, the underlying session is
      # recreated and is to be continued.
      self.assertFalse(session.should_stop())
      self.assertEqual(2, session_creator.number_of_sessions_created)

  @test_util.run_deprecated_v1
  def test_recovery_from_non_preemption_in_coordinator_with_raw_session(self):
    with self.cached_session() as test_session:
      session_creator = CountingSessionCreator(test_session)
      session = self.create_raw_session_with_failing_coordinator(
          session_creator,
          StopCoordinatorWithException(
              calls_before_stopping=2,
              exception_to_raise=errors_impl.UnknownError(
                  None, None, 'Some fatal exception inside the coordinator.')))

      self.assertEqual(1, session_creator.number_of_sessions_created)
      self.assertFalse(session.should_stop())

      c = constant_op.constant(0)
      v = array_ops.identity(c)

      def feed_step_fn(value):

        def step_fn(step_context):
          return step_context.run_with_hooks(fetches=v, feed_dict={c: value})

        return step_fn

      # The coordinator will not abort during this call, since it's the call
      # number 0.
      self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
      self.assertFalse(session.should_stop())
      # The coordinator will abort during the next call, since it's the call
      # number 1.
      self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
      # The coordinator was asked to stop due to non-redeemable error. Training
      # should stop and the session should not be recreated.
      self.assertTrue(session.should_stop())
      self.assertEqual(1, session_creator.number_of_sessions_created)
      with self.assertRaises(errors_impl.UnknownError):
        session.close()

  @test_util.run_deprecated_v1
  def test_recovery_from_session_getting_stuck_with_raw_session(self):
    with self.cached_session() as test_session:
      session_creator = CountingSessionCreator(test_session)
      session = self.create_raw_session_with_failing_coordinator(
          session_creator,
          FailTrainingAfterCoordinatorStopped(calls_before_stopping=2))

      self.assertEqual(1, session_creator.number_of_sessions_created)
      self.assertFalse(session.should_stop())

      c = constant_op.constant(0)
      v = array_ops.identity(c)

      def feed_step_fn(value):

        def step_fn(step_context):
          return step_context.run_with_hooks(fetches=v, feed_dict={c: value})

        return step_fn

      # Training will not fail, since it's the call number 0.
      self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
      self.assertFalse(session.should_stop())
      # Training will fail during the next call, since it's the call
      # number 1.
      self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
      # Even though the coordinator stopped which and training failed, the
      # underlying session is recreated and training is to be continued.
      self.assertFalse(session.should_stop())
      self.assertEqual(2, session_creator.number_of_sessions_created)


class FakeSession(monitored_session._WrappedSession):

  def __init__(self, sess):
    monitored_session._WrappedSession.__init__(self, sess)
    self.args_called = {}

  def run(self, fetches, **kwargs):
    self.args_called = dict(kwargs)
    # Call run only with fetches since we directly pass other arguments.
    return monitored_session._WrappedSession.run(self, fetches)


class HookedSessionTest(test.TestCase):
  """Tests of _HookedSession."""

  def testRunPassesAllArguments(self):
    with ops.Graph().as_default(), session_lib.Session() as sess:
      mock_run = FakeSession(sess)
      mon_sess = monitored_session._HookedSession(sess=mock_run, hooks=[])
      a_tensor = constant_op.constant([0], name='a_tensor')
      self.evaluate(variables.global_variables_initializer())
      output = mon_sess.run(fetches=a_tensor,
                            feed_dict='a_feed',
                            options='an_option',
                            run_metadata='a_metadata')
      self.assertEqual(output, [0])
      self.assertEqual(mock_run.args_called, {
          'feed_dict': 'a_feed',
          'options': 'an_option',
          'run_metadata': 'a_metadata'
      })

  def testCallsHooksBeginEnd(self):
    with ops.Graph().as_default(), session_lib.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = constant_op.constant([0], name='a_tensor')
      self.evaluate(variables.global_variables_initializer())
      mon_sess.run(a_tensor)

      for hook in [mock_hook, mock_hook2]:
        self.assertEqual(
            hook.last_run_values,
            session_run_hook.SessionRunValues(
                results=None,
                options=config_pb2.RunOptions(),
                run_metadata=config_pb2.RunMetadata()))
        self.assertEqual(hook.last_run_context.original_args,
                         session_run_hook.SessionRunArgs(a_tensor))
        self.assertEqual(hook.last_run_context.session, sess)
        self.assertEqual(hook.call_counter['begin'], 0)
        self.assertEqual(hook.call_counter['after_create_session'], 0)
        self.assertEqual(hook.call_counter['before_run'], 1)
        self.assertEqual(hook.call_counter['after_run'], 1)

  def testShouldStop(self):
    with ops.Graph().as_default(), session_lib.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      constant_op.constant([0], name='a_tensor')
      self.evaluate(variables.global_variables_initializer())

      mon_sess.run(fetches='a_tensor')
      self.assertFalse(mon_sess.should_stop())

      mock_hook.should_stop = True
      mon_sess.run(fetches='a_tensor')
      self.assertTrue(mon_sess.should_stop())

  def testFetchesHookRequests(self):
    with ops.Graph().as_default(), session_lib.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = constant_op.constant([0], name='a_tensor')
      another_tensor = constant_op.constant([5], name='another_tensor')
      third_tensor = constant_op.constant([10], name='third_tensor')
      mock_hook.request = session_run_hook.SessionRunArgs([another_tensor])
      mock_hook2.request = session_run_hook.SessionRunArgs([third_tensor])
      self.evaluate(variables.global_variables_initializer())

      output = mon_sess.run(fetches=a_tensor)
      self.assertEqual(output, [0])
      self.assertEqual(mock_hook.last_run_values.results, [5])
      self.assertEqual(mock_hook2.last_run_values.results, [10])

  def testOnlyHooksHaveFeeds(self):
    with ops.Graph().as_default(), session_lib.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = constant_op.constant([0], name='a_tensor')
      b_tensor = constant_op.constant([0], name='b_tensor')
      add_tensor = a_tensor + b_tensor
      mock_hook.request = session_run_hook.SessionRunArgs(
          None, feed_dict={a_tensor: [5]})
      mock_hook2.request = session_run_hook.SessionRunArgs(
          None, feed_dict={b_tensor: [10]})
      self.evaluate(variables.global_variables_initializer())

      self.assertEqual(mon_sess.run(fetches=add_tensor), [15])

  def testBothHooksAndUserHaveFeeds(self):
    with ops.Graph().as_default(), session_lib.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = constant_op.constant([0], name='a_tensor')
      b_tensor = constant_op.constant([0], name='b_tensor')
      c_tensor = constant_op.constant([0], name='c_tensor')
      add_tensor = a_tensor + b_tensor + c_tensor
      mock_hook.request = session_run_hook.SessionRunArgs(
          None, feed_dict={a_tensor: [5]})
      mock_hook2.request = session_run_hook.SessionRunArgs(
          None, feed_dict={b_tensor: [10]})
      self.evaluate(variables.global_variables_initializer())

      feed_dict = {c_tensor: [20]}
      self.assertEqual(
          mon_sess.run(fetches=add_tensor, feed_dict=feed_dict), [35])
      # User feed_dict should not be changed
      self.assertEqual(len(feed_dict), 1)

  def testHooksFeedConflicts(self):
    with ops.Graph().as_default(), session_lib.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = constant_op.constant([0], name='a_tensor')
      b_tensor = constant_op.constant([0], name='b_tensor')
      add_tensor = a_tensor + b_tensor
      mock_hook.request = session_run_hook.SessionRunArgs(
          None, feed_dict={a_tensor: [5]})
      mock_hook2.request = session_run_hook.SessionRunArgs(
          None, feed_dict={a_tensor: [10]})
      self.evaluate(variables.global_variables_initializer())

      with self.assertRaisesRegex(RuntimeError, 'Same tensor is fed'):
        mon_sess.run(fetches=add_tensor)

  def testHooksAndUserFeedConflicts(self):
    with ops.Graph().as_default(), session_lib.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = constant_op.constant([0], name='a_tensor')
      b_tensor = constant_op.constant([0], name='b_tensor')
      add_tensor = a_tensor + b_tensor
      mock_hook.request = session_run_hook.SessionRunArgs(
          None, feed_dict={a_tensor: [5]})
      mock_hook2.request = session_run_hook.SessionRunArgs(
          None, feed_dict={b_tensor: [10]})
      self.evaluate(variables.global_variables_initializer())

      with self.assertRaisesRegex(RuntimeError, 'Same tensor is fed'):
        mon_sess.run(fetches=add_tensor, feed_dict={b_tensor: [10]})


class RaiseOnceAtCountN(session_run_hook.SessionRunHook):
  """Hook that raises an Exception at step N."""

  def __init__(self, n, ex):
    self.n = n
    self.ex = ex
    self.raised = False

  def before_run(self, run_context):
    # Raise the first time we reach step N.
    self.n -= 1
    if 0 == self.n and not self.raised:
      self.raised = True
      raise self.ex
    return None


class RunOptionsMetadataHook(session_run_hook.SessionRunHook):
  """A hook that observes & optionally modifies RunOptions and RunMetadata."""

  def __init__(self, trace_level, timeout_in_ms, output_partition_graphs,
               debug_tensor_watch, report_tensor_allocations_upon_oom):
    self._trace_level = trace_level
    self._timeout_in_ms = timeout_in_ms
    self._output_partition_graphs = output_partition_graphs
    self._debug_tensor_watch = debug_tensor_watch
    self._report_tensor_allocations_upon_oom = (
        report_tensor_allocations_upon_oom)

    self.run_options_list = []
    self.run_metadata_list = []

  def before_run(self, run_context):
    options = config_pb2.RunOptions(
        trace_level=self._trace_level,
        timeout_in_ms=self._timeout_in_ms,
        output_partition_graphs=self._output_partition_graphs,
        report_tensor_allocations_upon_oom=self
        ._report_tensor_allocations_upon_oom)
    options.debug_options.debug_tensor_watch_opts.extend(
        [self._debug_tensor_watch])
    return session_run_hook.SessionRunArgs(None, None, options=options)

  def after_run(self, run_context, run_values):
    self.run_options_list.append(run_values.options)
    self.run_metadata_list.append(run_values.run_metadata)


class MonitoredSessionTest(test.TestCase):
  """MonitoredSession tests."""

  def test_defaults(self):
    with ops.Graph().as_default():
      a_var = variable_v1.VariableV1(0)
      with monitored_session.MonitoredSession() as session:
        self.assertEqual(0, session.run(a_var))

  def test_last_step(self):
    logdir = _test_dir(self.get_temp_dir(), 'test_last_step')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      # Run till step 3 and save.
      hooks = [basic_session_run_hooks.StopAtStepHook(last_step=3)]
      with monitored_session.MonitoredSession(hooks=hooks) as session:
        self.assertEqual(0, session.run(gstep))
        self.assertFalse(session.should_stop())
        self.assertEqual(1, session.run(do_step))
        self.assertFalse(session.should_stop())
        self.assertEqual(2, session.run(do_step))
        self.assertFalse(session.should_stop())
        self.assertEqual(3, session.run(do_step))
        self.assertTrue(session.should_stop())
        save_path = saver_lib._get_saver_or_default().save(
            session._coordinated_creator.tf_sess,
            os.path.join(logdir, 'step-3'))
      # Run till step 5 and save.
      def load_ckpt(scaffold, sess):
        scaffold.saver.restore(sess, save_path)

      session_creator = monitored_session.ChiefSessionCreator(
          monitored_session.Scaffold(init_fn=load_ckpt))
      hooks = [basic_session_run_hooks.StopAtStepHook(last_step=5)]
      with monitored_session.MonitoredSession(
          hooks=hooks, session_creator=session_creator) as session:
        self.assertEqual(3, session.run(gstep))
        self.assertFalse(session.should_stop())
        self.assertEqual(4, session.run(do_step))
        self.assertFalse(session.should_stop())
        self.assertEqual(5, session.run(do_step))
        self.assertTrue(session.should_stop())

  def test_num_steps(self):
    logdir = _test_dir(self.get_temp_dir(), 'test_num_steps')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      # Do 3 steps and save.
      hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=3)]
      with monitored_session.MonitoredSession(hooks=hooks) as session:
        session.run(do_step)
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertTrue(session.should_stop())
        save_path = saver_lib._get_saver_or_default().save(
            session._coordinated_creator.tf_sess,
            os.path.join(logdir, 'step-3'))
      # Restore and do 4 steps.
      def load_ckpt(scaffold, sess):
        scaffold.saver.restore(sess, save_path)

      session_creator = monitored_session.ChiefSessionCreator(
          scaffold=monitored_session.Scaffold(init_fn=load_ckpt))
      hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=4)]
      with monitored_session.MonitoredSession(
          hooks=hooks, session_creator=session_creator) as session:
        self.assertEqual(4, session.run(do_step))
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertTrue(session.should_stop())

  # This set of tests, verifies the supervised session behavior when exceptions
  # are raised next to the innermost session run() call.

  @test_util.run_deprecated_v1
  def test_recovery(self):
    logdir = _test_dir(self.get_temp_dir(), 'test_recovery')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      scaffold = monitored_session.Scaffold()
      # Use a hook to save the model every 100 steps.  It also saves it at
      # the end.
      hooks = [
          basic_session_run_hooks.CheckpointSaverHook(
              logdir, save_steps=1, scaffold=scaffold)
      ]
      with monitored_session.MonitoredSession(
          session_creator=monitored_session.ChiefSessionCreator(
              scaffold, checkpoint_dir=logdir),
          hooks=hooks) as session:
        self.assertEqual(0, session.run(gstep))
        self.assertEqual(1, session.run(do_step))
        self.assertEqual(2, session.run(do_step))
      # A restart will find the checkpoint and recover automatically.
      with monitored_session.MonitoredSession(
          session_creator=monitored_session.ChiefSessionCreator(
              scaffold, checkpoint_dir=logdir)) as session:
        self.assertEqual(2, session.run(gstep))
      # A restart will find the checkpoint and recover automatically.
      with monitored_session.MonitoredSession(
          session_creator=monitored_session.ChiefSessionCreator(
              scaffold,
              checkpoint_filename_with_path=checkpoint_management.
              latest_checkpoint(logdir))) as session:
        self.assertEqual(2, session.run(gstep))

  def test_retry_initialization_on_aborted_error(self):
    # Tests that we silently retry on abort during initialization.
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      self.init_raised_aborted_error = False

      def _init_fn(scaffold, session):
        _, _ = scaffold, session
        if not self.init_raised_aborted_error:
          self.init_raised_aborted_error = True
          raise errors_impl.AbortedError(None, None, 'Abort')

      with monitored_session.MonitoredSession(
          session_creator=monitored_session.ChiefSessionCreator(
              scaffold=monitored_session.Scaffold(
                  init_fn=_init_fn))) as session:
        self.assertFalse(session.should_stop())
        self.assertEqual(0, session.run(gstep))
      self.assertTrue(self.init_raised_aborted_error)

  def _retry_test(self, ex):
    # Tests that we silently retry on error.  Note that this does not test
    # recovery as we do not use a CheckpointSaver in this test.
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      hook = RaiseOnceAtCountN(4, ex)
      with monitored_session.MonitoredSession(hooks=[hook]) as session:
        self.assertEqual(0, session.run(gstep))
        self.assertEqual(1, session.run(do_step))
        self.assertEqual(2, session.run(do_step))
        self.assertFalse(session.should_stop())
        # Here at step 3, the hook triggers and raises AbortedError.  The
        # MonitoredSession automatically retries and restart from a freshly
        # initialized session, so the step is back to 0 and running do_step
        # moves it to 1.
        self.assertEqual(1, session.run(do_step))
        self.assertFalse(session.should_stop())
        self.assertTrue(hook.raised)
        self.assertEqual(2, session.run(do_step))
        self.assertFalse(session.should_stop())

  def test_retry_on_aborted_error(self):
    self._retry_test(errors_impl.AbortedError(None, None, 'Abort'))

  def test_retry_on_unavailable_error(self):
    self._retry_test(errors_impl.UnavailableError(None, None, 'Unavailable'))

  def test_recover_and_retry_on_aborted_error(self):
    # Tests that we silently retry and recover on abort.  This test uses
    # a CheckpointSaver to have something to recover from.
    logdir = _test_dir(self.get_temp_dir(),
                       'test_recover_and_retry_on_aborted_error')
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      scaffold = monitored_session.Scaffold()
      abort_hook = RaiseOnceAtCountN(
          4, errors_impl.AbortedError(None, None, 'Abort'))
      # Save after each step.
      ckpt_hook = basic_session_run_hooks.CheckpointSaverHook(
          logdir, save_steps=1, scaffold=scaffold)
      hooks = [abort_hook, ckpt_hook]
      with monitored_session.MonitoredSession(
          session_creator=monitored_session.ChiefSessionCreator(
              scaffold, checkpoint_dir=logdir),
          hooks=hooks) as session:
        self.assertEqual(0, session.run(gstep))
        self.assertEqual(1, session.run(do_step))
        self.assertEqual(2, session.run(do_step))
        self.assertFalse(session.should_stop())
        # Here at step 3, the hook triggers and raises AbortedError.  The
        # MonitoredSession automatically restores and retries.
        self.assertEqual(3, session.run(do_step))
        self.assertTrue(abort_hook.raised)
        self.assertFalse(session.should_stop())
        self.assertEqual(4, session.run(do_step))
        self.assertFalse(session.should_stop())

  def test_exit_cleanly_on_out_of_range_exception(self):
    # Tests that we stop cleanly when OutOfRange is raised.
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      hook = RaiseOnceAtCountN(2, errors_impl.OutOfRangeError(None, None,
                                                              'EOI'))
      session = monitored_session.MonitoredSession(hooks=[hook])
      # session should cleanly exit from the context.
      with session:
        self.assertEqual(0, session.run(gstep))
        self.assertFalse(session.should_stop())
        # Here at step 1, the hook triggers and raises OutOfRange. The
        # session should go into should_stop() mode. It should raise the
        # exception. So next step should not be executed.
        session.run(do_step)
        self.assertTrue(False)
      self.assertTrue(session.should_stop())

  def test_exit_cleanly_on_stop_iteration_exception(self):
    # Tests that we stop cleanly when OutOfRange is raised.
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      hook = RaiseOnceAtCountN(2, StopIteration)
      session = monitored_session.MonitoredSession(hooks=[hook])
      # session should cleanly exit from the context.
      with session:
        self.assertEqual(0, session.run(gstep))
        self.assertFalse(session.should_stop())
        # Here at step 1, the hook triggers and raises StopIteration. The
        # session should go into should_stop() mode. It should raise the
        # exception. So next step should not be executed.
        session.run(do_step)
        self.assertTrue(False)
      self.assertTrue(session.should_stop())

  def test_regular_exception_pass_through_run(self):
    # Tests that regular exceptions just pass through a "with
    # MonitoredSession" block and set the session in stop mode.
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      hook = RaiseOnceAtCountN(4, RuntimeError('regular exception'))
      session = monitored_session.MonitoredSession(hooks=[hook])
      with self.assertRaisesRegex(RuntimeError, 'regular exception'):
        with session:
          self.assertEqual(0, session.run(gstep))
          self.assertEqual(1, session.run(do_step))
          self.assertEqual(2, session.run(do_step))
          self.assertFalse(session.should_stop())
          # This triggers the hook and raises the exception
          session.run(do_step)
          # We should not hit this
          self.assertFalse(True)
      self.assertTrue(hook.raised)
      self.assertTrue(session.should_stop())

  def test_regular_exception_reported_to_coord_pass_through_run(self):
    # Tests that regular exceptions reported to the coordinator from a thread
    # passes through a "run()" call within a "with MonitoredSession" block and
    # set the session in stop mode.
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      session = monitored_session.MonitoredSession()
      run_performed_without_error = False
      with self.assertRaisesRegex(RuntimeError, 'a thread wants to stop'):
        with session:
          self.assertEqual(0, session.run(gstep))
          # Report an exception through the coordinator.
          try:
            raise RuntimeError('a thread wants to stop')
          except RuntimeError as e:
            session._coordinated_creator.coord.request_stop(e)
          # Call run() which should perform normally.
          self.assertEqual(0, session.run(gstep))
          run_performed_without_error = True
      self.assertTrue(run_performed_without_error)

  def test_regular_exception_reported_to_coord_pass_through_return(self):
    # Tests that regular exceptions reported to the coordinator from a thread
    # passes through returning from a "with MonitoredSession" block and
    # set the session in stop mode.
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      session = monitored_session.MonitoredSession()
      with self.assertRaisesRegex(RuntimeError, 'a thread wants to stop'):
        with session:
          self.assertEqual(0, session.run(gstep))
          # Report an exception through the coordinator.
          try:
            raise RuntimeError('a thread wants to stop')
          except RuntimeError as e:
            session._coordinated_creator.coord.request_stop(e)
          self.assertTrue(session.should_stop())

  # This set of tests, verifies the session behavior when exceptions are raised
  # from code inside a "with MonitoredSession:" context.

  def test_stop_cleanly_when_no_exception_in_with_body(self):
    # Tests that regular exceptions pass through
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      session = monitored_session.MonitoredSession()
      with session:
        self.assertEqual(1, session.run(do_step))
        self.assertEqual(2, session.run(do_step))
        self.assertFalse(session.should_stop())
      # Should have closed.
      self.assertTrue(session.should_stop())
      self.assertTrue(session._is_closed())

  def test_raises_regular_exceptions_in_with_body(self):
    # Tests that regular exceptions in "with body" are seen outside.
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      session = monitored_session.MonitoredSession()
      # We should see that exception.
      with self.assertRaisesRegex(RuntimeError, 'regular exception'):
        with session:
          self.assertEqual(1, session.run(do_step))
          self.assertEqual(2, session.run(do_step))
          self.assertFalse(session.should_stop())
          # Will be visible outside the "with body".
          raise RuntimeError('regular exception')
      # Should have closed.
      self.assertTrue(session.should_stop())
      self.assertTrue(session._is_closed())

  def test_graph(self):
    with ops.Graph().as_default() as g:
      with monitored_session.MonitoredSession() as session:
        self.assertEqual(g, session.graph)

  def test_graph_finalized_during_run_unfinalized_after_exit(self):
    with ops.Graph().as_default() as g:
      a_var = variable_v1.VariableV1(0)
      with monitored_session.MonitoredSession() as session:
        self.assertEqual(0, session.run(a_var))
        self.assertTrue(g.finalized)
      self.assertFalse(g.finalized)

  def test_keep_finalized_graph_as_finalized(self):
    with ops.Graph().as_default() as g:
      a_var = variable_v1.VariableV1(0)
      monitored_session.Scaffold().finalize()
      with monitored_session.MonitoredSession() as session:
        self.assertEqual(0, session.run(a_var))
        self.assertTrue(g.finalized)
      self.assertTrue(g.finalized)

  def test_merge_run_options_from_hooks(self):
    """Test for rewriting RunOptions and observing RunMetadata with hooks."""

    with ops.Graph().as_default():
      my_const = constant_op.constant(42, name='my_const')
      _ = constant_op.constant(24, name='my_const_2')

      watch_a = debug_pb2.DebugTensorWatch(
          node_name='my_const',
          output_slot=0,
          debug_ops=['DebugIdentity'],
          debug_urls=[])
      hook_a = RunOptionsMetadataHook(2, 30000, False, watch_a, False)
      watch_b = debug_pb2.DebugTensorWatch(
          node_name='my_const_2',
          output_slot=0,
          debug_ops=['DebugIdentity'],
          debug_urls=[])
      hook_b = RunOptionsMetadataHook(3, 60000, True, watch_b, True)
      with monitored_session.MonitoredSession(
          hooks=[hook_a, hook_b]) as session:
        self.assertEqual(42, session.run(my_const))

        # trace_level=3 should have overridden trace_level=2;
        # timeout_in_ms=60000 should have overridden 30000;
        # output_partition_graphs=True should have overridden False.
        # The two debug tensor watches should have been merged.
        self.assertEqual([
            config_pb2.RunOptions(
                trace_level=3,
                timeout_in_ms=60000,
                output_partition_graphs=True,
                debug_options=debug_pb2.DebugOptions(
                    debug_tensor_watch_opts=[watch_a, watch_b]),
                report_tensor_allocations_upon_oom=True),
        ], hook_b.run_options_list)
        self.assertEqual(1, len(hook_b.run_metadata_list))
        self.assertTrue(
            isinstance(hook_b.run_metadata_list[0], config_pb2.RunMetadata))
        self.assertGreater(len(hook_b.run_metadata_list[0].partition_graphs), 0)

  def test_merge_caller_and_hook_run_options(self):
    """Test that RunOptions from caller and hooks can be merged properly."""

    with ops.Graph().as_default():
      my_const = constant_op.constant(42, name='my_const')
      _ = constant_op.constant(24, name='my_const_2')

      hook_watch = debug_pb2.DebugTensorWatch(
          node_name='my_const_2',
          output_slot=0,
          debug_ops=['DebugIdentity'],
          debug_urls=[])
      hook = RunOptionsMetadataHook(2, 60000, False, hook_watch, False)
      with monitored_session.MonitoredSession(hooks=[hook]) as session:
        caller_watch = debug_pb2.DebugTensorWatch(
            node_name='my_const',
            output_slot=0,
            debug_ops=['DebugIdentity'],
            debug_urls=[])
        caller_options = config_pb2.RunOptions(
            trace_level=3,
            timeout_in_ms=30000,
            output_partition_graphs=True,
            report_tensor_allocations_upon_oom=True)
        caller_options.debug_options.debug_tensor_watch_opts.extend(
            [caller_watch])
        self.assertEqual(42, session.run(my_const, options=caller_options))

        # trace_level=3 from the caller should override 2 from the hook.
        # timeout_in_ms=60000 from the hook should override from the caller.
        # output_partition_graph=True from the caller should override False
        # from the hook.
        # The two debug watches from the caller and the hook should be merged,
        # in that order.
        self.assertEqual([
            config_pb2.RunOptions(
                trace_level=3,
                timeout_in_ms=60000,
                output_partition_graphs=True,
                debug_options=debug_pb2.DebugOptions(
                    debug_tensor_watch_opts=[caller_watch, hook_watch]),
                report_tensor_allocations_upon_oom=True),
        ], hook.run_options_list)
        self.assertEqual(1, len(hook.run_metadata_list))
        self.assertTrue(
            isinstance(hook.run_metadata_list[0], config_pb2.RunMetadata))
        self.assertGreater(len(hook.run_metadata_list[0].partition_graphs), 0)

  @test_util.run_deprecated_v1
  def test_with_statement_and_close(self):
    # Test case for https://github.com/tensorflow/tensorflow/issues/12224
    # where close() inside the with should have a better error message.
    with self.assertRaisesRegex(RuntimeError, 'Session is already closed'):
      with monitored_session.MonitoredSession() as session:
        session.close()

  def test_step_fn_example(self):
    with ops.Graph().as_default():
      c = array_ops.placeholder(dtypes.float32)
      v = array_ops.identity(c)

      def step_fn(step_context):
        value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
        return value

      with monitored_session.MonitoredSession() as session:
        self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)

  def test_step_function_stops(self):
    with ops.Graph().as_default():

      def step_fn(step_context):
        step_context.request_stop()

      with monitored_session.MonitoredSession() as session:
        self.assertEqual(None, session.run_step_fn(step_fn))
        self.assertTrue(session.should_stop())

  def test_step_request_stop_without_a_with_block(self):
    with ops.Graph().as_default():
      was_stop_iteration_raised = False

      def step_fn(step_context):
        step_context.request_stop()

      session = monitored_session.MonitoredSession()
      try:
        self.assertEqual(None, session.run_step_fn(step_fn))
      except StopIteration:
        was_stop_iteration_raised = True

      self.assertTrue(was_stop_iteration_raised)
      self.assertFalse(session.should_stop())

  def test_step_request_stop_in_a_loop(self):
    with ops.Graph().as_default():
      def step_fn(step_context):
        step_context.request_stop()

      with monitored_session.MonitoredSession() as session:
        while not session.should_stop():
          _ = session.run_step_fn(step_fn)
          self.fail('An exception should be raised on the line above.')

  def test_step_request_stop_with_returning_a_type(self):
    with ops.Graph().as_default():

      def step_fn(step_context):
        del step_context
        return 'a type'

      with monitored_session.MonitoredSession() as session:
        self.assertEqual('a type', session.run_step_fn(step_fn))

  def test_step_with_extra_arguments(self):
    with ops.Graph().as_default():

      def step_fn(step_context, extra_foo):
        del step_context, extra_foo

      with monitored_session.MonitoredSession() as session:
        with self.assertRaisesRegex(
            ValueError,
            '`step_fn` may either have one `step_context` argument'):
          self.assertEqual(None, session.run_step_fn(step_fn))

  def test_step_fn_belongs_to_a_class(self):
    with ops.Graph().as_default():
      c = array_ops.placeholder(dtypes.float32)
      v = array_ops.identity(c)

      class Model:

        def step_fn(self, step_context):
          return step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})

      with monitored_session.MonitoredSession() as session:
        model = Model()
        self.assertNear(3.2, session.run_step_fn(model.step_fn), 0.1)

  def test_step_fn_belongs_to_a_class_and_has_extra_methods(self):
    with ops.Graph().as_default():

      class Model:

        def step_fn(self, step_context, extra_foo):
          del step_context, extra_foo

      with monitored_session.MonitoredSession() as session:
        with self.assertRaisesRegex(
            ValueError,
            '`step_fn` may either have one `step_context` argument'):
          model = Model()
          self.assertEqual(None, session.run_step_fn(model.step_fn))

  def test_step_fn_with_hooks(self):
    with ops.Graph().as_default():
      var = resource_variable_ops.ResourceVariable(0.0)

      # This test highlights the interaction of hooks with
      # `Monitoredsession.run_step_fn`.  The order of execution of operations
      # below is:
      #   0.  stage_0
      #   1.  stage_1_0 or stage_1_1 in an undefined order
      #   2.  stage_2

      stage_0 = state_ops.assign_add(var, 0.3)
      stage_1_0 = state_ops.assign_add(var, 0.7)
      # The order of `stage_1_0` and `stage_1_1` is undefined by
      # `MonitoredSession`, but we should be able to assert when both of them
      # are complete.  To obtain a consistent result of adding two different
      # constants to `var`, we rely on a control dependency and
      # `ResourceVariable`.  Otherwise, it is possible that one of the
      # additions overwrites the result of the other addition.
      with ops.control_dependencies([stage_1_0]):
        stage_1_1 = state_ops.assign_add(var, 0.5)
      stage_2 = state_ops.assign_add(var, 1.1)

      class Hook(session_run_hook.SessionRunHook):

        def __init__(self, testing):
          self._testing = testing

        def before_run(self, run_context):
          return session_run_hook.SessionRunArgs(fetches=stage_1_0)

        def after_run(self, run_context, run_values):
          self._testing.assertNear(0.3 + 0.5 + 0.7,
                                   run_context.session.run(var), 0.1)
          self._testing.assertNear(0.3 + 0.5 + 0.7 + 1.1,
                                   run_context.session.run(stage_2), 0.1)

      def step_fn(step_context):
        self.assertNear(0.3, step_context.session.run(stage_0), 0.1)
        return step_context.run_with_hooks(fetches=stage_1_1)

      with monitored_session.MonitoredSession(hooks=[Hook(self)]) as session:
        self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn))

  def test_step_fn_has_the_same_hooks_behavior_without_recovery(self):
    with ops.Graph().as_default():
      var = resource_variable_ops.ResourceVariable(0.0)

      stage_0 = state_ops.assign_add(var, 0.3)
      stage_1_0 = state_ops.assign_add(var, 0.7)
      with ops.control_dependencies([stage_1_0]):
        stage_1_1 = state_ops.assign_add(var, 0.5)
      stage_2 = state_ops.assign_add(var, 1.1)

      class Hook(session_run_hook.SessionRunHook):

        def __init__(self, testing):
          self._testing = testing

        def before_run(self, run_context):
          return session_run_hook.SessionRunArgs(fetches=stage_1_0)

        def after_run(self, run_context, run_values):
          self._testing.assertNear(0.3 + 0.5 + 0.7,
                                   run_context.session.run(var), 0.1)
          self._testing.assertNear(0.3 + 0.5 + 0.7 + 1.1,
                                   run_context.session.run(stage_2), 0.1)

      def step_fn(step_context):
        self.assertNear(0.3, step_context.session.run(stage_0), 0.1)
        return step_context.run_with_hooks(fetches=stage_1_1)

      with monitored_session.SingularMonitoredSession(
          hooks=[Hook(self)]) as session:
        self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn))

  def test_step_fn_with_hooks_and_request_stop(self):
    with ops.Graph().as_default():
      trace_the_hook = {'before_run': False, 'after_run': False}

      class Hook(session_run_hook.SessionRunHook):

        def before_run(self, run_context):
          trace_the_hook['before_run'] = True

        def after_run(self, run_context, run_values):
          trace_the_hook['after_run'] = True

      def step_fn(step_context):
        step_context.request_stop()

      with monitored_session.MonitoredSession(hooks=[Hook()]) as session:
        self.assertEqual(None, session.run_step_fn(step_fn))
        self.assertTrue(session.should_stop())
        # `step_context.request_stop()` in a step_fn interrupts the flow of
        # running the hooks.
        self.assertFalse(trace_the_hook['before_run'])
        self.assertFalse(trace_the_hook['after_run'])

  def test_recovers_from_an_exception_in_step_fn(self):
    trace_the_exception = {'run_already': False}

    with ops.Graph().as_default():
      c = array_ops.placeholder(dtypes.float32)
      v = array_ops.identity(c)

      def step_fn(step_context):
        if not trace_the_exception['run_already']:
          trace_the_exception['run_already'] = True
          raise errors_impl.AbortedError(None, None, 'Abort')

        return step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})

      with monitored_session.MonitoredSession() as session:
        self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
      self.assertTrue(trace_the_exception['run_already'])

  def test_recovers_from_an_exception_in_step_fn_after_hooks(self):
    trace_the_exception = {'run_already': False, 'side_effect_counter': 0}

    with ops.Graph().as_default():
      c = array_ops.placeholder(dtypes.float32)
      v = array_ops.identity(c)
      graph_state = variable_v1.VariableV1(0.0)
      graph_side_effect = state_ops.assign_add(graph_state, 0.31)

      def step_fn(step_context):
        trace_the_exception['side_effect_counter'] += 1
        step_context.session.run(graph_side_effect)

        value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})

        if not trace_the_exception['run_already']:
          trace_the_exception['run_already'] = True
          raise errors_impl.AbortedError(None, None, 'Abort')

        return value

      with self.cached_session() as test_session:
        with monitored_session.MonitoredSession(
            CountingSessionCreator(test_session)) as session:
          session.run(variables.global_variables_initializer())

          self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
          self.assertTrue(trace_the_exception['run_already'])
          # Make sure the rest of the body of the step_fn is re-executed upon
          # AbortedError:
          self.assertEqual(2, trace_the_exception['side_effect_counter'])
          self.assertNear(0.62, session.run(graph_state), 0.1)

  def test_step_fn_doesnt_recover_when_it_wasnt_asked_to(self):
    trace_the_exception = {'run_already': False}

    with ops.Graph().as_default():
      c = array_ops.placeholder(dtypes.float32)
      v = array_ops.identity(c)

      def step_fn(step_context):
        if not trace_the_exception['run_already']:
          trace_the_exception['run_already'] = True
          raise errors_impl.AbortedError(None, None, 'Abort')

        value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
        return value

      with monitored_session.SingularMonitoredSession() as session:
        with self.assertRaisesRegex(errors_impl.AbortedError, 'Abort'):
          self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
          self.fail()

      self.assertTrue(trace_the_exception['run_already'])

  def test_step_fn_exception_from_before_run(self):
    trace_the_exception = {'run_already': False, 'side_effect_counter': 0}

    with ops.Graph().as_default():
      c = array_ops.placeholder(dtypes.float32)
      v = array_ops.identity(c)
      vv = constant_op.constant(3.2)
      graph_state = variable_v1.VariableV1(0.0)
      graph_side_effect = state_ops.assign_add(graph_state, 0.31)

      class Hook(session_run_hook.SessionRunHook):

        def __init__(self, testing):
          self._testing = testing

        def before_run(self, run_context):
          if not trace_the_exception['run_already']:
            trace_the_exception['run_already'] = True
            raise errors_impl.AbortedError(None, None, 'Abort')
          return session_run_hook.SessionRunArgs(fetches=vv)

        def after_run(self, run_context, run_values):
          self._testing.assertNear(3.2, run_values.results, 0.1)

      def step_fn(step_context):
        trace_the_exception['side_effect_counter'] += 1
        step_context.session.run(graph_side_effect)
        return step_context.run_with_hooks(fetches=v, feed_dict={c: 1.3})

      with self.cached_session() as test_session:
        with monitored_session.MonitoredSession(
            CountingSessionCreator(test_session),
            hooks=[Hook(self)]) as session:
          test_session.run(variables.global_variables_initializer())
          self.assertNear(1.3, session.run_step_fn(step_fn), 0.1)
          self.assertEqual(2, trace_the_exception['side_effect_counter'])
          self.assertNear(0.62, session.run(graph_state), 0.1)


class SingularMonitoredSessionTest(test.TestCase):
  """Tests SingularMonitoredSession."""

  def test_handles_initialization(self):
    with ops.Graph().as_default():
      a_var = variable_v1.VariableV1(0)
      with monitored_session.SingularMonitoredSession() as session:
        # If it's not initialized, following statement raises an error.
        self.assertEqual(0, session.run(a_var))

  def test_do_not_handle_aborted_error(self):
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()

      class _RaiseAbortedHook(session_run_hook.SessionRunHook):

        def before_run(self, run_context):
          raise errors_impl.AbortedError(None, None, 'Abort')

      with monitored_session.SingularMonitoredSession(
          hooks=[_RaiseAbortedHook()]) as session:
        with self.assertRaises(errors_impl.AbortedError):
          self.assertEqual(0, session.run(gstep))

      with self.assertRaises(errors_impl.AbortedError):
        with monitored_session.SingularMonitoredSession(
            hooks=[_RaiseAbortedHook()]) as session:
          self.assertEqual(0, session.run(gstep))

  def test_exit_cleanly_on_out_of_range_exception(self):
    # Tests that we stop cleanly when OutOfRange is raised.
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      hook = RaiseOnceAtCountN(2, errors_impl.OutOfRangeError(None, None,
                                                              'EOI'))
      session = monitored_session.SingularMonitoredSession(hooks=[hook])
      # session should cleanly exit from the context.
      with session:
        self.assertEqual(0, session.run(gstep))
        self.assertFalse(session.should_stop())
        # Here at step 1, the hook triggers and raises OutOfRange. The
        # session should go into should_stop() mode. It should raise the
        # exception. So next step should not be executed.
        session.run(do_step)
        self.assertTrue(False)
      self.assertTrue(session.should_stop())

  def test_regular_exception_reported_to_coord_pass_through_run(self):
    # Tests that regular exceptions reported to the coordinator from a thread
    # passes through a "run()" call within a "with MonitoredSession" block and
    # set the session in stop mode.
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      session = monitored_session.SingularMonitoredSession()
      run_performed_without_error = False
      with self.assertRaisesRegex(RuntimeError, 'a thread wants to stop'):
        with session:
          self.assertEqual(0, session.run(gstep))
          # Report an exception through the coordinator.
          try:
            raise RuntimeError('a thread wants to stop')
          except RuntimeError as e:
            session._coordinated_creator.coord.request_stop(e)
          # Call run() which should perform normally.
          self.assertEqual(0, session.run(gstep))
          run_performed_without_error = True
      self.assertTrue(run_performed_without_error)

  def test_stop_cleanly_when_no_exception_in_with_body(self):
    # Tests that regular exceptions pass through
    with ops.Graph().as_default():
      gstep = training_util.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      session = monitored_session.SingularMonitoredSession()
      with session:
        self.assertEqual(1, session.run(do_step))
        self.assertEqual(2, session.run(do_step))
        self.assertFalse(session.should_stop())
      # Should have closed.
      self.assertTrue(session.should_stop())
      self.assertEqual(None, session.raw_session())

  def test_graph(self):
    with ops.Graph().as_default() as g:
      with monitored_session.SingularMonitoredSession() as session:
        self.assertEqual(g, session.graph)

  def test_raw_session(self):
    with ops.Graph().as_default():
      with monitored_session.SingularMonitoredSession() as session:
        self.assertTrue(isinstance(session.raw_session(), session_lib.Session))


if __name__ == '__main__':
  test.main()