research/seq_flow_lite/layers/normalization_layers.py
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for normalization."""
import tensorflow as tf
from layers import base_layers # import seq_flow_lite module
from layers import quantization_layers # import seq_flow_lite module
from tf_ops import tf_custom_ops_py # import seq_flow_lite module
class BatchNormalization(base_layers.BaseLayer):
"""A class that applies batch normalization to the input tensor."""
def __init__(self, ema_decay=0.999, **kwargs):
self.ema_decay = ema_decay
super(BatchNormalization, self).__init__(**kwargs)
def build(self, input_shapes):
self.reduce_dims = list(range(len(input_shapes) - 1))
shape = [input_shapes[-1]]
self.offset = self.add_weight(
"offset",
shape=shape,
initializer=tf.keras.initializers.Zeros(),
trainable=True)
self.scale = self.add_weight(
"scale",
shape=shape,
initializer=tf.keras.initializers.Ones(),
trainable=True)
self.mva_mean = self.add_weight(
"mva_mean",
shape=shape,
initializer=tf.keras.initializers.Zeros(),
trainable=False)
self.mva_var = self.add_weight(
"mva_variance",
shape=shape,
initializer=tf.keras.initializers.Ones(),
trainable=False)
def call(self, inputs):
mean_mom, var_mom = None, None
if self.parameters.mode == base_layers.TRAIN:
mean_mom, var_mom = tf.nn.moments(inputs, self.reduce_dims)
return self._batch_norm(inputs, mean_mom, var_mom)
def _batch_norm(self, inputs, mean_mom, var_mom):
if self.parameters.mode == base_layers.TRAIN:
# During training compute summay stats, update them to moving average
# variables and use the summary stas for batch normalization.
with tf.control_dependencies([
self.assign_moving_average(self.mva_mean, mean_mom, self.ema_decay),
self.assign_moving_average(self.mva_var, var_mom, self.ema_decay)
]):
tensor = tf.nn.batch_normalization(inputs, mean_mom, var_mom,
self.offset, self.scale, 1e-9)
else:
# During eval/inference use the moving average variable for batch
# normalization. The variables would be frozen to constants before
# saving graph.
tensor = tf.nn.batch_normalization(inputs, self.mva_mean, self.mva_var,
self.offset, self.scale, 1e-9)
return tensor
class VarLenBatchNormalization(BatchNormalization):
"""A class that applies batch normalization to the input tensor."""
def __init__(self, rank=2, **kwargs):
self.rank = rank
assert rank == 2 or rank == 4
super(VarLenBatchNormalization, self).__init__(**kwargs)
def _reduce(self, tensor, multiplier):
return tf.reduce_sum(tensor, axis=self.reduce_dims) * multiplier
def call(self, inputs, mask, inverse_normalizer):
if self.parameters.mode == base_layers.TRAIN:
self._assert_rank_and_type(inputs, self.rank)
self._assert_rank_and_type(mask, self.rank)
inputs = mask * inputs
mean_mom = self._reduce(inputs, inverse_normalizer)
var_mom = self._reduce(inputs * inputs, inverse_normalizer)
return mask * self._batch_norm(inputs, mean_mom, var_mom)
elif self.parameters.mode == base_layers.EVAL:
return mask * self._batch_norm(inputs, None, None)
return self._batch_norm(inputs, None, None)
class LayerNormalization(base_layers.BaseLayer):
"""A class that applies layer normalization to the input tensor."""
def __init__(self, axes=None, **kwargs):
self.axes = axes or [-1]
self.qactivation = quantization_layers.ActivationQuantization(**kwargs)
super(LayerNormalization, self).__init__(**kwargs)
def build(self, input_shape):
self.rank = len(input_shape)
for i, axis in enumerate(self.axes):
if axis < 0:
self.axes[i] += self.rank
assert (self.axes[i] > 0 and self.axes[i] < self.rank)
self.offset = self.add_weight(
"offset",
shape=[1],
initializer=tf.keras.initializers.Zeros(),
trainable=True)
self.scale = self.add_weight(
"scale",
shape=[1],
initializer=tf.keras.initializers.Ones(),
trainable=True)
def call(self, tensor):
tensor = self.qactivation(tensor)
if self.parameters.mode != base_layers.TFLITE:
mean, variance = tf.nn.moments(tensor, self.axes, keepdims=True)
# If all the values in the tensor are same, variance will be 0. Adding a
# small epsilon to variance ensures that we get 0 as the normalized result
# instead of NaN in the resulting tensor.
tensor = (tensor - mean) / tf.sqrt(variance + 1e-6)
return tensor * self.scale + self.offset
else:
return tf_custom_ops_py.layer_norm(
tensor, self.scale, self.offset, axes=self.axes)