tensorflow/models

View on GitHub
official/utils/flags/_base.py

Summary

Maintainability
C
1 day
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Flags which will be nearly universal across models."""

from absl import flags
import tensorflow as tf, tf_keras
from official.utils.flags._conventions import help_wrap


def define_base(data_dir=True,
                model_dir=True,
                clean=False,
                train_epochs=False,
                epochs_between_evals=False,
                stop_threshold=False,
                batch_size=True,
                num_gpu=False,
                hooks=False,
                export_dir=False,
                distribution_strategy=False,
                run_eagerly=False):
  """Register base flags.

  Args:
    data_dir: Create a flag for specifying the input data directory.
    model_dir: Create a flag for specifying the model file directory.
    clean: Create a flag for removing the model_dir.
    train_epochs: Create a flag to specify the number of training epochs.
    epochs_between_evals: Create a flag to specify the frequency of testing.
    stop_threshold: Create a flag to specify a threshold accuracy or other eval
      metric which should trigger the end of training.
    batch_size: Create a flag to specify the batch size.
    num_gpu: Create a flag to specify the number of GPUs used.
    hooks: Create a flag to specify hooks for logging.
    export_dir: Create a flag to specify where a SavedModel should be exported.
    distribution_strategy: Create a flag to specify which Distribution Strategy
      to use.
    run_eagerly: Create a flag to specify to run eagerly op by op.

  Returns:
    A list of flags for core.py to marks as key flags.
  """
  key_flags = []

  if data_dir:
    flags.DEFINE_string(
        name="data_dir",
        short_name="dd",
        default="/tmp",
        help=help_wrap("The location of the input data."))
    key_flags.append("data_dir")

  if model_dir:
    flags.DEFINE_string(
        name="model_dir",
        short_name="md",
        default="/tmp",
        help=help_wrap("The location of the model checkpoint files."))
    key_flags.append("model_dir")

  if clean:
    flags.DEFINE_boolean(
        name="clean",
        default=False,
        help=help_wrap("If set, model_dir will be removed if it exists."))
    key_flags.append("clean")

  if train_epochs:
    flags.DEFINE_integer(
        name="train_epochs",
        short_name="te",
        default=1,
        help=help_wrap("The number of epochs used to train."))
    key_flags.append("train_epochs")

  if epochs_between_evals:
    flags.DEFINE_integer(
        name="epochs_between_evals",
        short_name="ebe",
        default=1,
        help=help_wrap("The number of training epochs to run between "
                       "evaluations."))
    key_flags.append("epochs_between_evals")

  if stop_threshold:
    flags.DEFINE_float(
        name="stop_threshold",
        short_name="st",
        default=None,
        help=help_wrap("If passed, training will stop at the earlier of "
                       "train_epochs and when the evaluation metric is  "
                       "greater than or equal to stop_threshold."))

  if batch_size:
    flags.DEFINE_integer(
        name="batch_size",
        short_name="bs",
        default=32,
        help=help_wrap("Batch size for training and evaluation. When using "
                       "multiple gpus, this is the global batch size for "
                       "all devices. For example, if the batch size is 32 "
                       "and there are 4 GPUs, each GPU will get 8 examples on "
                       "each step."))
    key_flags.append("batch_size")

  if num_gpu:
    flags.DEFINE_integer(
        name="num_gpus",
        short_name="ng",
        default=1,
        help=help_wrap("How many GPUs to use at each worker with the "
                       "DistributionStrategies API. The default is 1."))

  if run_eagerly:
    flags.DEFINE_boolean(
        name="run_eagerly",
        default=False,
        help="Run the model op by op without building a model function.")

  if hooks:
    flags.DEFINE_list(
        name="hooks",
        short_name="hk",
        default="LoggingTensorHook",
        help=help_wrap(
            u"A list of (case insensitive) strings to specify the names of "
            u"training hooks. Example: `--hooks ProfilerHook,"
            u"ExamplesPerSecondHook`\n See hooks_helper "
            u"for details."))
    key_flags.append("hooks")

  if export_dir:
    flags.DEFINE_string(
        name="export_dir",
        short_name="ed",
        default=None,
        help=help_wrap("If set, a SavedModel serialization of the model will "
                       "be exported to this directory at the end of training. "
                       "See the README for more details and relevant links."))
    key_flags.append("export_dir")

  if distribution_strategy:
    flags.DEFINE_string(
        name="distribution_strategy",
        short_name="ds",
        default="mirrored",
        help=help_wrap("The Distribution Strategy to use for training. "
                       "Accepted values are 'off', 'one_device', "
                       "'mirrored', 'parameter_server', 'collective', "
                       "case insensitive. 'off' means not to use "
                       "Distribution Strategy; 'default' means to choose "
                       "from `MirroredStrategy` or `OneDeviceStrategy` "
                       "according to the number of GPUs."))

  return key_flags


def get_num_gpus(flags_obj):
  """Treat num_gpus=-1 as 'use all'."""
  if flags_obj.num_gpus != -1:
    return flags_obj.num_gpus

  from tensorflow.python.client import device_lib  # pylint: disable=g-import-not-at-top
  local_device_protos = device_lib.list_local_devices()
  return sum([1 for d in local_device_protos if d.device_type == "GPU"])