tensorflow/models

View on GitHub
official/projects/simclr/modeling/layers/nn_blocks.py

Summary

Maintainability
A
1 hr
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains common building blocks for simclr neural networks."""
from typing import Text, Optional

import tensorflow as tf, tf_keras

from official.modeling import tf_utils

regularizers = tf_keras.regularizers


class DenseBN(tf_keras.layers.Layer):
  """Modified Dense layer to help build simclr system.

  The layer is a standards combination of Dense, BatchNorm and Activation.
  """

  def __init__(
      self,
      output_dim: int,
      use_bias: bool = True,
      use_normalization: bool = False,
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      activation: Optional[Text] = 'relu',
      kernel_initializer: Text = 'VarianceScaling',
      kernel_regularizer: Optional[regularizers.Regularizer] = None,
      bias_regularizer: Optional[regularizers.Regularizer] = None,
      name='linear_layer',
      **kwargs):
    """Customized Dense layer.

    Args:
      output_dim: `int` size of output dimension.
      use_bias: if True, use biase in the dense layer.
      use_normalization: if True, use batch normalization.
      use_sync_bn: if True, use synchronized batch normalization.
      norm_momentum: `float` normalization momentum for the moving average.
      norm_epsilon: `float` small float added to variance to avoid dividing by
        zero.
      activation: `str` name of the activation function.
      kernel_initializer: kernel_initializer for convolutional layers.
      kernel_regularizer: tf_keras.regularizers.Regularizer object for Conv2D.
        Default to None.
      bias_regularizer: tf_keras.regularizers.Regularizer object for Conv2d.
        Default to None.
      name: `str`, name of the layer.
      **kwargs: keyword arguments to be passed.
    """
    # Note: use_bias is ignored for the dense layer when use_bn=True.
    # However, it is still used for batch norm.
    super(DenseBN, self).__init__(**kwargs)
    self._output_dim = output_dim
    self._use_bias = use_bias
    self._use_normalization = use_normalization
    self._use_sync_bn = use_sync_bn
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
    self._activation = activation
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer
    self._name = name

    if use_sync_bn:
      self._norm = tf_keras.layers.experimental.SyncBatchNormalization
    else:
      self._norm = tf_keras.layers.BatchNormalization
    if tf_keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1
    if activation:
      self._activation_fn = tf_utils.get_activation(activation)
    else:
      self._activation_fn = None

  def get_config(self):
    config = {
        'output_dim': self._output_dim,
        'use_bias': self._use_bias,
        'activation': self._activation,
        'use_sync_bn': self._use_sync_bn,
        'use_normalization': self._use_normalization,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epsilon,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
    }
    base_config = super(DenseBN, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def build(self, input_shape):
    self._dense0 = tf_keras.layers.Dense(
        self._output_dim,
        kernel_initializer=self._kernel_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        use_bias=self._use_bias and not self._use_normalization)

    if self._use_normalization:
      self._norm0 = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
          center=self._use_bias,
          scale=True)

    super(DenseBN, self).build(input_shape)

  def call(self, inputs, training=None):
    assert inputs.shape.ndims == 2, inputs.shape
    x = self._dense0(inputs)
    if self._use_normalization:
      x = self._norm0(x)
    if self._activation:
      x = self._activation_fn(x)
    return x