tensorflow/models

View on GitHub
official/projects/simclr/dataloaders/simclr_input.py

Summary

Maintainability
A
1 hr
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data parser and processing for SimCLR.

For pre-training:
- Preprocessing:
  -> random cropping
  -> resize back to the original size
  -> random color distortions
  -> random Gaussian blur (sequential)
- Each image need to be processed randomly twice

```snippets
      if train_mode == 'pretrain':
        xs = []
        for _ in range(2):  # Two transformations
          xs.append(preprocess_fn_pretrain(image))
        image = tf.concat(xs, -1)
      else:
        image = preprocess_fn_finetune(image)
```

For fine-tuning:
typical image classification input
"""

from typing import List

import tensorflow as tf, tf_keras

from official.projects.simclr.dataloaders import preprocess_ops as simclr_preprocess_ops
from official.projects.simclr.modeling import simclr_model
from official.vision.dataloaders import decoder
from official.vision.dataloaders import parser
from official.vision.ops import preprocess_ops


class Decoder(decoder.Decoder):
  """A tf.Example decoder for classification task."""

  def __init__(self, decode_label=True):
    self._decode_label = decode_label

    self._keys_to_features = {
        'image/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''),
    }
    if self._decode_label:
      self._keys_to_features.update({
          'image/class/label': (
              tf.io.FixedLenFeature((), tf.int64, default_value=-1))
      })

  def decode(self, serialized_example):
    return tf.io.parse_single_example(
        serialized_example, self._keys_to_features)


class TFDSDecoder(decoder.Decoder):
  """A TFDS decoder for classification task."""

  def __init__(self, decode_label=True):
    self._decode_label = decode_label

  def decode(self, serialized_example):
    sample_dict = {
        'image/encoded': tf.io.encode_jpeg(
            serialized_example['image'], quality=100),
    }
    if self._decode_label:
      sample_dict.update({
          'image/class/label': serialized_example['label'],
      })
    return sample_dict


class Parser(parser.Parser):
  """Parser for SimCLR training."""

  def __init__(self,
               output_size: List[int],
               aug_rand_crop: bool = True,
               aug_rand_hflip: bool = True,
               aug_color_distort: bool = True,
               aug_color_jitter_strength: float = 1.0,
               aug_color_jitter_impl: str = 'simclrv2',
               aug_rand_blur: bool = True,
               parse_label: bool = True,
               test_crop: bool = True,
               mode: str = simclr_model.PRETRAIN,
               dtype: str = 'float32'):
    """Initializes parameters for parsing annotations in the dataset.

    Args:
      output_size: `Tensor` or `list` for [height, width] of output image. The
        output_size should be divided by the largest feature stride 2^max_level.
      aug_rand_crop: `bool`, if Ture, augment training with random cropping.
      aug_rand_hflip: `bool`, if True, augment training with random
        horizontal flip.
      aug_color_distort: `bool`, if True augment training with color distortion.
      aug_color_jitter_strength: `float`, the floating number for the strength
        of the color augmentation
      aug_color_jitter_impl: `str`, 'simclrv1' or 'simclrv2'. Define whether
        to use simclrv1 or simclrv2's version of random brightness.
      aug_rand_blur: `bool`, if True, augment training with random blur.
      parse_label: `bool`, if True, parse label together with image.
      test_crop: `bool`, if True, augment eval with center cropping.
      mode: `str`, 'pretain' or 'finetune'. Define training mode.
      dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
        or 'bfloat16'.
    """
    self._output_size = output_size
    self._aug_rand_crop = aug_rand_crop
    self._aug_rand_hflip = aug_rand_hflip
    self._aug_color_distort = aug_color_distort
    self._aug_color_jitter_strength = aug_color_jitter_strength
    self._aug_color_jitter_impl = aug_color_jitter_impl
    self._aug_rand_blur = aug_rand_blur
    self._parse_label = parse_label
    self._mode = mode
    self._test_crop = test_crop
    if max(self._output_size[0], self._output_size[1]) <= 32:
      self._test_crop = False

    if dtype == 'float32':
      self._dtype = tf.float32
    elif dtype == 'float16':
      self._dtype = tf.float16
    elif dtype == 'bfloat16':
      self._dtype = tf.bfloat16
    else:
      raise ValueError('dtype {!r} is not supported!'.format(dtype))

  def _parse_one_train_image(self, image_bytes):

    image = tf.image.decode_jpeg(image_bytes, channels=3)
    # This line convert the image to float 0.0 - 1.0
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)

    if self._aug_rand_crop:
      image = simclr_preprocess_ops.random_crop_with_resize(
          image, self._output_size[0], self._output_size[1])

    if self._aug_rand_hflip:
      image = tf.image.random_flip_left_right(image)

    if self._aug_color_distort and self._mode == simclr_model.PRETRAIN:
      image = simclr_preprocess_ops.random_color_jitter(
          image=image,
          color_jitter_strength=self._aug_color_jitter_strength,
          impl=self._aug_color_jitter_impl)

    if self._aug_rand_blur and self._mode == simclr_model.PRETRAIN:
      image = simclr_preprocess_ops.random_blur(
          image, self._output_size[0], self._output_size[1])

    image = tf.image.resize(
        image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
    image = tf.reshape(image, [self._output_size[0], self._output_size[1], 3])

    image = tf.clip_by_value(image, 0., 1.)
    # Convert image to self._dtype.
    image = tf.image.convert_image_dtype(image, self._dtype)

    return image

  def _parse_train_data(self, decoded_tensors):
    """Parses data for training."""
    image_bytes = decoded_tensors['image/encoded']

    if self._mode == simclr_model.FINETUNE:
      image = self._parse_one_train_image(image_bytes)

    elif self._mode == simclr_model.PRETRAIN:
      # Transform each example twice using a combination of
      # simple augmentations, resulting in 2N data points
      xs = []
      for _ in range(2):
        xs.append(self._parse_one_train_image(image_bytes))
      image = tf.concat(xs, -1)

    else:
      raise ValueError('The mode {} is not supported by the Parser.'
                       .format(self._mode))

    if self._parse_label:
      label = tf.cast(decoded_tensors['image/class/label'], dtype=tf.int32)
      return image, label

    return image

  def _parse_eval_data(self, decoded_tensors):
    """Parses data for evaluation."""
    image_bytes = decoded_tensors['image/encoded']
    image_shape = tf.image.extract_jpeg_shape(image_bytes)

    if self._test_crop:
      image = preprocess_ops.center_crop_image_v2(image_bytes, image_shape)
    else:
      image = tf.image.decode_jpeg(image_bytes, channels=3)
    # This line convert the image to float 0.0 - 1.0
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)

    image = tf.image.resize(
        image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
    image = tf.reshape(image, [self._output_size[0], self._output_size[1], 3])

    image = tf.clip_by_value(image, 0., 1.)

    # Convert image to self._dtype.
    image = tf.image.convert_image_dtype(image, self._dtype)

    if self._parse_label:
      label = tf.cast(decoded_tensors['image/class/label'], dtype=tf.int32)
      return image, label

    return image