official/projects/qat/vision/modeling/layers/nn_layers.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common building blocks for neural networks."""
import enum
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import tensorflow as tf, tf_keras
import tensorflow_model_optimization as tfmot
from official.modeling import tf_utils
from official.projects.qat.vision.quantization import configs
from official.projects.qat.vision.quantization import helper
from official.vision.modeling import heads
from official.vision.modeling.decoders import aspp
from official.vision.modeling.layers import nn_layers
# Type annotations.
States = Dict[str, tf.Tensor]
Activation = Union[str, Callable]
# String constants.
class FeatureFusion(str, enum.Enum):
PYRAMID_FUSION = 'pyramid_fusion'
PANOPTIC_FPN_FUSION = 'panoptic_fpn_fusion'
DEEPLABV3PLUS = 'deeplabv3plus'
DEEPLABV3PLUS_SUM_TO_MERGE = 'deeplabv3plus_sum_to_merge'
@tf_keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitationQuantized(
helper.LayerQuantizerHelper,
tf_keras.layers.Layer):
"""Creates a squeeze and excitation layer."""
def __init__(self,
in_filters,
out_filters,
se_ratio,
divisible_by=1,
use_3d_input=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
gating_activation='sigmoid',
round_down_protect=True,
**kwargs):
"""Initializes a squeeze and excitation layer.
Args:
in_filters: An `int` number of filters of the input tensor.
out_filters: An `int` number of filters of the output tensor.
se_ratio: A `float` or None. If not None, se ratio for the squeeze and
excitation layer.
divisible_by: An `int` that ensures all inner dimensions are divisible by
this number.
use_3d_input: A `bool` of whether input is 2D or 3D image.
kernel_initializer: A `str` of kernel_initializer for convolutional
layers.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2d.
Default to None.
activation: A `str` name of the activation function.
gating_activation: A `str` name of the activation function for final
gating function.
round_down_protect: A `bool` of whether round down more than 10% will be
allowed.
**kwargs: Additional keyword arguments to be passed.
"""
super().__init__(**kwargs)
self._in_filters = in_filters
self._out_filters = out_filters
self._se_ratio = se_ratio
self._divisible_by = divisible_by
self._round_down_protect = round_down_protect
self._use_3d_input = use_3d_input
self._activation = activation
self._gating_activation = gating_activation
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if tf_keras.backend.image_data_format() == 'channels_last':
if not use_3d_input:
self._spatial_axis = [1, 2]
else:
self._spatial_axis = [1, 2, 3]
else:
if not use_3d_input:
self._spatial_axis = [2, 3]
else:
self._spatial_axis = [2, 3, 4]
def _create_gating_activation_layer(self):
if self._gating_activation == 'hard_sigmoid':
# Convert hard_sigmoid activation to quantizable keras layers so each op
# can be properly quantized.
# Formula is hard_sigmoid(x) = relu6(x + 3) * 0.16667.
self._add_quantizer('add_three')
self._add_quantizer('divide_six')
self._relu6 = tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation('relu6', use_keras_layer=True),
configs.Default8BitActivationQuantizeConfig())
else:
self._gating_activation_layer = (
tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(
self._gating_activation, use_keras_layer=True
),
configs.Default8BitActivationQuantizeConfig(),
)
)
def _apply_gating_activation_layer(
self, x: tf.Tensor, training: bool) -> tf.Tensor:
if self._gating_activation == 'hard_sigmoid':
x = self._apply_quantizer('add_three', x + 3.0, training)
x = self._relu6(x)
x = self._apply_quantizer('divide_six', x * 1.6667, training)
else:
x = self._gating_activation_layer(x)
return x
def build(self, input_shape):
num_reduced_filters = nn_layers.make_divisible(
max(1, int(self._in_filters * self._se_ratio)),
divisor=self._divisible_by,
round_down_protect=self._round_down_protect)
self._se_reduce = helper.Conv2DQuantized(
filters=num_reduced_filters,
kernel_size=1,
strides=1,
padding='same',
use_bias=True,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=helper.NoOpActivation())
self._se_expand = helper.Conv2DOutputQuantized(
filters=self._out_filters,
kernel_size=1,
strides=1,
padding='same',
use_bias=True,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=helper.NoOpActivation())
self._multiply = tfmot.quantization.keras.QuantizeWrapperV2(
tf_keras.layers.Multiply(),
configs.Default8BitQuantizeConfig([], [], True))
self._reduce_mean_quantizer = (
tfmot.quantization.keras.quantizers.MovingAverageQuantizer(
num_bits=8, per_axis=False, symmetric=False, narrow_range=False))
self._reduce_mean_quantizer_vars = self._reduce_mean_quantizer.build(
None, 'reduce_mean_quantizer_vars', self)
self._activation_layer = tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(self._activation, use_keras_layer=True),
configs.Default8BitActivationQuantizeConfig())
self._create_gating_activation_layer()
self._build_quantizer_vars()
super().build(input_shape)
def get_config(self):
config = {
'in_filters': self._in_filters,
'out_filters': self._out_filters,
'se_ratio': self._se_ratio,
'divisible_by': self._divisible_by,
'use_3d_input': self._use_3d_input,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'gating_activation': self._gating_activation,
'round_down_protect': self._round_down_protect,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
x = tf.reduce_mean(inputs, self._spatial_axis, keepdims=True)
x = self._reduce_mean_quantizer(
x, training, self._reduce_mean_quantizer_vars)
x = self._activation_layer(self._se_reduce(x))
x = self._apply_gating_activation_layer(self._se_expand(x), training)
x = self._multiply([x, inputs])
return x
@tf_keras.utils.register_keras_serializable(package='Vision')
class SegmentationHeadQuantized(tf_keras.layers.Layer):
"""Creates a segmentation head."""
def __init__(
self,
num_classes: int,
level: Union[int, str],
num_convs: int = 2,
num_filters: int = 256,
use_depthwise_convolution: bool = False,
prediction_kernel_size: int = 1,
upsample_factor: int = 1,
feature_fusion: Optional[str] = None,
decoder_min_level: Optional[int] = None,
decoder_max_level: Optional[int] = None,
low_level: int = 2,
low_level_num_filters: int = 48,
num_decoder_filters: int = 256,
activation: str = 'relu',
logit_activation: Optional[str] = None,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes a segmentation head.
Args:
num_classes: An `int` number of mask classification categories. The number
of classes does not include background class.
level: An `int` or `str`, level to use to build segmentation head.
num_convs: An `int` number of stacked convolution before the last
prediction layer.
num_filters: An `int` number to specify the number of filters used.
Default is 256.
use_depthwise_convolution: A bool to specify if use depthwise separable
convolutions.
prediction_kernel_size: An `int` number to specify the kernel size of the
prediction layer.
upsample_factor: An `int` number to specify the upsampling factor to
generate finer mask. Default 1 means no upsampling is applied.
feature_fusion: One of `deeplabv3plus`, `deeplabv3plus_sum_to_merge`,
`pyramid_fusion`, or None. If `deeplabv3plus`, features from
decoder_features[level] will be fused with low level feature maps from
backbone. If `pyramid_fusion`, multiscale features will be resized and
fused at the target level.
decoder_min_level: An `int` of minimum level from decoder to use in
feature fusion. It is only used when feature_fusion is set to
`panoptic_fpn_fusion`.
decoder_max_level: An `int` of maximum level from decoder to use in
feature fusion. It is only used when feature_fusion is set to
`panoptic_fpn_fusion`.
low_level: An `int` of backbone level to be used for feature fusion. It is
used when feature_fusion is set to `deeplabv3plus`.
low_level_num_filters: An `int` of reduced number of filters for the low
level features before fusing it with higher level features. It is only
used when feature_fusion is set to `deeplabv3plus`.
num_decoder_filters: An `int` of number of filters in the decoder outputs.
It is only used when feature_fusion is set to `panoptic_fpn_fusion`.
activation: A `str` that indicates which activation is used, e.g. 'relu',
'swish', etc.
logit_activation: Unused.
use_sync_bn: A `bool` that indicates whether to use synchronized batch
normalization across different replicas.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default is None.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
**kwargs: Additional keyword arguments to be passed.
"""
super().__init__(**kwargs)
self._config_dict = {
'num_classes': num_classes,
'level': level,
'num_convs': num_convs,
'num_filters': num_filters,
'use_depthwise_convolution': use_depthwise_convolution,
'prediction_kernel_size': prediction_kernel_size,
'upsample_factor': upsample_factor,
'feature_fusion': feature_fusion,
'decoder_min_level': decoder_min_level,
'decoder_max_level': decoder_max_level,
'low_level': low_level,
'low_level_num_filters': low_level_num_filters,
'num_decoder_filters': num_decoder_filters,
'activation': activation,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_layer = tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(activation, use_keras_layer=True),
configs.Default8BitActivationQuantizeConfig())
if logit_activation:
raise ValueError('Unused logit_activation option inherited from '
'vision SegmentationHead modeling config.')
def build(self, input_shape: Sequence[tf.TensorShape]):
"""Creates the variables of the segmentation head."""
# When input_shape is a list/tuple, the first corresponds to backbone
# features used for resizing the decoder features (the second) if feature
# fusion type is `deeplabv3plus`.
backbone_shape = input_shape[0]
use_depthwise_convolution = self._config_dict['use_depthwise_convolution']
random_initializer = tf_keras.initializers.RandomNormal(stddev=0.01)
conv_kwargs = {
'kernel_size': 3 if not use_depthwise_convolution else 1,
'padding': 'same',
'use_bias': False,
'kernel_initializer': random_initializer,
'kernel_regularizer': self._config_dict['kernel_regularizer'],
}
norm_layer = (
tf_keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn'] else
tf_keras.layers.BatchNormalization)
norm_with_quantize = helper.BatchNormalizationQuantized(norm_layer)
norm_no_quantize = helper.BatchNormalizationNoQuantized(norm_layer)
norm = helper.norm_by_activation(self._config_dict['activation'],
norm_with_quantize, norm_no_quantize)
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
}
if self._config_dict['feature_fusion'] in [
FeatureFusion.DEEPLABV3PLUS, FeatureFusion.DEEPLABV3PLUS_SUM_TO_MERGE
]:
# Deeplabv3+ feature fusion layers.
self._dlv3p_conv = helper.Conv2DQuantized(
kernel_size=1,
padding='same',
use_bias=False,
kernel_initializer=tf_utils.clone_initializer(random_initializer),
kernel_regularizer=self._config_dict['kernel_regularizer'],
name='segmentation_head_deeplabv3p_fusion_conv',
filters=self._config_dict['low_level_num_filters'],
activation=helper.NoOpActivation())
self._dlv3p_norm = norm(
name='segmentation_head_deeplabv3p_fusion_norm', **bn_kwargs)
# Segmentation head layers.
self._convs = []
self._norms = []
for i in range(self._config_dict['num_convs']):
if use_depthwise_convolution:
self._convs.append(
helper.DepthwiseConv2DQuantized(
name='segmentation_head_depthwise_conv_{}'.format(i),
kernel_size=3,
padding='same',
use_bias=False,
depthwise_initializer=tf_utils.clone_initializer(
random_initializer),
depthwise_regularizer=self._config_dict['kernel_regularizer'],
depth_multiplier=1,
activation=helper.NoOpActivation()))
norm_name = 'segmentation_head_depthwise_norm_{}'.format(i)
self._norms.append(norm(name=norm_name, **bn_kwargs))
conv_name = 'segmentation_head_conv_{}'.format(i)
self._convs.append(
helper.Conv2DQuantized(
name=conv_name,
filters=self._config_dict['num_filters'],
activation=helper.NoOpActivation(),
**conv_kwargs))
norm_name = 'segmentation_head_norm_{}'.format(i)
self._norms.append(norm(name=norm_name, **bn_kwargs))
self._classifier = helper.Conv2DOutputQuantized(
name='segmentation_output',
filters=self._config_dict['num_classes'],
kernel_size=self._config_dict['prediction_kernel_size'],
padding='same',
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf_utils.clone_initializer(random_initializer),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
activation=helper.NoOpActivation())
self._upsampling_layer = helper.UpSampling2DQuantized(
size=(self._config_dict['upsample_factor'],
self._config_dict['upsample_factor']),
interpolation='nearest')
self._resizing_layer = helper.ResizingQuantized(
backbone_shape[1], backbone_shape[2], interpolation='bilinear')
self._concat_layer = helper.ConcatenateQuantized(axis=self._bn_axis)
self._add_layer = tfmot.quantization.keras.QuantizeWrapperV2(
tf_keras.layers.Add(), configs.Default8BitQuantizeConfig([], [], True))
super().build(input_shape)
def call(self, inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
Union[tf.Tensor, Mapping[str, tf.Tensor]]]):
"""Forward pass of the segmentation head.
It supports both a tuple of 2 tensors or 2 dictionaries. The first is
backbone endpoints, and the second is decoder endpoints. When inputs are
tensors, they are from a single level of feature maps. When inputs are
dictionaries, they contain multiple levels of feature maps, where the key
is the index of feature map.
Args:
inputs: A tuple of 2 feature map tensors of shape
[batch, height_l, width_l, channels] or 2 dictionaries of tensors:
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
Returns:
segmentation prediction mask: A `tf.Tensor` of the segmentation mask
scores predicted from input features.
"""
if self._config_dict['feature_fusion'] in (
FeatureFusion.PYRAMID_FUSION, FeatureFusion.PANOPTIC_FPN_FUSION):
raise ValueError(
'The feature fusion method `pyramid_fusion` is not supported in QAT.')
backbone_output = inputs[0]
decoder_output = inputs[1]
if self._config_dict['feature_fusion'] in {
FeatureFusion.DEEPLABV3PLUS, FeatureFusion.DEEPLABV3PLUS_SUM_TO_MERGE
}:
# deeplabv3+ feature fusion.
x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output
y = backbone_output[str(self._config_dict['low_level'])] if isinstance(
backbone_output, dict) else backbone_output
y = self._dlv3p_norm(self._dlv3p_conv(y))
y = self._activation_layer(y)
x = self._resizing_layer(x)
x = tf.cast(x, dtype=y.dtype)
if self._config_dict['feature_fusion'] == FeatureFusion.DEEPLABV3PLUS:
x = self._concat_layer([x, y])
else:
x = self._add_layer([x, y])
else:
x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output
for conv, norm in zip(self._convs, self._norms):
x = conv(x)
x = norm(x)
x = self._activation_layer(x)
if self._config_dict['upsample_factor'] > 1:
# Use keras layer for nearest upsampling so it is QAT compatible.
x = self._upsampling_layer(x)
return self._classifier(x)
def get_config(self):
base_config = super().get_config()
return dict(list(base_config.items()) + list(self._config_dict.items()))
@classmethod
def from_config(cls, config):
return cls(**config)
@tf_keras.utils.register_keras_serializable(package='Vision')
class SpatialPyramidPoolingQuantized(nn_layers.SpatialPyramidPooling):
"""Implements the quantized Atrous Spatial Pyramid Pooling.
References:
[Rethinking Atrous Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1706.05587.pdf)
[Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
"""
def __init__(
self,
output_channels: int,
dilation_rates: List[int],
pool_kernel_size: Optional[List[int]] = None,
use_sync_bn: bool = False,
batchnorm_momentum: float = 0.99,
batchnorm_epsilon: float = 0.001,
activation: str = 'relu',
dropout: float = 0.5,
kernel_initializer: str = 'GlorotUniform',
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
use_depthwise_convolution: bool = False,
**kwargs):
"""Initializes `SpatialPyramidPooling`.
Args:
output_channels: Number of channels produced by SpatialPyramidPooling.
dilation_rates: A list of integers for parallel dilated conv.
pool_kernel_size: A list of integers or None. If None, global average
pooling is applied, otherwise an average pooling of pool_kernel_size is
applied.
use_sync_bn: A bool, whether or not to use sync batch normalization.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99.
batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
0.001.
activation: A `str` for type of activation to be used. Defaults to 'relu'.
dropout: A float for the dropout rate before output. Defaults to 0.5.
kernel_initializer: Kernel initializer for conv layers. Defaults to
`glorot_uniform`.
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
use_depthwise_convolution: Allows spatial pooling to be separable
depthwise convolusions. [Encoder-Decoder with Atrous Separable
Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1802.02611.pdf)
**kwargs: Other keyword arguments for the layer.
"""
super().__init__(
output_channels=output_channels,
dilation_rates=dilation_rates,
use_sync_bn=use_sync_bn,
batchnorm_momentum=batchnorm_momentum,
batchnorm_epsilon=batchnorm_epsilon,
activation=activation,
dropout=dropout,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
interpolation=interpolation,
pool_kernel_size=pool_kernel_size,
use_depthwise_convolution=use_depthwise_convolution)
self._activation_fn = tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(activation, use_keras_layer=True),
configs.Default8BitActivationQuantizeConfig())
self._activation_fn_no_quant = (
tf_utils.get_activation(activation, use_keras_layer=True))
def build(self, input_shape):
height = input_shape[1]
width = input_shape[2]
channels = input_shape[3]
norm_layer = (
tf_keras.layers.experimental.SyncBatchNormalization
if self._use_sync_bn else tf_keras.layers.BatchNormalization)
norm_with_quantize = helper.BatchNormalizationQuantized(norm_layer)
norm_no_quantize = helper.BatchNormalizationNoQuantized(norm_layer)
norm = helper.norm_by_activation(self._activation, norm_with_quantize,
norm_no_quantize)
self.aspp_layers = []
conv1 = helper.Conv2DQuantized(
filters=self._output_channels,
kernel_size=(1, 1),
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
use_bias=False,
activation=helper.NoOpActivation())
norm1 = norm(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
self.aspp_layers.append([conv1, norm1])
for dilation_rate in self._dilation_rates:
leading_layers = []
kernel_size = (3, 3)
if self._use_depthwise_convolution:
leading_layers += [
helper.DepthwiseConv2DOutputQuantized(
depth_multiplier=1,
kernel_size=kernel_size,
padding='same',
depthwise_regularizer=self._kernel_regularizer,
depthwise_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
dilation_rate=dilation_rate,
use_bias=False,
activation=helper.NoOpActivation())
]
kernel_size = (1, 1)
conv_dilation = leading_layers + [
helper.Conv2DQuantized(
filters=self._output_channels,
kernel_size=kernel_size,
padding='same',
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
dilation_rate=dilation_rate,
use_bias=False,
activation=helper.NoOpActivation())
]
norm_dilation = norm(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
self.aspp_layers.append(conv_dilation + [norm_dilation])
if self._pool_kernel_size is None:
pooling = [
helper.GlobalAveragePooling2DQuantized(),
helper.ReshapeQuantized((1, 1, channels))
]
else:
pooling = [helper.AveragePooling2DQuantized(self._pool_kernel_size)]
conv2 = helper.Conv2DQuantized(
filters=self._output_channels,
kernel_size=(1, 1),
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
use_bias=False,
activation=helper.NoOpActivation())
norm2 = norm(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
self.aspp_layers.append(pooling + [conv2, norm2])
self._resizing_layer = helper.ResizingQuantized(
height, width, interpolation=self._interpolation)
self._projection = [
helper.Conv2DQuantized(
filters=self._output_channels,
kernel_size=(1, 1),
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
use_bias=False,
activation=helper.NoOpActivation()),
norm(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
]
self._dropout_layer = tf_keras.layers.Dropout(rate=self._dropout)
self._concat_layer = helper.ConcatenateQuantized(axis=-1)
def call(self,
inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor:
if training is None:
training = tf_keras.backend.learning_phase()
result = []
for i, layers in enumerate(self.aspp_layers):
x = inputs
for layer in layers:
# Apply layers sequentially.
x = layer(x, training=training)
x = self._activation_fn(x)
# Apply resize layer to the end of the last set of layers.
if i == len(self.aspp_layers) - 1:
x = self._resizing_layer(x)
result.append(tf.cast(x, inputs.dtype))
x = self._concat_layer(result)
for layer in self._projection:
x = layer(x, training=training)
x = self._activation_fn(x)
return self._dropout_layer(x)
@tf_keras.utils.register_keras_serializable(package='Vision')
class ASPPQuantized(aspp.ASPP):
"""Creates a quantized Atrous Spatial Pyramid Pooling (ASPP) layer."""
def __init__(
self,
level: int,
dilation_rates: List[int],
num_filters: int = 256,
pool_kernel_size: Optional[int] = None,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
activation: str = 'relu',
dropout_rate: float = 0.0,
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
use_depthwise_convolution: bool = False,
spp_layer_version: str = 'v1',
output_tensor: bool = True,
**kwargs):
"""Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
Args:
level: An `int` level to apply ASPP.
dilation_rates: A `list` of dilation rates.
num_filters: An `int` number of output filters in ASPP.
pool_kernel_size: A `list` of [height, width] of pooling kernel size or
None. Pooling size is with respect to original image size, it will be
scaled down by 2**level. If None, global average pooling is used.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
activation: A `str` activation to be used in ASPP.
dropout_rate: A `float` rate for dropout regularization.
kernel_initializer: A `str` name of kernel_initializer for convolutional
layers.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default is None.
interpolation: A `str` of interpolation method. It should be one of
`bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
`gaussian`, or `mitchellcubic`.
use_depthwise_convolution: If True depthwise separable convolutions will
be added to the Atrous spatial pyramid pooling.
spp_layer_version: A `str` of spatial pyramid pooling layer version.
output_tensor: Whether to output a single tensor or a dictionary of
tensor. Default is true.
**kwargs: Additional keyword arguments to be passed.
"""
super().__init__(
level=level,
dilation_rates=dilation_rates,
num_filters=num_filters,
pool_kernel_size=pool_kernel_size,
use_sync_bn=use_sync_bn,
norm_momentum=norm_momentum,
norm_epsilon=norm_epsilon,
activation=activation,
dropout_rate=dropout_rate,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
interpolation=interpolation,
use_depthwise_convolution=use_depthwise_convolution,
spp_layer_version=spp_layer_version,
output_tensor=output_tensor,
**kwargs)
self._aspp_layer = SpatialPyramidPoolingQuantized
def call(self, inputs: Union[tf.Tensor, Mapping[str,
tf.Tensor]]) -> tf.Tensor:
"""Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
The output of ASPP will be a dict of {`level`, `tf.Tensor`} even if only one
level is present, if output_tensor is false. Hence, this will be compatible
with the rest of the segmentation model interfaces.
If output_tensor is true, a single tensot is output.
Args:
inputs: A `tf.Tensor` of shape [batch, height_l, width_l, filter_size] or
a `dict` of `tf.Tensor` where
- key: A `str` of the level of the multilevel feature maps.
- values: A `tf.Tensor` of shape [batch, height_l, width_l,
filter_size].
Returns:
A `tf.Tensor` of shape [batch, height_l, width_l, filter_size] or a `dict`
of `tf.Tensor` where
- key: A `str` of the level of the multilevel feature maps.
- values: A `tf.Tensor` of output of ASPP module.
"""
level = str(self._config_dict['level'])
backbone_output = inputs[level] if isinstance(inputs, dict) else inputs
return self.aspp(backbone_output)
class BatchNormalizationWrapper(tf_keras.layers.Wrapper):
"""A BatchNormalizationWrapper that explicitly not folded.
It just added an identity depthwise conv right before the normalization.
As a result, given normalization op just folded into the identity depthwise
conv layer.
Note that it only used when the batch normalization folding is not working.
It makes quantize them as a 1x1 depthwise conv layer that just work as same
as inference mode for the normalization. (Basically mult and add for the BN.)
"""
def call(self, inputs: tf.Tensor, *args: Any, **kwargs: Any) -> tf.Tensor:
channels = tf.shape(inputs)[-1]
x = tf.nn.depthwise_conv2d(
inputs, tf.ones([1, 1, channels, 1]), [1, 1, 1, 1], 'VALID')
outputs = self.layer.call(x, *args, **kwargs)
return outputs
class MaskScoringQuantized(heads.MaskScoring):
"""Creates a quantized mask scoring layer.
This implements mask scoring layer from the paper:
Zhaojin Huang, Lichao Huang, Yongchao Gong, Chang Huang, Xinggang Wang.
Mask Scoring R-CNN.
(https://arxiv.org/pdf/1903.00241.pdf)
"""
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the mask scoring head."""
self._activation_layer = tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(
self._config_dict['activation'], use_keras_layer=True
),
configs.Default8BitActivationQuantizeConfig(),
)
conv_kwargs = {
'filters': self._config_dict['num_filters'],
'kernel_size': 3,
'padding': 'same',
}
conv_kwargs.update({
'kernel_initializer': tf_keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'
),
'bias_initializer': tf.zeros_initializer(),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
norm_layer = (
tf_keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn']
else tf_keras.layers.BatchNormalization
)
norm_with_quantize = helper.BatchNormalizationQuantized(norm_layer)
norm_no_quantize = helper.BatchNormalizationNoQuantized(norm_layer)
bn_op = helper.norm_by_activation(
self._config_dict['activation'], norm_with_quantize, norm_no_quantize
)
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
}
self._convs = []
self._conv_norms = []
for i in range(self._config_dict['num_convs']):
if self._config_dict['use_depthwise_convolution']:
self._convs.append(
helper.DepthwiseConv2DQuantized(
name='mask-scoring-depthwise-conv-{}'.format(i),
kernel_size=3,
padding='same',
use_bias=False,
depthwise_initializer=tf_keras.initializers.RandomNormal(
stddev=0.01),
depthwise_regularizer=self._config_dict['kernel_regularizer'],
depth_multiplier=1,
activation=helper.NoOpActivation()))
norm_name = 'mask-scoring-depthwise-bn-{}'.format(i)
self._conv_norms.append(bn_op(name=norm_name, **bn_kwargs))
conv_name = 'mask-scoring_{}'.format(i)
if 'kernel_initializer' in conv_kwargs:
conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer(
conv_kwargs['kernel_initializer']
)
if self._config_dict['use_depthwise_convolution']:
conv_kwargs['kernel_size'] = 1
self._convs.append(
helper.Conv2DQuantized(
name=conv_name, activation=helper.NoOpActivation(), **conv_kwargs
)
)
bn_name = 'mask-scoring-bn_{}'.format(i)
self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs))
self._fcs = []
self._fc_norms = []
for i in range(self._config_dict['num_fcs']):
fc_name = 'mask-scoring-fc_{}'.format(i)
self._fcs.append(
helper.DenseQuantized(
units=self._config_dict['fc_dims'],
kernel_initializer=tf_keras.initializers.VarianceScaling(
scale=1 / 3.0, mode='fan_out', distribution='uniform'
),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
name=fc_name,
activation=helper.NoOpActivation(),
)
)
bn_name = 'mask-scoring-fc-bn_{}'.format(i)
self._fc_norms.append(bn_op(name=bn_name, **bn_kwargs))
self._classifier = helper.DenseOutputQuantized(
units=self._config_dict['num_classes'],
kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
name='iou-scores',
)
self._resizing_layer = helper.ResizingQuantized(
self._config_dict['fc_input_size'][0],
self._config_dict['fc_input_size'][1],
interpolation='bilinear',
)
self._identity_layer = helper.IdentityQuantized(trainable=False)
super().build(input_shape)
def call(self, inputs: tf.Tensor, training: bool = None):
"""Forward pass mask scoring head.
Args:
inputs: A `tf.Tensor` of the shape [batch_size, width, size, num_classes],
representing the segmentation logits.
training: a `bool` indicating whether it is in `training` mode.
Returns:
mask_scores: A `tf.Tensor` of predicted mask scores
[batch_size, num_classes].
"""
x = tf.stop_gradient(inputs)
for conv, bn in zip(self._convs, self._conv_norms):
x = conv(x)
x = bn(x)
x = self._activation_layer(x)
x = self._resizing_layer(x)
_, h, w, filters = x.get_shape().as_list()
x = tf.reshape(x, [-1, h * w * filters])
for fc, bn in zip(self._fcs, self._fc_norms):
x = fc(x)
x = bn(x)
x = self._activation_layer(x)
ious = self._classifier(x)
ious = self._identity_layer(ious)
return ious