official/projects/simclr/dataloaders/simclr_input.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data parser and processing for SimCLR.
For pre-training:
- Preprocessing:
-> random cropping
-> resize back to the original size
-> random color distortions
-> random Gaussian blur (sequential)
- Each image need to be processed randomly twice
```snippets
if train_mode == 'pretrain':
xs = []
for _ in range(2): # Two transformations
xs.append(preprocess_fn_pretrain(image))
image = tf.concat(xs, -1)
else:
image = preprocess_fn_finetune(image)
```
For fine-tuning:
typical image classification input
"""
from typing import List
import tensorflow as tf, tf_keras
from official.projects.simclr.dataloaders import preprocess_ops as simclr_preprocess_ops
from official.projects.simclr.modeling import simclr_model
from official.vision.dataloaders import decoder
from official.vision.dataloaders import parser
from official.vision.ops import preprocess_ops
class Decoder(decoder.Decoder):
"""A tf.Example decoder for classification task."""
def __init__(self, decode_label=True):
self._decode_label = decode_label
self._keys_to_features = {
'image/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''),
}
if self._decode_label:
self._keys_to_features.update({
'image/class/label': (
tf.io.FixedLenFeature((), tf.int64, default_value=-1))
})
def decode(self, serialized_example):
return tf.io.parse_single_example(
serialized_example, self._keys_to_features)
class TFDSDecoder(decoder.Decoder):
"""A TFDS decoder for classification task."""
def __init__(self, decode_label=True):
self._decode_label = decode_label
def decode(self, serialized_example):
sample_dict = {
'image/encoded': tf.io.encode_jpeg(
serialized_example['image'], quality=100),
}
if self._decode_label:
sample_dict.update({
'image/class/label': serialized_example['label'],
})
return sample_dict
class Parser(parser.Parser):
"""Parser for SimCLR training."""
def __init__(self,
output_size: List[int],
aug_rand_crop: bool = True,
aug_rand_hflip: bool = True,
aug_color_distort: bool = True,
aug_color_jitter_strength: float = 1.0,
aug_color_jitter_impl: str = 'simclrv2',
aug_rand_blur: bool = True,
parse_label: bool = True,
test_crop: bool = True,
mode: str = simclr_model.PRETRAIN,
dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset.
Args:
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
aug_rand_crop: `bool`, if Ture, augment training with random cropping.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_color_distort: `bool`, if True augment training with color distortion.
aug_color_jitter_strength: `float`, the floating number for the strength
of the color augmentation
aug_color_jitter_impl: `str`, 'simclrv1' or 'simclrv2'. Define whether
to use simclrv1 or simclrv2's version of random brightness.
aug_rand_blur: `bool`, if True, augment training with random blur.
parse_label: `bool`, if True, parse label together with image.
test_crop: `bool`, if True, augment eval with center cropping.
mode: `str`, 'pretain' or 'finetune'. Define training mode.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
"""
self._output_size = output_size
self._aug_rand_crop = aug_rand_crop
self._aug_rand_hflip = aug_rand_hflip
self._aug_color_distort = aug_color_distort
self._aug_color_jitter_strength = aug_color_jitter_strength
self._aug_color_jitter_impl = aug_color_jitter_impl
self._aug_rand_blur = aug_rand_blur
self._parse_label = parse_label
self._mode = mode
self._test_crop = test_crop
if max(self._output_size[0], self._output_size[1]) <= 32:
self._test_crop = False
if dtype == 'float32':
self._dtype = tf.float32
elif dtype == 'float16':
self._dtype = tf.float16
elif dtype == 'bfloat16':
self._dtype = tf.bfloat16
else:
raise ValueError('dtype {!r} is not supported!'.format(dtype))
def _parse_one_train_image(self, image_bytes):
image = tf.image.decode_jpeg(image_bytes, channels=3)
# This line convert the image to float 0.0 - 1.0
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if self._aug_rand_crop:
image = simclr_preprocess_ops.random_crop_with_resize(
image, self._output_size[0], self._output_size[1])
if self._aug_rand_hflip:
image = tf.image.random_flip_left_right(image)
if self._aug_color_distort and self._mode == simclr_model.PRETRAIN:
image = simclr_preprocess_ops.random_color_jitter(
image=image,
color_jitter_strength=self._aug_color_jitter_strength,
impl=self._aug_color_jitter_impl)
if self._aug_rand_blur and self._mode == simclr_model.PRETRAIN:
image = simclr_preprocess_ops.random_blur(
image, self._output_size[0], self._output_size[1])
image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image = tf.reshape(image, [self._output_size[0], self._output_size[1], 3])
image = tf.clip_by_value(image, 0., 1.)
# Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype)
return image
def _parse_train_data(self, decoded_tensors):
"""Parses data for training."""
image_bytes = decoded_tensors['image/encoded']
if self._mode == simclr_model.FINETUNE:
image = self._parse_one_train_image(image_bytes)
elif self._mode == simclr_model.PRETRAIN:
# Transform each example twice using a combination of
# simple augmentations, resulting in 2N data points
xs = []
for _ in range(2):
xs.append(self._parse_one_train_image(image_bytes))
image = tf.concat(xs, -1)
else:
raise ValueError('The mode {} is not supported by the Parser.'
.format(self._mode))
if self._parse_label:
label = tf.cast(decoded_tensors['image/class/label'], dtype=tf.int32)
return image, label
return image
def _parse_eval_data(self, decoded_tensors):
"""Parses data for evaluation."""
image_bytes = decoded_tensors['image/encoded']
image_shape = tf.image.extract_jpeg_shape(image_bytes)
if self._test_crop:
image = preprocess_ops.center_crop_image_v2(image_bytes, image_shape)
else:
image = tf.image.decode_jpeg(image_bytes, channels=3)
# This line convert the image to float 0.0 - 1.0
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image = tf.reshape(image, [self._output_size[0], self._output_size[1], 3])
image = tf.clip_by_value(image, 0., 1.)
# Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype)
if self._parse_label:
label = tf.cast(decoded_tensors['image/class/label'], dtype=tf.int32)
return image, label
return image