tensorflow/models

View on GitHub
official/modeling/tf_utils.py

Summary

Maintainability
A
3 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common TF utilities."""

import functools
import inspect
import six
import tensorflow as tf, tf_keras

from tensorflow.python.util import deprecation
from official.modeling import activations


@deprecation.deprecated(
    None,
    "tf_keras.layers.Layer supports multiple positional args and kwargs as "
    "input tensors. pack/unpack inputs to override __call__ is no longer "
    "needed.")
def pack_inputs(inputs):
  """Pack a list of `inputs` tensors to a tuple.

  Args:
    inputs: a list of tensors.

  Returns:
    a tuple of tensors. if any input is None, replace it with a special constant
    tensor.
  """
  inputs = tf.nest.flatten(inputs)
  outputs = []
  for x in inputs:
    if x is None:
      outputs.append(tf.constant(0, shape=[], dtype=tf.int32))
    else:
      outputs.append(x)
  return tuple(outputs)


@deprecation.deprecated(
    None,
    "tf_keras.layers.Layer supports multiple positional args and kwargs as "
    "input tensors. pack/unpack inputs to override __call__ is no longer "
    "needed.")
def unpack_inputs(inputs):
  """unpack a tuple of `inputs` tensors to a tuple.

  Args:
    inputs: a list of tensors.

  Returns:
    a tuple of tensors. if any input is a special constant tensor, replace it
    with None.
  """
  inputs = tf.nest.flatten(inputs)
  outputs = []
  for x in inputs:
    if is_special_none_tensor(x):
      outputs.append(None)
    else:
      outputs.append(x)
  x = tuple(outputs)

  # To trick the very pointless 'unbalanced-tuple-unpacking' pylint check
  # from triggering.
  if len(x) == 1:
    return x[0]
  return tuple(outputs)


def is_special_none_tensor(tensor):
  """Checks if a tensor is a special None Tensor."""
  return tensor.shape.ndims == 0 and tensor.dtype == tf.int32


def get_activation(identifier, use_keras_layer=False, **kwargs):
  """Maps an identifier to a Python function, e.g., "relu" => `tf.nn.relu`.

  It checks string first and if it is one of customized activation not in TF,
  the corresponding activation will be returned. For non-customized activation
  names and callable identifiers, always fallback to tf_keras.activations.get.

  Prefers using keras layers when use_keras_layer=True. Now it only supports
  'relu', 'linear', 'identity', 'swish', 'mish', 'leaky_relu', and 'gelu'.

  Args:
    identifier: String name of the activation function or callable.
    use_keras_layer: If True, use keras layer if identifier is allow-listed.
    **kwargs: Keyword arguments to use to instantiate an activation function.
      Available only for 'leaky_relu' and 'gelu' when using keras layers.
      For example: get_activation('leaky_relu', use_keras_layer=True, alpha=0.1)

  Returns:
    A Python function corresponding to the activation function or a keras
    activation layer when use_keras_layer=True.
  """
  if isinstance(identifier, six.string_types):
    identifier = str(identifier).lower()
    if use_keras_layer:
      keras_layer_allowlist = {
          "relu": "relu",
          "linear": "linear",
          "identity": "linear",
          "swish": "swish",
          "sigmoid": "sigmoid",
          "relu6": tf.nn.relu6,
          "leaky_relu": functools.partial(tf.nn.leaky_relu, **kwargs),
          "hard_swish": activations.hard_swish,
          "hard_sigmoid": activations.hard_sigmoid,
          "mish": activations.mish,
          "gelu": functools.partial(tf.nn.gelu, **kwargs),
      }
      if identifier in keras_layer_allowlist:
        return tf_keras.layers.Activation(keras_layer_allowlist[identifier])
    name_to_fn = {
        "gelu": activations.gelu,
        "simple_swish": activations.simple_swish,
        "hard_swish": activations.hard_swish,
        "relu6": activations.relu6,
        "hard_sigmoid": activations.hard_sigmoid,
        "identity": activations.identity,
        "mish": activations.mish,
    }
    if identifier in name_to_fn:
      return tf_keras.activations.get(name_to_fn[identifier])
  return tf_keras.activations.get(identifier)


def get_shape_list(tensor, expected_rank=None, name=None):
  """Returns a list of the shape of tensor, preferring static dimensions.

  Args:
    tensor: A tf.Tensor object to find the shape of.
    expected_rank: (optional) int. The expected rank of `tensor`. If this is
      specified and the `tensor` has a different rank, and exception will be
      thrown.
    name: Optional name of the tensor for the error message.

  Returns:
    A list of dimensions of the shape of tensor. All static dimensions will
    be returned as python integers, and dynamic dimensions will be returned
    as tf.Tensor scalars.
  """
  if expected_rank is not None:
    assert_rank(tensor, expected_rank, name)

  shape = tensor.shape.as_list()

  non_static_indexes = []
  for (index, dim) in enumerate(shape):
    if dim is None:
      non_static_indexes.append(index)

  if not non_static_indexes:
    return shape

  dyn_shape = tf.shape(tensor)
  for index in non_static_indexes:
    shape[index] = dyn_shape[index]
  return shape


def assert_rank(tensor, expected_rank, name=None):
  """Raises an exception if the tensor rank is not of the expected rank.

  Args:
    tensor: A tf.Tensor to check the rank of.
    expected_rank: Python integer or list of integers, expected rank.
    name: Optional name of the tensor for the error message.

  Raises:
    ValueError: If the expected shape doesn't match the actual shape.
  """
  expected_rank_dict = {}
  if isinstance(expected_rank, six.integer_types):
    expected_rank_dict[expected_rank] = True
  else:
    for x in expected_rank:
      expected_rank_dict[x] = True

  actual_rank = tensor.shape.ndims
  if actual_rank not in expected_rank_dict:
    raise ValueError(
        "For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not "
        "equal to the expected tensor rank `%s`" %
        (name, actual_rank, str(tensor.shape), str(expected_rank)))


def safe_mean(losses):
  """Computes a safe mean of the losses.

  Args:
    losses: `Tensor` whose elements contain individual loss measurements.

  Returns:
    A scalar representing the mean of `losses`. If `num_present` is zero,
      then zero is returned.
  """
  total = tf.reduce_sum(losses)
  num_elements = tf.cast(tf.size(losses), dtype=losses.dtype)
  return tf.math.divide_no_nan(total, num_elements)


def get_replica_id():
  """Gets replica id depending on the environment."""
  context = tf.distribute.get_replica_context()
  if context is not None:
    return context.replica_id_in_sync_group
  else:
    raise RuntimeError("Unknown replica context. The `get_replica_id` method "
                       "relies on TF 2.x tf.distribute API.")


def cross_replica_concat(value, axis, name="cross_replica_concat"):
  """Concatenates the given `value` across (GPU/TPU) cores, along `axis`.

  In general, each core ("replica") will pass a
  replica-specific value as `value` (corresponding to some element of a
  data-parallel computation taking place across replicas).

  The resulting concatenated `Tensor` will have the same shape as `value` for
  all dimensions except `axis`, where it will be larger by a factor of the
  number of replicas. It will also have the same `dtype` as `value`.

  The position of a given replica's `value` within the resulting concatenation
  is determined by that replica's replica ID. For
  example:

  With `value` for replica 0 given as

      0 0 0
      0 0 0

  and `value` for replica 1 given as

      1 1 1
      1 1 1

  the resulting concatenation along axis 0 will be

      0 0 0
      0 0 0
      1 1 1
      1 1 1

  and this result will be identical across all replicas.

  Note that this API only works in TF2 with `tf.distribute`.

  Args:
    value: The `Tensor` to concatenate across replicas. Each replica will have a
      different value for this `Tensor`, and these replica-specific values will
      be concatenated.
    axis: The axis along which to perform the concatenation as a Python integer
      (not a `Tensor`). E.g., `axis=0` to concatenate along the batch dimension.
    name: A name for the operation (used to create a name scope).

  Returns:
    The result of concatenating `value` along `axis` across replicas.

  Raises:
    RuntimeError: when the batch (0-th) dimension is None.
  """
  with tf.name_scope(name):
    context = tf.distribute.get_replica_context()
    # Typically this could be hit only if the tensor is derived from a
    # dataset with finite epochs and drop_remainder=False, where the last
    # batch could of different batch size and then the dim-0 is of dynamic
    # shape.
    if value.shape.as_list()[0] is None:
      raise RuntimeError(f"{value} has unknown batch.")
    return context.all_gather(value, axis=axis)


def clone_initializer(initializer):
  # Keras initializer is going to be stateless, which mean reusing the same
  # initializer will produce same init value when the shapes are the same.
  if isinstance(initializer, tf_keras.initializers.Initializer):
    return initializer.__class__.from_config(initializer.get_config())
  # When the input is string/dict or other serialized configs, caller will
  # create a new keras Initializer instance based on that, and we don't need to
  # do anything
  return initializer


def serialize_keras_object(obj):
  if hasattr(tf_keras.utils, "legacy"):
    return tf_keras.utils.legacy.serialize_keras_object(obj)
  else:
    return tf_keras.utils.serialize_keras_object(obj)


def deserialize_keras_object(
    config, module_objects=None, custom_objects=None, printable_module_name=None
):
  if hasattr(tf_keras.utils, "legacy"):
    return tf_keras.utils.legacy.deserialize_keras_object(
        config, custom_objects, module_objects, printable_module_name
    )
  else:
    return tf_keras.utils.deserialize_keras_object(
        config, custom_objects, module_objects, printable_module_name
    )


def serialize_layer(layer, use_legacy_format=False):
  if (
      "use_legacy_format"
      in inspect.getfullargspec(tf_keras.layers.serialize).args
  ):
    return tf_keras.layers.serialize(layer, use_legacy_format=use_legacy_format)
  else:
    return tf_keras.layers.serialize(layer)


def serialize_initializer(initializer, use_legacy_format=False):
  if (
      "use_legacy_format"
      in inspect.getfullargspec(tf_keras.initializers.serialize).args
  ):
    return tf_keras.initializers.serialize(
        initializer, use_legacy_format=use_legacy_format
    )
  else:
    return tf_keras.initializers.serialize(initializer)


def serialize_regularizer(regularizer, use_legacy_format=False):
  if (
      "use_legacy_format"
      in inspect.getfullargspec(tf_keras.regularizers.serialize).args
  ):
    return tf_keras.regularizers.serialize(
        regularizer, use_legacy_format=use_legacy_format
    )
  else:
    return tf_keras.regularizers.serialize(regularizer)


def serialize_constraint(constraint, use_legacy_format=False):
  if (
      "use_legacy_format"
      in inspect.getfullargspec(tf_keras.constraints.serialize).args
  ):
    return tf_keras.constraints.serialize(
        constraint, use_legacy_format=use_legacy_format
    )
  else:
    return tf_keras.constraints.serialize(constraint)


def serialize_activation(activation, use_legacy_format=False):
  if (
      "use_legacy_format"
      in inspect.getfullargspec(tf_keras.activations.serialize).args
  ):
    return tf_keras.activations.serialize(
        activation, use_legacy_format=use_legacy_format
    )
  else:
    return tf_keras.activations.serialize(activation)