tensorflow/models

View on GitHub
official/projects/simclr/tasks/simclr.py

Summary

Maintainability
C
1 day
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Image SimCLR task definition.

SimCLR training two different modes:
- pretrain
- fine-tuning

For the above two different modes, the following components are different in
the task definition:
- training data format
- training loss
- projection_head and/or supervised_head
"""
from typing import Dict, Optional

from absl import logging
import tensorflow as tf, tf_keras

from official.core import base_task
from official.core import config_definitions
from official.core import input_reader
from official.core import task_factory
from official.modeling import optimization
from official.modeling import performance
from official.modeling import tf_utils
from official.projects.simclr.configs import simclr as exp_cfg
from official.projects.simclr.dataloaders import simclr_input
from official.projects.simclr.heads import simclr_head
from official.projects.simclr.losses import contrastive_losses
from official.projects.simclr.modeling import simclr_model
from official.vision.modeling import backbones

OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig


@task_factory.register_task_cls(exp_cfg.SimCLRPretrainTask)
class SimCLRPretrainTask(base_task.Task):
  """A task for image classification."""

  def create_optimizer(self,
                       optimizer_config: OptimizationConfig,
                       runtime_config: Optional[RuntimeConfig] = None):
    """Creates an TF optimizer from configurations.

    Args:
      optimizer_config: the parameters of the Optimization settings.
      runtime_config: the parameters of the runtime.

    Returns:
      A tf.optimizers.Optimizer object.
    """
    if (optimizer_config.optimizer.type == 'lars' and
        self.task_config.loss.l2_weight_decay > 0.0):
      raise ValueError('The l2_weight_decay cannot be used together with lars '
                       'optimizer. Please set it to 0.')

    opt_factory = optimization.OptimizerFactory(optimizer_config)
    optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
    # Configuring optimizer when loss_scale is set in runtime config. This helps
    # avoiding overflow/underflow for float16 computations.
    if runtime_config and runtime_config.loss_scale:
      optimizer = performance.configure_optimizer(
          optimizer,
          use_float16=runtime_config.mixed_precision_dtype == 'float16',
          loss_scale=runtime_config.loss_scale)

    return optimizer

  def build_model(self):
    model_config = self.task_config.model
    input_specs = tf_keras.layers.InputSpec(shape=[None] +
                                            model_config.input_size)

    l2_weight_decay = self.task_config.loss.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (
        tf_keras.regularizers.l2(l2_weight_decay /
                                 2.0) if l2_weight_decay else None)

    # Build backbone
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=model_config.norm_activation,
        l2_regularizer=l2_regularizer)

    # Build projection head
    norm_activation_config = model_config.norm_activation
    projection_head_config = model_config.projection_head
    projection_head = simclr_head.ProjectionHead(
        proj_output_dim=projection_head_config.proj_output_dim,
        num_proj_layers=projection_head_config.num_proj_layers,
        ft_proj_idx=projection_head_config.ft_proj_idx,
        kernel_regularizer=l2_regularizer,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon)

    # Build supervised head
    supervised_head_config = model_config.supervised_head
    if supervised_head_config:
      if supervised_head_config.zero_init:
        s_kernel_initializer = 'zeros'
      else:
        s_kernel_initializer = 'random_uniform'
      supervised_head = simclr_head.ClassificationHead(
          num_classes=supervised_head_config.num_classes,
          kernel_initializer=s_kernel_initializer,
          kernel_regularizer=l2_regularizer)
    else:
      supervised_head = None

    model = simclr_model.SimCLRModel(
        input_specs=input_specs,
        backbone=backbone,
        projection_head=projection_head,
        supervised_head=supervised_head,
        mode=model_config.mode,
        backbone_trainable=model_config.backbone_trainable)

    logging.info(model.get_config())

    return model

  def initialize(self, model: tf_keras.Model):
    """Loading pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    # Restoring checkpoint.
    if self.task_config.init_checkpoint_modules == 'all':
      ckpt = tf.train.Checkpoint(**model.checkpoint_items)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
    elif self.task_config.init_checkpoint_modules == 'backbone':
      ckpt = tf.train.Checkpoint(backbone=model.backbone)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
    else:
      raise ValueError(
          "Only 'all' or 'backbone' can be used to initialize the model.")

    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)

  def build_inputs(self, params, input_context=None):
    input_size = self.task_config.model.input_size

    if params.tfds_name:
      decoder = simclr_input.TFDSDecoder(params.decoder.decode_label)
    else:
      decoder = simclr_input.Decoder(params.decoder.decode_label)

    parser = simclr_input.Parser(
        output_size=input_size[:2],
        aug_rand_crop=params.parser.aug_rand_crop,
        aug_rand_hflip=params.parser.aug_rand_hflip,
        aug_color_distort=params.parser.aug_color_distort,
        aug_color_jitter_strength=params.parser.aug_color_jitter_strength,
        aug_color_jitter_impl=params.parser.aug_color_jitter_impl,
        aug_rand_blur=params.parser.aug_rand_blur,
        parse_label=params.parser.parse_label,
        test_crop=params.parser.test_crop,
        mode=params.parser.mode,
        dtype=params.dtype)

    reader = input_reader.InputReader(
        params,
        dataset_fn=tf.data.TFRecordDataset,
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset

  def build_losses(self,
                   labels,
                   model_outputs,
                   aux_losses=None) -> Dict[str, tf.Tensor]:
    # Compute contrastive relative loss
    con_losses_obj = contrastive_losses.ContrastiveLoss(
        projection_norm=self.task_config.loss.projection_norm,
        temperature=self.task_config.loss.temperature)
    # The projection outputs from model has the size of
    # (2 * bsz, project_dim)
    projection_outputs = model_outputs[simclr_model.PROJECTION_OUTPUT_KEY]
    projection1, projection2 = tf.split(projection_outputs, 2, 0)
    contrast_loss, (contrast_logits, contrast_labels) = con_losses_obj(
        projection1=projection1, projection2=projection2)

    contrast_accuracy = tf.equal(
        tf.argmax(contrast_labels, axis=1), tf.argmax(contrast_logits, axis=1))
    contrast_accuracy = tf.reduce_mean(tf.cast(contrast_accuracy, tf.float32))

    contrast_prob = tf.nn.softmax(contrast_logits)
    contrast_entropy = -tf.reduce_mean(
        tf.reduce_sum(contrast_prob * tf.math.log(contrast_prob + 1e-8), -1))

    model_loss = contrast_loss

    losses = {
        'contrast_loss': contrast_loss,
        'contrast_accuracy': contrast_accuracy,
        'contrast_entropy': contrast_entropy
    }

    if self.task_config.model.supervised_head is not None:
      outputs = model_outputs[simclr_model.SUPERVISED_OUTPUT_KEY]
      labels = tf.concat([labels, labels], 0)

      if self.task_config.evaluation.one_hot:
        sup_loss = tf_keras.losses.CategoricalCrossentropy(
            from_logits=True, reduction=tf_keras.losses.Reduction.NONE)(labels,
                                                                        outputs)
      else:
        sup_loss = tf_keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf_keras.losses.Reduction.NONE)(labels,
                                                                        outputs)
      sup_loss = tf.reduce_mean(sup_loss)

      label_acc = tf.equal(
          tf.argmax(labels, axis=1), tf.argmax(outputs, axis=1))
      label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))

      model_loss = contrast_loss + sup_loss

      losses.update({
          'accuracy': label_acc,
          'supervised_loss': sup_loss,
      })

    total_loss = model_loss
    if aux_losses:
      reg_loss = tf.reduce_sum(aux_losses)
      total_loss = model_loss + reg_loss

    losses['total_loss'] = total_loss

    return losses

  def build_metrics(self, training=True):

    if training:
      metrics = []
      metric_names = [
          'total_loss', 'contrast_loss', 'contrast_accuracy', 'contrast_entropy'
      ]
      if self.task_config.model.supervised_head:
        metric_names.extend(['supervised_loss', 'accuracy'])
      for name in metric_names:
        metrics.append(tf_keras.metrics.Mean(name, dtype=tf.float32))
    else:
      k = self.task_config.evaluation.top_k
      if self.task_config.evaluation.one_hot:
        metrics = [
            tf_keras.metrics.CategoricalAccuracy(name='accuracy'),
            tf_keras.metrics.TopKCategoricalAccuracy(
                k=k, name='top_{}_accuracy'.format(k))
        ]
      else:
        metrics = [
            tf_keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
            tf_keras.metrics.SparseTopKCategoricalAccuracy(
                k=k, name='top_{}_accuracy'.format(k))
        ]
    return metrics

  def train_step(self, inputs, model, optimizer, metrics=None):
    features, labels = inputs

    # To do a sanity check that we absolutely use no labels when pretraining, we
    # can set the labels here to zero.
    if self.task_config.train_data.input_set_label_to_zero:
      labels *= 0

    if (self.task_config.model.supervised_head is not None and
        self.task_config.evaluation.one_hot):
      num_classes = self.task_config.model.supervised_head.num_classes
      labels = tf.one_hot(labels, num_classes)

    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
      outputs = model(features, training=True)
      # Casting output layer as float32 is necessary when mixed_precision is
      # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
      outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)

      # Computes per-replica loss.
      losses = self.build_losses(
          model_outputs=outputs, labels=labels, aux_losses=model.losses)

      scaled_loss = losses['total_loss'] / num_replicas
      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
      if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    logging.info('Trainable variables:')
    for var in tvars:
      logging.info(var.name)
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient when LossScaleOptimizer is used.
    if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

    logs = {self.loss: losses['total_loss']}

    for m in metrics:
      m.update_state(losses[m.name])
      logs.update({m.name: m.result()})

    return logs

  def validation_step(self, inputs, model, metrics=None):
    if self.task_config.model.supervised_head is None:
      raise ValueError(
          'Skipping eval during pretraining without supervised head.')

    features, labels = inputs
    if self.task_config.evaluation.one_hot:
      num_classes = self.task_config.model.supervised_head.num_classes
      labels = tf.one_hot(labels, num_classes)

    outputs = model(
        features, training=False)[simclr_model.SUPERVISED_OUTPUT_KEY]
    outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)

    logs = {self.loss: 0}

    if metrics:
      self.process_metrics(metrics, labels, outputs)
      logs.update({m.name: m.result() for m in metrics})
    elif model.compiled_metrics:
      self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
      logs.update({m.name: m.result() for m in model.metrics})

    return logs


