tensorflow/models

View on GitHub
official/legacy/transformer/metrics.py

Summary

Maintainability
F
5 days
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for calculating loss, accuracy, and other model metrics.

Metrics:
 - Padded loss, accuracy, and negative log perplexity. Source:
     https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/metrics.py
 - BLEU approximation. Source:
     https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
 - ROUGE score. Source:
     https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/rouge.py
"""

import functools

import tensorflow as tf, tf_keras


def _pad_tensors_to_same_length(x, y):
  """Pad x and y so that the results have the same length (second dimension)."""
  with tf.name_scope("pad_to_same_length"):
    x_length = tf.shape(x)[1]
    y_length = tf.shape(y)[1]

    max_length = tf.maximum(x_length, y_length)

    x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]])
    y = tf.pad(y, [[0, 0], [0, max_length - y_length]])
    return x, y


def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
  """Calculate cross entropy loss while ignoring padding.

  Args:
    logits: Tensor of size [batch_size, length_logits, vocab_size]
    labels: Tensor of size [batch_size, length_labels]
    smoothing: Label smoothing constant, used to determine the on and off values
    vocab_size: int size of the vocabulary

  Returns:
    Returns the cross entropy loss and weight tensors: float32 tensors with
      shape [batch_size, max(length_logits, length_labels)]
  """
  with tf.name_scope("loss"):
    logits, labels = _pad_tensors_to_same_length(logits, labels)

    # Calculate smoothing cross entropy
    with tf.name_scope("smoothing_cross_entropy"):
      confidence = 1.0 - smoothing
      low_confidence = (1.0 - confidence) / tf.cast(vocab_size - 1, tf.float32)
      soft_targets = tf.one_hot(
          tf.cast(labels, tf.int32),
          depth=vocab_size,
          on_value=confidence,
          off_value=low_confidence)
      xentropy = tf.nn.softmax_cross_entropy_with_logits(
          logits=logits, labels=soft_targets)

      # Calculate the best (lowest) possible value of cross entropy, and
      # subtract from the cross entropy loss.
      normalizing_constant = -(
          confidence * tf.math.log(confidence) +
          tf.cast(vocab_size - 1, tf.float32) * low_confidence *
          tf.math.log(low_confidence + 1e-20))
      xentropy -= normalizing_constant

    weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
    return xentropy * weights, weights


def padded_accuracy(logits, labels):
  """Percentage of times that predictions matches labels on non-0s."""
  with tf.name_scope("padded_accuracy"):
    logits, labels = _pad_tensors_to_same_length(logits, labels)
    weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
    outputs = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
    padded_labels = tf.cast(labels, tf.int32)
    return tf.cast(tf.equal(outputs, padded_labels), tf.float32), weights


def padded_accuracy_topk(logits, labels, k):
  """Percentage of times that top-k predictions matches labels on non-0s."""
  with tf.name_scope("padded_accuracy_topk"):
    logits, labels = _pad_tensors_to_same_length(logits, labels)
    weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
    effective_k = tf.minimum(k, tf.shape(logits)[-1])
    _, outputs = tf.nn.top_k(logits, k=effective_k)
    outputs = tf.cast(outputs, tf.int32)
    padded_labels = tf.cast(labels, tf.int32)
    padded_labels = tf.expand_dims(padded_labels, axis=-1)
    padded_labels += tf.zeros_like(outputs)  # Pad to same shape.
    same = tf.cast(tf.equal(outputs, padded_labels), tf.float32)
    same_topk = tf.reduce_sum(same, axis=-1)
    return same_topk, weights


def padded_accuracy_top5(logits, labels):
  return padded_accuracy_topk(logits, labels, 5)


def padded_sequence_accuracy(logits, labels):
  """Percentage of times that predictions matches labels everywhere (non-0)."""
  with tf.name_scope("padded_sequence_accuracy"):
    logits, labels = _pad_tensors_to_same_length(logits, labels)
    weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
    outputs = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
    padded_labels = tf.cast(labels, tf.int32)
    not_correct = tf.cast(tf.not_equal(outputs, padded_labels),
                          tf.float32) * weights
    axis = list(range(1, len(outputs.get_shape())))
    correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis))
    return correct_seq, tf.constant(1.0)


def padded_neg_log_perplexity(logits, labels, vocab_size):
  """Average log-perplexity excluding padding 0s. No smoothing."""
  num, den = padded_cross_entropy_loss(logits, labels, 0, vocab_size)
  return -num, den


class MetricLayer(tf_keras.layers.Layer):
  """Custom a layer of metrics for Transformer model."""

  def __init__(self, vocab_size):
    super(MetricLayer, self).__init__()
    self.vocab_size = vocab_size
    self.metric_mean_fns = []

  def build(self, input_shape):
    """"Builds metric layer."""
    neg_log_perplexity = functools.partial(
        padded_neg_log_perplexity, vocab_size=self.vocab_size)
    self.metric_mean_fns = [
        (tf_keras.metrics.Mean("accuracy"), padded_accuracy),
        (tf_keras.metrics.Mean("accuracy_top5"), padded_accuracy_top5),
        (tf_keras.metrics.Mean("accuracy_per_sequence"),
         padded_sequence_accuracy),
        (tf_keras.metrics.Mean("neg_log_perplexity"), neg_log_perplexity),
    ]
    super(MetricLayer, self).build(input_shape)

  def get_config(self):
    return {"vocab_size": self.vocab_size}

  def call(self, inputs):
    logits, targets = inputs[0], inputs[1]
    for mean, fn in self.metric_mean_fns:
      m = mean(*fn(logits, targets))
      self.add_metric(m)
    return logits


def transformer_loss(logits, labels, smoothing, vocab_size):
  """Calculates total loss containing cross entropy with padding ignored.

  Args:
    logits: Tensor of size [batch_size, length_logits, vocab_size]
    labels: Tensor of size [batch_size, length_labels]
    smoothing: Label smoothing constant, used to determine the on and off values
    vocab_size: int size of the vocabulary

  Returns:
    A scalar float tensor for loss.
  """
  xentropy, weights = padded_cross_entropy_loss(logits, labels, smoothing,
                                                vocab_size)
  return tf.reduce_sum(xentropy) / tf.reduce_sum(weights)