tensorflow/models

View on GitHub
official/projects/const_cl/losses/losses.py

Summary

Maintainability
A
3 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The losses for ConST-CL."""

from typing import Mapping

import tensorflow as tf, tf_keras

from tensorflow.compiler.tf2xla.python import xla  # pylint: disable=g-direct-tensorflow-import
from official.projects.video_ssl.losses import losses as video_ssl_losses

tpu_cross_replica_concat = video_ssl_losses.tpu_cross_replica_concat


_LARGE_NUM = 1e9


class ContrastiveLoss(object):
  """InfoNCE loss.

  Reference: Oord et al. "Representation learning with contrastive
    predictive coding" NeurIPS 2019.
  """

  def __init__(self,
               normalize_inputs: bool,
               temperature: float):
    """Computes contrastive loss.

    Args:
      normalize_inputs: whether or not to l2 normalize the inputs vector.
      temperature: temperature in the InfoNCE contrastive loss.
    """
    self._normalize_inputs = normalize_inputs
    self._temperature = temperature

  def __call__(self,
               inputs: tf.Tensor,
               num_replicas: int = 1) -> Mapping[str, tf.Tensor]:
    """Calculates the loss.

    Args:
      inputs: the embeddings (in shape [2*B, C]) from video clips after the
        projection head.
      num_replicas: the number of TPU replicas.

    Returns:
      a dictionary contains calculated loss and statistics.
    """
    inputs1, inputs2 = tf.split(inputs, num_or_size_splits=2, axis=0)
    if self._normalize_inputs:
      inputs1 = tf.math.l2_normalize(inputs1, -1)
      inputs2 = tf.math.l2_normalize(inputs2, -1)
    batch_size = tf.shape(inputs1)[0]

    if num_replicas == 1:
      # This is the local version.
      inputs1_large = inputs1
      inputs2_large = inputs2
      labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
      masks = tf.one_hot(tf.range(batch_size), batch_size)
    else:
      # This is the cross-tpu version.
      inputs1_large = tpu_cross_replica_concat(inputs1, num_replicas)
      inputs2_large = tpu_cross_replica_concat(inputs2, num_replicas)
      enlarged_batch_size = tf.shape(inputs1_large)[0]
      replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
      labels_idx = tf.range(batch_size) + replica_id * batch_size
      labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
      masks = tf.one_hot(labels_idx, enlarged_batch_size)

    logits_aa = tf.matmul(
        inputs1, inputs1_large, transpose_b=True) / self._temperature
    logits_aa = logits_aa - tf.cast(masks, logits_aa.dtype) * _LARGE_NUM
    logits_bb = tf.matmul(
        inputs2, inputs2_large, transpose_b=True) / self._temperature
    logits_bb = logits_bb - tf.cast(masks, logits_bb.dtype) * _LARGE_NUM
    logits_ab = tf.matmul(
        inputs1, inputs2_large, transpose_b=True) / self._temperature
    logits_ba = tf.matmul(
        inputs2, inputs1_large, transpose_b=True) / self._temperature

    loss_a = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
        labels, tf.concat([logits_ab, logits_aa], 1)))
    loss_b = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
        labels, tf.concat([logits_ba, logits_bb], 1)))
    loss = loss_a + loss_b

    contrast_prob = tf.nn.softmax(logits_ab)
    contrast_entropy = - tf.reduce_mean(
        tf.reduce_sum(contrast_prob * tf.math.log(contrast_prob + 1e-8), -1))

    contrast_acc = tf.equal(tf.argmax(labels, 1), tf.argmax(logits_ab, axis=1))
    contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32))

    return {
        'loss': loss,
        'contrastive_accuracy': contrast_acc,
        'contrastive_entropy': contrast_entropy,
    }


class InstanceContrastiveLoss(object):
  """Instance Contrastive Loss.

  Reference: Yuan et al. "Contextualized Spatio-Temporal Contrastive Learning
    with Self-Supervision" CVPR 2022.
  """

  def __init__(self,
               normalize_inputs: bool,
               temperature: float):
    self._normalize_inputs = normalize_inputs
    self._temperature = temperature

  def __call__(self,
               predictions: Mapping[str, tf.Tensor],
               num_replicas: int = 1) -> Mapping[str, tf.Tensor]:
    """Computes contrastive loss for spatio-temporal instance embeddings.

    Args:
      predictions: a dictionary of the model outputs, contains
        'instances_a2b': the reconstructed instance features from view a -> b.
          In shape [B, N, C].
        'instances_b2a': the reconstructed instance features from view b -> a.
          In shape [B, N, C].
        'instances_a': the target instance features in view a. In shape
          [B, N, C].
        'instances_b': the target instance features in view b. In shape
          [B, N, C].
        'masks_a': the vaidity boolean mask for instances in view a. In shape
          [B, N].
        'masks_b': the vaidity boolean mask for instances in view b. In shape
          [B, N].
      num_replicas: the number of TPU replicas.

    Returns:
      A loss scalar.
      The staticstics for positive examples.
      The staticstics for negative examples.
    """

    inst_a2b = predictions['instances_a2b']
    inst_b2a = predictions['instances_b2a']
    inst_a = predictions['instances_a']
    inst_b = predictions['instances_b']
    masks_a = tf.cast(predictions['masks_a'][..., None], dtype=inst_a.dtype)
    masks_b = tf.cast(predictions['masks_b'][..., None], dtype=inst_b.dtype)

    if self._normalize_inputs:
      inst_a2b = tf.math.l2_normalize(inst_a2b, axis=-1)
      inst_b2a = tf.math.l2_normalize(inst_b2a, axis=-1)
      inst_a = tf.math.l2_normalize(inst_a, axis=-1)
      inst_b = tf.math.l2_normalize(inst_b, axis=-1)

    b, n = inst_a.shape.as_list()[:2]
    batch_index = tf.range(b)

    # Computes similarity based on raw features in view a and b.
    similarity_ab = tf.einsum('ijc,ikc->ijk', inst_a, inst_b)

    # Loss on translated_a2b.
    similarity_ab_index = tf.argmax(similarity_ab, axis=2, output_type=tf.int32)
    lookup_a2b_index = tf.stack(
        [tf.tile(batch_index[:, None], [1, n]), similarity_ab_index], axis=-1)
    loss_and_stats_a = self._compute_constrastive_loss(
        positive_lookup_index=lookup_a2b_index,
        inst_translated=inst_a2b,
        inst_target=inst_b,
        inst_mask=masks_a,
        num_replicas=num_replicas)

    # Loss on translated_b2a.
    similarity_ba_index = tf.argmax(similarity_ab, axis=1, output_type=tf.int32)
    lookup_b2a_index = tf.stack(
        [tf.tile(batch_index[:, None], [1, n]), similarity_ba_index], axis=-1)
    loss_and_stats_b = self._compute_constrastive_loss(
        positive_lookup_index=lookup_b2a_index,
        inst_translated=inst_b2a,
        inst_target=inst_a,
        inst_mask=masks_b,
        num_replicas=num_replicas)

    loss_and_stats = {}
    for key in loss_and_stats_a:
      loss_and_stats[key] = 0.5 * (
          loss_and_stats_a[key] + loss_and_stats_b[key])
    return loss_and_stats

  def _get_negative_similarity_statistics(
      self,
      logits: tf.Tensor,
      batch_masks: tf.Tensor,
      inst_mask: tf.Tensor) -> Mapping[str, tf.Tensor]:
    """Gets negative examples similarity statistics.

    Args:
      logits: the logits matrix.
      batch_masks: the batch validity mask.
      inst_mask: the instance validity mask.

    Returns:
      logs: a dictionary of logs.
    """
    # logits = [b, n, bl, n]
    # batch_masks = [b, n, bl, n]
    # inst_mask = [b, n, 1]
    inst_mask = tf.cast(inst_mask, logits.dtype)
    batch_masks = tf.cast(batch_masks, logits.dtype)
    batch_masks = tf.ones_like(batch_masks) - batch_masks
    masks = batch_masks * inst_mask[..., None]
    # Recover the raw similarity and mask self-similarity, which will be
    # removed from negative samples.
    similarity = logits * masks * self._temperature
    similarity_mean = tf.reduce_sum(similarity) / tf.reduce_sum(masks)

    similarity_masks = tf.squeeze(inst_mask, axis=-1)
    similarity_max = similarity - (1.0 - masks) * _LARGE_NUM
    similarity_max = tf.reduce_max(similarity_max, axis=[-1, -2])
    similarity_max = tf.reduce_sum(
        similarity_max * similarity_masks) / tf.reduce_sum(similarity_masks)

    similarity_min = similarity + (1.0 - masks) * _LARGE_NUM
    similarity_min = tf.reduce_min(similarity_min, axis=[-1, -2])
    similarity_min = tf.reduce_sum(
        similarity_min * similarity_masks) / tf.reduce_sum(similarity_masks)
    logs = {
        'negative_similarity_mean': similarity_mean,
        'negative_similarity_min': similarity_min,
        'negative_similarity_max': similarity_max,
    }
    return logs

  def _get_positive_similarity_statistics(
      self,
      logits: tf.Tensor,
      inst_mask: tf.Tensor) -> Mapping[str, tf.Tensor]:
    """Gets positive examples similarity statistics.

    Args:
      logits: the logits matrix.
      inst_mask: the instance validity mask.

    Returns:
      logs: a dictionary of logs.
    """
    # logits in shape [b, n]
    # inst_mask in shape [b, n, 1]
    inst_mask = tf.squeeze(inst_mask, axis=-1)
    inst_mask = tf.cast(inst_mask, dtype=logits.dtype)
    similarity = logits * inst_mask * self._temperature

    num_instances = tf.reduce_sum(inst_mask)
    similarity_mean = tf.reduce_sum(similarity) / num_instances

    similarity_max = similarity - (1.0 - inst_mask) * _LARGE_NUM
    similarity_max = tf.reduce_max(similarity_max)

    similarity_min = similarity + (1.0 - inst_mask) * _LARGE_NUM
    similarity_min = tf.reduce_min(similarity_min)

    logs = {
        'positive_similarity_mean': similarity_mean,
        'positive_similarity_min': similarity_min,
        'positive_similarity_max': similarity_max,
    }
    return logs

  def _compute_constrastive_loss(
      self,
      positive_lookup_index: tf.Tensor,
      inst_translated: tf.Tensor,
      inst_target: tf.Tensor,
      inst_mask: tf.Tensor,
      num_replicas: int = 1) -> Mapping[str, tf.Tensor]:
    """Computes constrastive loss.

    Args:
      positive_lookup_index: the index tensor to look-up the corresponding
        features in inst_target. In shape [B, N].
      inst_translated: a float tensor of shape [B, N, C] of translated instance
        features by the transformer head.
      inst_target: a float tensor of shape [B, N, C] of instance features on the
        target domain. Note that the order of inst_target is not necessarily
        matched to inst_translated.
      inst_mask: a boolean tensor of shape [B, N, 1] suggesting valid instances
        in inst_translated.
      num_replicas: the number of TPU replicas.

    Returns:
      loss_and_stats: a dictionary of loss and intermediate statistics.
    """
    b, n = inst_translated.shape.as_list()[:2]

    if num_replicas == 1:
      inst_target_large = inst_target
      b_large = tf.shape(inst_target_large)[0]
      labels_idx = tf.range(b)
    else:
      inst_target_large = tpu_cross_replica_concat(
          inst_target,
          num_replicas)
      b_large = tf.shape(inst_target_large)[0]
      # NOTE: make sure to use xla.replica_id() here and in
      # tpu_cross_replica_concat to consistently align the replica_id.
      # replicator.replica_id != xla.replica_id()
      replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
      labels_idx = tf.range(b) + replica_id * b

    # [B, BL], 1 indicates positive batches.
    batch_masks = tf.one_hot(labels_idx, b_large)
    # [B, N, BL, N]
    batch_masks = tf.tile(batch_masks[:, None, :, None], [1, n, 1, n])

    # Construct negative examples.
    logits_negative = tf.einsum(
        'ijc,pqc->ijpq',
        inst_translated, inst_target_large) / self._temperature
    # Get negative statistics.
    negative_stats = self._get_negative_similarity_statistics(
        logits_negative, batch_masks, inst_mask)
    logits_negative = logits_negative - tf.cast(
        batch_masks, logits_negative.dtype) * _LARGE_NUM
    logits_negative = tf.reshape(logits_negative, [b * n, b_large * n])

    # Construct positive examples.
    inst_matched = tf.gather_nd(
        inst_target, positive_lookup_index, name='matched_inst')
    logits_positive = tf.einsum(
        'ijc,ijc->ij',
        inst_translated, inst_matched) / self._temperature
    # Get positive statistics.
    positive_stats = self._get_positive_similarity_statistics(
        logits_positive, inst_mask)
    logits_positive = tf.reshape(logits_positive, [b * n, 1])

    logits_all = tf.concat([logits_positive, logits_negative], axis=1)
    loss_pos = tf.reduce_logsumexp(logits_positive, 1)
    loss_all = tf.reduce_logsumexp(logits_all, 1)
    loss = (loss_all - loss_pos) * tf.reshape(inst_mask, [b * n])

    # Average across instances.
    loss = tf.math.divide_no_nan(
        tf.reduce_sum(loss), tf.reduce_sum(inst_mask))

    loss_and_stats = {'loss': loss}
    loss_and_stats.update(negative_stats)
    loss_and_stats.update(positive_stats)
    return loss_and_stats