official/projects/simclr/heads/simclr_head.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SimCLR prediction heads."""
from typing import Optional, Text
import tensorflow as tf, tf_keras
from official.projects.simclr.modeling.layers import nn_blocks
regularizers = tf_keras.regularizers
layers = tf_keras.layers
class ProjectionHead(tf_keras.layers.Layer):
"""Projection head."""
def __init__(
self,
num_proj_layers: int = 3,
proj_output_dim: Optional[int] = None,
ft_proj_idx: int = 0,
kernel_initializer: Text = 'VarianceScaling',
kernel_regularizer: Optional[regularizers.Regularizer] = None,
bias_regularizer: Optional[regularizers.Regularizer] = None,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
**kwargs):
"""The projection head used during pretraining of SimCLR.
Args:
num_proj_layers: `int` number of Dense layers used.
proj_output_dim: `int` output dimension of projection head, i.e., output
dimension of the final layer.
ft_proj_idx: `int` index of layer to use during fine-tuning. 0 means no
projection head during fine tuning, -1 means the final layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf_keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf_keras.regularizers.Regularizer object for Conv2d.
Default to None.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(ProjectionHead, self).__init__(**kwargs)
assert proj_output_dim is not None or num_proj_layers == 0
assert ft_proj_idx <= num_proj_layers, (num_proj_layers, ft_proj_idx)
self._proj_output_dim = proj_output_dim
self._num_proj_layers = num_proj_layers
self._ft_proj_idx = ft_proj_idx
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._layers = []
def get_config(self):
config = {
'proj_output_dim': self._proj_output_dim,
'num_proj_layers': self._num_proj_layers,
'ft_proj_idx': self._ft_proj_idx,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ProjectionHead, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self._layers = []
if self._num_proj_layers > 0:
intermediate_dim = int(input_shape[-1])
for j in range(self._num_proj_layers):
if j != self._num_proj_layers - 1:
# for the middle layers, use bias and relu for the output.
layer = nn_blocks.DenseBN(
output_dim=intermediate_dim,
use_bias=True,
use_normalization=True,
activation='relu',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon,
name='nl_%d' % j)
else:
# for the final layer, neither bias nor relu is used.
layer = nn_blocks.DenseBN(
output_dim=self._proj_output_dim,
use_bias=False,
use_normalization=True,
activation=None,
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon,
name='nl_%d' % j)
self._layers.append(layer)
super(ProjectionHead, self).build(input_shape)
def call(self, inputs, training=None):
hiddens_list = [tf.identity(inputs, 'proj_head_input')]
if self._num_proj_layers == 0:
proj_head_output = inputs
proj_finetune_output = inputs
else:
for j in range(self._num_proj_layers):
hiddens = self._layers[j](hiddens_list[-1], training)
hiddens_list.append(hiddens)
proj_head_output = tf.identity(
hiddens_list[-1], 'proj_head_output')
proj_finetune_output = tf.identity(
hiddens_list[self._ft_proj_idx], 'proj_finetune_output')
# The first element is the output of the projection head.
# The second element is the input of the finetune head.
return proj_head_output, proj_finetune_output
class ClassificationHead(tf_keras.layers.Layer):
"""Classification Head."""
def __init__(
self,
num_classes: int,
kernel_initializer: Text = 'random_uniform',
kernel_regularizer: Optional[regularizers.Regularizer] = None,
bias_regularizer: Optional[regularizers.Regularizer] = None,
name: Text = 'head_supervised',
**kwargs):
"""The classification head used during pretraining or fine tuning.
Args:
num_classes: `int` size of the output dimension or number of classes
for classification task.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf_keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf_keras.regularizers.Regularizer object for Conv2d.
Default to None.
name: `str`, name of the layer.
**kwargs: keyword arguments to be passed.
"""
super(ClassificationHead, self).__init__(name=name, **kwargs)
self._num_classes = num_classes
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._name = name
def get_config(self):
config = {
'num_classes': self._num_classes,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
}
base_config = super(ClassificationHead, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self._dense0 = layers.Dense(
units=self._num_classes,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=None)
super(ClassificationHead, self).build(input_shape)
def call(self, inputs, training=None):
inputs = self._dense0(inputs)
return inputs