@task_factory.register_task_cls(exp_cfg.SimCLRFinetuneTask)
class SimCLRFinetuneTask(base_task.Task):
  """A task for image classification."""

  def create_optimizer(self,
                       optimizer_config: OptimizationConfig,
                       runtime_config: Optional[RuntimeConfig] = None):
    """Creates an TF optimizer from configurations.

    Args:
      optimizer_config: the parameters of the Optimization settings.
      runtime_config: the parameters of the runtime.

    Returns:
      A tf.optimizers.Optimizer object.
    """
    if (optimizer_config.optimizer.type == 'lars' and
        self.task_config.loss.l2_weight_decay > 0.0):
      raise ValueError('The l2_weight_decay cannot be used together with lars '
                       'optimizer. Please set it to 0.')

    opt_factory = optimization.OptimizerFactory(optimizer_config)
    optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
    # Configuring optimizer when loss_scale is set in runtime config. This helps
    # avoiding overflow/underflow for float16 computations.
    if runtime_config and runtime_config.loss_scale:
      optimizer = performance.configure_optimizer(
          optimizer,
          use_float16=runtime_config.mixed_precision_dtype == 'float16',
          loss_scale=runtime_config.loss_scale)

    return optimizer

  def build_model(self):
    model_config = self.task_config.model
    input_specs = tf_keras.layers.InputSpec(shape=[None] +
                                            model_config.input_size)

    l2_weight_decay = self.task_config.loss.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (
        tf_keras.regularizers.l2(l2_weight_decay /
                                 2.0) if l2_weight_decay else None)

    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=model_config.norm_activation,
        l2_regularizer=l2_regularizer)

    norm_activation_config = model_config.norm_activation
    projection_head_config = model_config.projection_head
    projection_head = simclr_head.ProjectionHead(
        proj_output_dim=projection_head_config.proj_output_dim,
        num_proj_layers=projection_head_config.num_proj_layers,
        ft_proj_idx=projection_head_config.ft_proj_idx,
        kernel_regularizer=l2_regularizer,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon)

    supervised_head_config = model_config.supervised_head
    if supervised_head_config.zero_init:
      s_kernel_initializer = 'zeros'
    else:
      s_kernel_initializer = 'random_uniform'
    supervised_head = simclr_head.ClassificationHead(
        num_classes=supervised_head_config.num_classes,
        kernel_initializer=s_kernel_initializer,
        kernel_regularizer=l2_regularizer)

    model = simclr_model.SimCLRModel(
        input_specs=input_specs,
        backbone=backbone,
        projection_head=projection_head,
        supervised_head=supervised_head,
        mode=model_config.mode,
        backbone_trainable=model_config.backbone_trainable)

    logging.info(model.get_config())

    return model

  def initialize(self, model: tf_keras.Model):
    """Loading pretrained checkpoint."""
    if not self.task_config.init_checkpoint:
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    # Restoring checkpoint.
    if self.task_config.init_checkpoint_modules == 'all':
      ckpt = tf.train.Checkpoint(**model.checkpoint_items)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
    elif self.task_config.init_checkpoint_modules == 'backbone_projection':
      ckpt = tf.train.Checkpoint(
          backbone=model.backbone, projection_head=model.projection_head)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
    elif self.task_config.init_checkpoint_modules == 'backbone':
      ckpt = tf.train.Checkpoint(backbone=model.backbone)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
    else:
      raise ValueError(
          "Only 'all' or 'backbone' can be used to initialize the model.")

    # If the checkpoint is from pretraining, reset the following parameters
    model.backbone_trainable = self.task_config.model.backbone_trainable
    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)

  def build_inputs(self, params, input_context=None):
    input_size = self.task_config.model.input_size

    if params.tfds_name:
      decoder = simclr_input.TFDSDecoder(params.decoder.decode_label)
    else:
      decoder = simclr_input.Decoder(params.decoder.decode_label)
    parser = simclr_input.Parser(
        output_size=input_size[:2],
        parse_label=params.parser.parse_label,
        test_crop=params.parser.test_crop,
        mode=params.parser.mode,
        dtype=params.dtype)

    reader = input_reader.InputReader(
        params,
        dataset_fn=tf.data.TFRecordDataset,
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training))

    dataset = reader.read(input_context=input_context)

    return dataset

  def build_losses(self, labels, model_outputs, aux_losses=None):
    """Sparse categorical cross entropy loss.

    Args:
      labels: labels.
      model_outputs: Output logits of the classifier.
      aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.

    Returns:
      The total loss tensor.
    """
    losses_config = self.task_config.loss
    if losses_config.one_hot:
      total_loss = tf_keras.losses.categorical_crossentropy(
          labels,
          model_outputs,
          from_logits=True,
          label_smoothing=losses_config.label_smoothing)
    else:
      total_loss = tf_keras.losses.sparse_categorical_crossentropy(
          labels, model_outputs, from_logits=True)

    total_loss = tf_utils.safe_mean(total_loss)
    if aux_losses:
      total_loss += tf.add_n(aux_losses)

    return total_loss

  def build_metrics(self, training=True):
    """Gets streaming metrics for training/validation."""
    k = self.task_config.evaluation.top_k
    if self.task_config.evaluation.one_hot:
      metrics = [
          tf_keras.metrics.CategoricalAccuracy(name='accuracy'),
          tf_keras.metrics.TopKCategoricalAccuracy(
              k=k, name='top_{}_accuracy'.format(k))
      ]
    else:
      metrics = [
          tf_keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
          tf_keras.metrics.SparseTopKCategoricalAccuracy(
              k=k, name='top_{}_accuracy'.format(k))
      ]
    return metrics

  def train_step(self, inputs, model, optimizer, metrics=None):
    """Does forward and backward.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs
    if self.task_config.loss.one_hot:
      num_classes = self.task_config.model.supervised_head.num_classes
      labels = tf.one_hot(labels, num_classes)

    num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
    with tf.GradientTape() as tape:
      outputs = model(
          features, training=True)[simclr_model.SUPERVISED_OUTPUT_KEY]
      # Casting output layer as float32 is necessary when mixed_precision is
      # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
      outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)

      # Computes per-replica loss.
      loss = self.build_losses(
          model_outputs=outputs, labels=labels, aux_losses=model.losses)
      # Scales loss as the default gradients allreduce performs sum inside the
      # optimizer.
      scaled_loss = loss / num_replicas

      # For mixed_precision policy, when LossScaleOptimizer is used, loss is
      # scaled for numerical stability.
      if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    logging.info('Trainable variables:')
    for var in tvars:
      logging.info(var.name)
    grads = tape.gradient(scaled_loss, tvars)
    # Scales back gradient before apply_gradients when LossScaleOptimizer is
    # used.
    if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))

    logs = {self.loss: loss}
    if metrics:
      self.process_metrics(metrics, labels, outputs)
      logs.update({m.name: m.result() for m in metrics})
    elif model.compiled_metrics:
      self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
      logs.update({m.name: m.result() for m in model.metrics})
    return logs

  def validation_step(self, inputs, model, metrics=None):
    """Validatation step.

    Args:
      inputs: a dictionary of input tensors.
      model: the keras.Model.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    features, labels = inputs
    if self.task_config.loss.one_hot:
      num_classes = self.task_config.model.supervised_head.num_classes
      labels = tf.one_hot(labels, num_classes)

    outputs = self.inference_step(features,
                                  model)[simclr_model.SUPERVISED_OUTPUT_KEY]
    outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
    loss = self.build_losses(
        model_outputs=outputs, labels=labels, aux_losses=model.losses)

    logs = {self.loss: loss}
    if metrics:
      self.process_metrics(metrics, labels, outputs)
      logs.update({m.name: m.result() for m in metrics})
    elif model.compiled_metrics:
      self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
      logs.update({m.name: m.result() for m in model.metrics})
    return logs