tensorflow/models

View on GitHub
official/legacy/image_classification/resnet/resnet_ctl_imagenet_main.py

Summary

Maintainability
B
6 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Runs a ResNet model on the ImageNet dataset using custom training loops."""

import math
import os

# Import libraries
from absl import app
from absl import flags
from absl import logging
import orbit
import tensorflow as tf, tf_keras
from official.common import distribute_utils
from official.legacy.image_classification.resnet import common
from official.legacy.image_classification.resnet import imagenet_preprocessing
from official.legacy.image_classification.resnet import resnet_runnable
from official.modeling import performance
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers

flags.DEFINE_boolean(name='use_tf_function', default=True,
                     help='Wrap the train and test step inside a '
                     'tf.function.')
flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
                     help='Calculate L2_loss on concatenated weights, '
                     'instead of using Keras per-layer L2 loss.')


def build_stats(runnable, time_callback):
  """Normalizes and returns dictionary of stats.

  Args:
    runnable: The module containing all the training and evaluation metrics.
    time_callback: Time tracking callback instance.

  Returns:
    Dictionary of normalized results.
  """
  stats = {}

  if not runnable.flags_obj.skip_eval:
    stats['eval_loss'] = runnable.test_loss.result().numpy()
    stats['eval_acc'] = runnable.test_accuracy.result().numpy()

    stats['train_loss'] = runnable.train_loss.result().numpy()
    stats['train_acc'] = runnable.train_accuracy.result().numpy()

  if time_callback:
    timestamp_log = time_callback.timestamp_log
    stats['step_timestamp_log'] = timestamp_log
    stats['train_finish_time'] = time_callback.train_finish_time
    if time_callback.epoch_runtime_log:
      stats['avg_exp_per_second'] = time_callback.average_examples_per_second

  return stats


def get_num_train_iterations(flags_obj):
  """Returns the number of training steps, train and test epochs."""
  train_steps = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  eval_steps = math.ceil(1.0 * imagenet_preprocessing.NUM_IMAGES['validation'] /
                         flags_obj.batch_size)

  return train_steps, train_epochs, eval_steps


def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  keras_utils.set_session_config()
  performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

  if tf.config.list_physical_devices('GPU'):
    if flags_obj.tf_gpu_thread_mode:
      keras_utils.set_gpu_thread_mode_and_count(
          per_gpu_thread_count=flags_obj.per_gpu_thread_count,
          gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
          num_gpus=flags_obj.num_gpus,
          datasets_num_private_threads=flags_obj.datasets_num_private_threads)
    common.set_cudnn_batchnorm_mode()

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
  tf_keras.backend.set_image_data_format(data_format)

  strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)

  per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
      flags_obj)
  if flags_obj.steps_per_loop is None:
    steps_per_loop = per_epoch_steps
  elif flags_obj.steps_per_loop > per_epoch_steps:
    steps_per_loop = per_epoch_steps
    logging.warn('Setting steps_per_loop to %d to respect epoch boundary.',
                 steps_per_loop)
  else:
    steps_per_loop = flags_obj.steps_per_loop

  logging.info(
      'Training %d epochs, each epoch has %d steps, '
      'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
      train_epochs * per_epoch_steps, eval_steps)

  time_callback = keras_utils.TimeHistory(
      flags_obj.batch_size,
      flags_obj.log_steps,
      logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
  with distribute_utils.get_strategy_scope(strategy):
    runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                              per_epoch_steps)

  eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
  checkpoint_interval = (
      steps_per_loop * 5 if flags_obj.enable_checkpoint_and_export else None)
  summary_interval = steps_per_loop if flags_obj.enable_tensorboard else None

  checkpoint_manager = tf.train.CheckpointManager(
      runnable.checkpoint,
      directory=flags_obj.model_dir,
      max_to_keep=10,
      step_counter=runnable.global_step,
      checkpoint_interval=checkpoint_interval)

  resnet_controller = orbit.Controller(
      strategy=strategy,
      trainer=runnable,
      evaluator=runnable if not flags_obj.skip_eval else None,
      global_step=runnable.global_step,
      steps_per_loop=steps_per_loop,
      checkpoint_manager=checkpoint_manager,
      summary_interval=summary_interval,
      summary_dir=flags_obj.model_dir,
      eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval'))

  time_callback.on_train_begin()
  if not flags_obj.skip_eval:
    resnet_controller.train_and_evaluate(
        train_steps=per_epoch_steps * train_epochs,
        eval_steps=eval_steps,
        eval_interval=eval_interval)
  else:
    resnet_controller.train(steps=per_epoch_steps * train_epochs)
  time_callback.on_train_end()

  stats = build_stats(runnable, time_callback)
  return stats


def main(_):
  model_helpers.apply_clean(flags.FLAGS)
  stats = run(flags.FLAGS)
  logging.info('Run stats:\n%s', stats)


if __name__ == '__main__':
  logging.set_verbosity(logging.INFO)
  common.define_keras_flags()
  app.run(main)