tensorflow/models

View on GitHub
orbit/utils/loop_fns.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Copyright 2024 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for creating loop functions."""

from absl import logging
from orbit.utils import tpu_summaries

import tensorflow as tf, tf_keras


def create_loop_fn(step_fn):
  """Creates a loop function driven by a Python `while` loop.

  Args:
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. There are no constraints on the return value of the
      function (except that it must be compatible with any `reduce_fn` provided
      to the returned `loop_fn`).

  Returns:
    A loop function taking required `iterator` and `num_steps` parameters, as
    well as optional `state` and `reduce_fn` parameters for accumulating state
    over multiple iterations of the loop. See the `loop_fn` definition below for
    additional details.
  """

  def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
    """Makes `num_steps` calls to `step_fn(iterator)`.

    Additionally, state may be accumulated across iterations of the loop.
    Conceptually, state accumulation is handled roughly as follows:

        for _ in range(num_steps):
          step_outputs  = step_fn(iterator)
          state = reduce_fn(state, step_outputs)
        return state

    However, the implementation is slightly more complicated in order to support
    looping until the iterator is exhausted (when `num_steps == -1`) and to
    properly catch exceptions when running under async remote eager (as is the
    case in TPU training setups involving separate coordinator/worker machines).

    Args:
      iterator: A nested structure of `tf.data.Iterator` or
        `DistributedIterator`.
      num_steps: The number of steps in the loop. If `num_steps == -1`, will
        iterate until exausting the iterator.
      state: An optional initial state before running the loop.
      reduce_fn: A callable taking two inputs, `state` and `value`, where
        `state` is the previous output from `reduce_fn`, and `value` is the
        output from `step_fn`.

    Returns:
      The final state returned by `reduce_fn`, or `None` if `state` and
      `reduce_fn` are not provided.
    """
    step = 0
    try:
      # To make sure the OutOfRangeError exception can be handled well under
      # async remote eager, we need to wrap the loop body in `async_scope`.
      with tf.experimental.async_scope():
        while num_steps == -1 or step < num_steps:
          outputs = step_fn(iterator)
          if reduce_fn is not None:
            state = reduce_fn(state, outputs)
          step += 1
        return state
    except (StopIteration, tf.errors.OutOfRangeError):
      logging.info("The dataset iterator is exhausted after %d steps.", step)
      tf.experimental.async_clear_error()
      return state

  return loop_fn


def create_tf_while_loop_fn(step_fn):
  """Creates a loop function compatible with TF's AutoGraph loop conversion.

  Args:
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. Currently, any return values are ignored.

  Returns:
    A loop function taking required `iterator` and `num_steps` parameters. If
    called inside a `tf.function`, the loop will be converted by AutoGraph into
    a `tf.while_loop` construct. See the `loop_fn` definition below for
    additional details.
  """

  def loop_fn(iterator, num_steps):
    """Makes `num_steps` calls to `step_fn(iterator)`.

    Args:
      iterator: A nested structure of `tf.data.Iterator` or
        `DistributedIterator`.
      num_steps: The number of steps in the loop. Should be passed as a
        `tf.Tensor`. Iterating until iterator exhaustion is not supported.
    """
    if not isinstance(num_steps, tf.Tensor):
      raise ValueError(
          "`num_steps` should be a `tf.Tensor`. Passing a Python value can "
          "cause unnecessary retracing when wrapped by `tf.function`.")

    for _ in tf.range(num_steps):
      # Clear out the outer name scope so the ops created inside `tf.while_loop`
      # don't get "while/" as name prefix.
      with tf.name_scope(""):
        step_fn(iterator)

  return loop_fn


def create_tf_while_loop_fn_with_state(step_fn):
  """Creates a TF while loop function with state.

  This function is similar to `create_tf_while_loop_fn`, but allowing a `state`
  to be accumulated over multiple iterations of the loop. Note that the
  structure of the `state` cannot be changed across iterations.

  Args:
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. Currently, any return values are ignored.

  Returns:
    A loop function taking required `iterator`, `num_steps`, `state` and
    `reduce_fn` parameters. If called inside a `tf.function`, the loop will be
    converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn`
    definition below for additional details.
  """

  def loop_fn_with_state(iterator, num_steps, state, reduce_fn):
    """Makes `num_steps` calls to `step_fn(iterator)`.

    Args:
      iterator: A nested structure of `tf.data.Iterator` or
        `DistributedIterator`.
      num_steps: The number of steps in the loop. Should be passed as a
        `tf.Tensor`. Iterating until iterator exhaustion is not supported.
      state: An initial state before running the loop.
      reduce_fn: A callable taking two inputs, `state` and `value`, where
        `state` is the previous output from `reduce_fn`, and `value` is the
        output from `step_fn`.

    Returns:
      The final state returned by `reduce_fn`.
    """
    if not isinstance(num_steps, tf.Tensor):
      raise ValueError(
          "`num_steps` should be a `tf.Tensor`. Passing a Python value can "
          "cause unnecessary retracing when wrapped by `tf.function`.")

    def _get_relaxed_tensor_shape(t):
      """Returns a `TensorShape` with all `None` dimensions."""
      if not tf.is_tensor(t):
        return None

      shape = t.shape
      if shape.rank is not None and shape.rank > 0:
        return tf.TensorShape([None] * shape.rank)
      return shape

    def _get_relaxed_shape_structure(s):
      """Returns the relaxed shape of the input nested structure `s`."""
      return tf.nest.pack_sequence_as(
          state, [_get_relaxed_tensor_shape(t) for t in tf.nest.flatten(s)])

    for _ in tf.range(num_steps):
      # Clear out the outer name scope so the ops created inside `tf.while_loop`
      # don't get "while/" as name prefix.
      with tf.name_scope(""):
        # Relax the shapes within the loop, so the shape of `state` can change
        # across iterations. This is useful to aggregate outputs from each step
        # and concat to `state`.
        tf.autograph.experimental.set_loop_options(
            shape_invariants=[(state, _get_relaxed_shape_structure(state))])
        outputs = step_fn(iterator)
        state = reduce_fn(state, outputs)
    return state

  return loop_fn_with_state


class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction):
  """Implements a two-program approach for optimizing summaries on TPU.

  This version works with the result of `create_tf_while_loop_fn`.
  """

  def __call__(self, iterator, num_steps):
    if tf.summary.should_record_summaries():
      output = self.with_summaries(iterator, tf.constant(1))
      num_steps -= 1
    if num_steps >= 1:
      output = self.without_summaries(iterator, num_steps)
    return output