tensorflow/models

View on GitHub
official/core/train_lib_test.py

Summary

Maintainability
B
4 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for train_ctl_lib."""
import json
import os

from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import numpy as np
import tensorflow as tf, tf_keras

from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.common import flags as tfm_flags
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.utils.testing import mock_task

FLAGS = flags.FLAGS

tfm_flags.define_flags()


class TrainTest(tf.test.TestCase, parameterized.TestCase):

  def setUp(self):
    super(TrainTest, self).setUp()
    self._test_config = {
        'trainer': {
            'checkpoint_interval': 10,
            'steps_per_loop': 10,
            'summary_interval': 10,
            'train_steps': 10,
            'validation_steps': 5,
            'validation_interval': 10,
            'continuous_eval_timeout': 1,
            'validation_summary_subdir': 'validation',
            'optimizer_config': {
                'optimizer': {
                    'type': 'sgd',
                },
                'learning_rate': {
                    'type': 'constant'
                }
            }
        },
    }

  @combinations.generate(
      combinations.combine(
          distribution_strategy=[
              strategy_combinations.default_strategy,
              strategy_combinations.cloud_tpu_strategy,
              strategy_combinations.one_device_strategy_gpu,
          ],
          flag_mode=['train', 'eval', 'train_and_eval'],
          run_post_eval=[True, False]))
  def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
    model_dir = self.get_temp_dir()
    flags_dict = dict(
        experiment='mock',
        mode=flag_mode,
        model_dir=model_dir,
        params_override=json.dumps(self._test_config))
    with flagsaver.flagsaver(**flags_dict):
      params = train_utils.parse_configuration(flags.FLAGS)
      train_utils.serialize_config(params, model_dir)
      with distribution_strategy.scope():
        task = task_factory.get_task(params.task, logging_dir=model_dir)

      _, logs = train_lib.run_experiment(
          distribution_strategy=distribution_strategy,
          task=task,
          mode=flag_mode,
          params=params,
          model_dir=model_dir,
          run_post_eval=run_post_eval)

    if 'eval' in flag_mode:
      self.assertTrue(
          tf.io.gfile.exists(
              os.path.join(model_dir,
                           params.trainer.validation_summary_subdir)))
    if run_post_eval:
      self.assertNotEmpty(logs)
    else:
      self.assertEmpty(logs)
    self.assertNotEmpty(
        tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
    if flag_mode == 'eval':
      return
    self.assertNotEmpty(
        tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
    # Tests continuous evaluation.
    _, logs = train_lib.run_experiment(
        distribution_strategy=distribution_strategy,
        task=task,
        mode='continuous_eval',
        params=params,
        model_dir=model_dir,
        run_post_eval=run_post_eval)

  @combinations.generate(
      combinations.combine(
          distribution_strategy=[
              strategy_combinations.default_strategy,
              strategy_combinations.cloud_tpu_strategy,
              strategy_combinations.one_device_strategy_gpu,
          ],
          flag_mode=['train', 'eval', 'train_and_eval'],
          run_post_eval=[True, False]))
  def test_end_to_end_class(self, distribution_strategy, flag_mode,
                            run_post_eval):
    model_dir = self.get_temp_dir()
    flags_dict = dict(
        experiment='mock',
        mode=flag_mode,
        model_dir=model_dir,
        params_override=json.dumps(self._test_config))
    with flagsaver.flagsaver(**flags_dict):
      params = train_utils.parse_configuration(flags.FLAGS)
      train_utils.serialize_config(params, model_dir)
      with distribution_strategy.scope():
        task = task_factory.get_task(params.task, logging_dir=model_dir)

      _, logs = train_lib.OrbitExperimentRunner(
          distribution_strategy=distribution_strategy,
          task=task,
          mode=flag_mode,
          params=params,
          model_dir=model_dir,
          run_post_eval=run_post_eval).run()

    if 'eval' in flag_mode:
      self.assertTrue(
          tf.io.gfile.exists(
              os.path.join(model_dir,
                           params.trainer.validation_summary_subdir)))
    if run_post_eval:
      self.assertNotEmpty(logs)
    else:
      self.assertEmpty(logs)
    self.assertNotEmpty(
        tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
    if flag_mode == 'eval':
      return
    self.assertNotEmpty(
        tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
    # Tests continuous evaluation.
    _, logs = train_lib.OrbitExperimentRunner(
        distribution_strategy=distribution_strategy,
        task=task,
        mode='continuous_eval',
        params=params,
        model_dir=model_dir,
        run_post_eval=run_post_eval).run()

  @combinations.generate(
      combinations.combine(
          distribution_strategy=[
              strategy_combinations.default_strategy,
              strategy_combinations.cloud_tpu_strategy,
              strategy_combinations.one_device_strategy_gpu,
          ],
          flag_mode=['train', 'train_and_eval'],
      ))
  def test_recovery_nan_error(self, distribution_strategy, flag_mode):
    model_dir = self.get_temp_dir()
    flags_dict = dict(
        experiment='mock',
        mode=flag_mode,
        model_dir=model_dir,
        params_override=json.dumps(self._test_config))
    with flagsaver.flagsaver(**flags_dict):
      params = train_utils.parse_configuration(flags.FLAGS)
      train_utils.serialize_config(params, model_dir)
      with distribution_strategy.scope():
        # task = task_factory.get_task(params.task, logging_dir=model_dir)
        task = mock_task.MockTask(params.task, logging_dir=model_dir)

        # Set the loss to NaN to trigger RunTimeError.
        def build_losses(labels, model_outputs, aux_losses=None):
          del labels, model_outputs
          return tf.constant([np.nan], tf.float32) + aux_losses

        task.build_losses = build_losses

      with self.assertRaises(RuntimeError):
        train_lib.OrbitExperimentRunner(
            distribution_strategy=distribution_strategy,
            task=task,
            mode=flag_mode,
            params=params,
            model_dir=model_dir).run()

  @combinations.generate(
      combinations.combine(
          distribution_strategy=[
              strategy_combinations.default_strategy,
              strategy_combinations.cloud_tpu_strategy,
              strategy_combinations.one_device_strategy_gpu,
          ],
          flag_mode=['train'],
      ))
  def test_recovery(self, distribution_strategy, flag_mode):
    loss_threshold = 1.0
    model_dir = self.get_temp_dir()
    flags_dict = dict(
        experiment='mock',
        mode=flag_mode,
        model_dir=model_dir,
        params_override=json.dumps(self._test_config))
    with flagsaver.flagsaver(**flags_dict):
      params = train_utils.parse_configuration(flags.FLAGS)
      params.trainer.loss_upper_bound = loss_threshold
      params.trainer.recovery_max_trials = 1
      train_utils.serialize_config(params, model_dir)
      with distribution_strategy.scope():
        task = task_factory.get_task(params.task, logging_dir=model_dir)

      # Saves a checkpoint for reference.
      model = task.build_model()
      checkpoint = tf.train.Checkpoint(model=model)
      checkpoint_manager = tf.train.CheckpointManager(
          checkpoint, self.get_temp_dir(), max_to_keep=2)
      checkpoint_manager.save()
      before_weights = model.get_weights()

      def build_losses(labels, model_outputs, aux_losses=None):
        del labels, model_outputs
        return tf.constant([loss_threshold], tf.float32) + aux_losses

      task.build_losses = build_losses

      model, _ = train_lib.OrbitExperimentRunner(
          distribution_strategy=distribution_strategy,
          task=task,
          mode=flag_mode,
          params=params,
          model_dir=model_dir).run()
      after_weights = model.get_weights()
      for left, right in zip(before_weights, after_weights):
        self.assertAllEqual(left, right)

  def test_parse_configuration(self):
    model_dir = self.get_temp_dir()
    flags_dict = dict(
        experiment='mock',
        mode='train',
        model_dir=model_dir,
        params_override=json.dumps(self._test_config))
    with flagsaver.flagsaver(**flags_dict):
      params = train_utils.parse_configuration(flags.FLAGS, lock_return=True)
      with self.assertRaises(ValueError):
        params.override({'task': {'init_checkpoint': 'Foo'}})

      params = train_utils.parse_configuration(flags.FLAGS, lock_return=False)
      params.override({'task': {'init_checkpoint': 'Bar'}})
      self.assertEqual(params.task.init_checkpoint, 'Bar')


if __name__ == '__main__':
  tf.test.main()