tensorflow/models

View on GitHub
research/seq_flow_lite/layers/normalization_layers.py

Summary

Maintainability
A
1 hr
Test Coverage
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for normalization."""
import tensorflow as tf

from layers import base_layers # import seq_flow_lite module
from layers import quantization_layers # import seq_flow_lite module
from tf_ops import tf_custom_ops_py # import seq_flow_lite module


class BatchNormalization(base_layers.BaseLayer):
  """A class that applies batch normalization to the input tensor."""

  def __init__(self, ema_decay=0.999, **kwargs):
    self.ema_decay = ema_decay
    super(BatchNormalization, self).__init__(**kwargs)

  def build(self, input_shapes):
    self.reduce_dims = list(range(len(input_shapes) - 1))
    shape = [input_shapes[-1]]
    self.offset = self.add_weight(
        "offset",
        shape=shape,
        initializer=tf.keras.initializers.Zeros(),
        trainable=True)
    self.scale = self.add_weight(
        "scale",
        shape=shape,
        initializer=tf.keras.initializers.Ones(),
        trainable=True)
    self.mva_mean = self.add_weight(
        "mva_mean",
        shape=shape,
        initializer=tf.keras.initializers.Zeros(),
        trainable=False)
    self.mva_var = self.add_weight(
        "mva_variance",
        shape=shape,
        initializer=tf.keras.initializers.Ones(),
        trainable=False)

  def call(self, inputs):
    mean_mom, var_mom = None, None
    if self.parameters.mode == base_layers.TRAIN:
      mean_mom, var_mom = tf.nn.moments(inputs, self.reduce_dims)
    return self._batch_norm(inputs, mean_mom, var_mom)

  def _batch_norm(self, inputs, mean_mom, var_mom):
    if self.parameters.mode == base_layers.TRAIN:
      # During training compute summay stats, update them to moving average
      # variables and use the summary stas for batch normalization.
      with tf.control_dependencies([
          self.assign_moving_average(self.mva_mean, mean_mom, self.ema_decay),
          self.assign_moving_average(self.mva_var, var_mom, self.ema_decay)
      ]):
        tensor = tf.nn.batch_normalization(inputs, mean_mom, var_mom,
                                           self.offset, self.scale, 1e-9)
    else:
      # During eval/inference use the moving average variable for batch
      # normalization. The variables would be frozen to constants before
      # saving graph.
      tensor = tf.nn.batch_normalization(inputs, self.mva_mean, self.mva_var,
                                         self.offset, self.scale, 1e-9)
    return tensor


class VarLenBatchNormalization(BatchNormalization):
  """A class that applies batch normalization to the input tensor."""

  def __init__(self, rank=2, **kwargs):
    self.rank = rank
    assert rank == 2 or rank == 4
    super(VarLenBatchNormalization, self).__init__(**kwargs)

  def _reduce(self, tensor, multiplier):
    return tf.reduce_sum(tensor, axis=self.reduce_dims) * multiplier

  def call(self, inputs, mask, inverse_normalizer):
    if self.parameters.mode == base_layers.TRAIN:
      self._assert_rank_and_type(inputs, self.rank)
      self._assert_rank_and_type(mask, self.rank)
      inputs = mask * inputs
      mean_mom = self._reduce(inputs, inverse_normalizer)
      var_mom = self._reduce(inputs * inputs, inverse_normalizer)
      return mask * self._batch_norm(inputs, mean_mom, var_mom)
    elif self.parameters.mode == base_layers.EVAL:
      return mask * self._batch_norm(inputs, None, None)
    return self._batch_norm(inputs, None, None)


class LayerNormalization(base_layers.BaseLayer):
  """A class that applies layer normalization to the input tensor."""

  def __init__(self, axes=None, **kwargs):
    self.axes = axes or [-1]
    self.qactivation = quantization_layers.ActivationQuantization(**kwargs)
    super(LayerNormalization, self).__init__(**kwargs)

  def build(self, input_shape):
    self.rank = len(input_shape)
    for i, axis in enumerate(self.axes):
      if axis < 0:
        self.axes[i] += self.rank
      assert (self.axes[i] > 0 and self.axes[i] < self.rank)
    self.offset = self.add_weight(
        "offset",
        shape=[1],
        initializer=tf.keras.initializers.Zeros(),
        trainable=True)
    self.scale = self.add_weight(
        "scale",
        shape=[1],
        initializer=tf.keras.initializers.Ones(),
        trainable=True)

  def call(self, tensor):
    tensor = self.qactivation(tensor)
    if self.parameters.mode != base_layers.TFLITE:
      mean, variance = tf.nn.moments(tensor, self.axes, keepdims=True)
      # If all the values in the tensor are same, variance will be 0. Adding a
      # small epsilon to variance ensures that we get 0 as the normalized result
      # instead of NaN in the resulting tensor.
      tensor = (tensor - mean) / tf.sqrt(variance + 1e-6)
      return tensor * self.scale + self.offset
    else:
      return tf_custom_ops_py.layer_norm(
          tensor, self.scale, self.offset, axes=self.axes)