official/projects/qat/vision/modeling/layers/nn_blocks.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains quantized neural blocks for the QAT."""
from typing import Any, Dict, Optional, Sequence, Tuple, Union
# Import libraries
from absl import logging
import tensorflow as tf, tf_keras
import tensorflow_model_optimization as tfmot
from official.modeling import tf_utils
from official.projects.qat.vision.modeling.layers import nn_layers as qat_nn_layers
from official.projects.qat.vision.quantization import configs
from official.projects.qat.vision.quantization import helper
from official.vision.modeling.layers import nn_layers
# This class is copied from modeling.layers.nn_blocks.BottleneckBlock and apply
# QAT.
@tf_keras.utils.register_keras_serializable(package='Vision')
class BottleneckBlockQuantized(tf_keras.layers.Layer):
"""A quantized standard bottleneck block."""
def __init__(self,
filters: int,
strides: int,
dilation_rate: int = 1,
use_projection: bool = False,
se_ratio: Optional[float] = None,
resnetd_shortcut: bool = False,
stochastic_depth_drop_rate: Optional[float] = None,
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: tf_keras.regularizers.Regularizer = None,
bias_regularizer: tf_keras.regularizers.Regularizer = None,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
bn_trainable: bool = True, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Initializes a standard bottleneck block with BN after convolutions.
Args:
filters: An `int` number of filters for the first two convolutions. Note
that the third and final convolution will use 4 times as many filters.
strides: An `int` block stride. If greater than 1, this block will
ultimately downsample the input.
dilation_rate: An `int` dilation_rate of convolutions. Default to 1.
use_projection: A `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
resnetd_shortcut: A `bool`. If True, apply the resnetd style modification
to the shortcut connection.
stochastic_depth_drop_rate: A `float` or None. If not None, drop rate for
the stochastic depth layer.
kernel_initializer: A `str` of kernel_initializer for convolutional
layers.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2d.
Default to None.
activation: A `str` name of the activation function.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed.
"""
super(BottleneckBlockQuantized, self).__init__(**kwargs)
self._filters = filters
self._strides = strides
self._dilation_rate = dilation_rate
self._use_projection = use_projection
self._se_ratio = se_ratio
self._resnetd_shortcut = resnetd_shortcut
self._use_sync_bn = use_sync_bn
self._activation = activation
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
norm_layer = (
tf_keras.layers.experimental.SyncBatchNormalization
if use_sync_bn else tf_keras.layers.BatchNormalization)
self._norm_with_quantize = helper.BatchNormalizationQuantized(norm_layer)
self._norm = helper.BatchNormalizationNoQuantized(norm_layer)
if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._bn_trainable = bn_trainable
def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
"""Build variables and child layers to prepare for calling."""
if self._use_projection:
if self._resnetd_shortcut:
self._shortcut0 = tf_keras.layers.AveragePooling2D(
pool_size=2, strides=self._strides, padding='same')
self._shortcut1 = helper.Conv2DQuantized(
filters=self._filters * 4,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=helper.NoOpActivation())
else:
self._shortcut = helper.Conv2DQuantized(
filters=self._filters * 4,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=helper.NoOpActivation())
self._norm0 = self._norm_with_quantize(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._conv1 = helper.Conv2DQuantized(
filters=self._filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=helper.NoOpActivation())
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._activation1 = tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(self._activation, use_keras_layer=True),
configs.Default8BitActivationQuantizeConfig())
self._conv2 = helper.Conv2DQuantized(
filters=self._filters,
kernel_size=3,
strides=self._strides,
dilation_rate=self._dilation_rate,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=helper.NoOpActivation())
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._activation2 = tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(self._activation, use_keras_layer=True),
configs.Default8BitActivationQuantizeConfig())
self._conv3 = helper.Conv2DQuantized(
filters=self._filters * 4,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=helper.NoOpActivation())
self._norm3 = self._norm_with_quantize(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
self._activation3 = tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(self._activation, use_keras_layer=True),
configs.Default8BitActivationQuantizeConfig())
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = qat_nn_layers.SqueezeExcitationQuantized(
in_filters=self._filters * 4,
out_filters=self._filters * 4,
se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
else:
self._squeeze_excitation = None
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
self._add = tfmot.quantization.keras.QuantizeWrapperV2(
tf_keras.layers.Add(),
configs.Default8BitQuantizeConfig([], [], True))
super(BottleneckBlockQuantized, self).build(input_shape)
def get_config(self) -> Dict[str, Any]:
"""Get a config of this layer."""
config = {
'filters': self._filters,
'strides': self._strides,
'dilation_rate': self._dilation_rate,
'use_projection': self._use_projection,
'se_ratio': self._se_ratio,
'resnetd_shortcut': self._resnetd_shortcut,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
'bn_trainable': self._bn_trainable
}
base_config = super(BottleneckBlockQuantized, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(
self,
inputs: tf.Tensor,
training: Optional[Union[bool, tf.Tensor]] = None) -> tf.Tensor:
"""Run the BottleneckBlockQuantized logics."""
shortcut = inputs
if self._use_projection:
if self._resnetd_shortcut:
shortcut = self._shortcut0(shortcut)
shortcut = self._shortcut1(shortcut)
else:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation1(x)
x = self._conv2(x)
x = self._norm2(x)
x = self._activation2(x)
x = self._conv3(x)
x = self._norm3(x)
if self._squeeze_excitation:
x = self._squeeze_excitation(x)
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
x = self._add([x, shortcut])
return self._activation3(x)
# This class is copied from modeling.backbones.mobilenet.Conv2DBNBlock and apply
# QAT.
@tf_keras.utils.register_keras_serializable(package='Vision')
class Conv2DBNBlockQuantized(tf_keras.layers.Layer):
"""A quantized convolution block with batch normalization."""
def __init__(
self,
filters: int,
kernel_size: int = 3,
strides: int = 1,
use_bias: bool = False,
use_explicit_padding: bool = False,
activation: str = 'relu6',
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
use_normalization: bool = True,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
**kwargs):
"""A convolution block with batch normalization.
Args:
filters: An `int` number of filters for the first two convolutions. Note
that the third and final convolution will use 4 times as many filters.
kernel_size: An `int` specifying the height and width of the 2D
convolution window.
strides: An `int` of block stride. If greater than 1, this block will
ultimately downsample the input.
use_bias: If True, use bias in the convolution layer.
use_explicit_padding: Use 'VALID' padding for convolutions, but prepad
inputs so that the output dimensions are the same as if 'SAME' padding
were used.
activation: A `str` name of the activation function.
kernel_initializer: A `str` for kernel initializer of convolutional
layers.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
Default to None.
use_normalization: If True, use batch normalization.
use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: Additional keyword arguments to be passed.
"""
super(Conv2DBNBlockQuantized, self).__init__(**kwargs)
self._filters = filters
self._kernel_size = kernel_size
self._strides = strides
self._activation = activation
self._use_bias = use_bias
self._use_explicit_padding = use_explicit_padding
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_normalization = use_normalization
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_explicit_padding and kernel_size > 1:
self._padding = 'valid'
else:
self._padding = 'same'
norm_layer = (
tf_keras.layers.experimental.SyncBatchNormalization
if use_sync_bn else tf_keras.layers.BatchNormalization)
self._norm_with_quantize = helper.BatchNormalizationQuantized(norm_layer)
self._norm = helper.BatchNormalizationNoQuantized(norm_layer)
if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
def get_config(self) -> Dict[str, Any]:
"""Get a config of this layer."""
config = {
'filters': self._filters,
'strides': self._strides,
'kernel_size': self._kernel_size,
'use_bias': self._use_bias,
'use_explicit_padding': self._use_explicit_padding,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_normalization': self._use_normalization,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(Conv2DBNBlockQuantized, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
"""Build variables and child layers to prepare for calling."""
if self._use_explicit_padding and self._kernel_size > 1:
padding_size = nn_layers.get_padding_for_kernel_size(self._kernel_size)
self._pad = tf_keras.layers.ZeroPadding2D(padding_size)
conv2d_quantized = (
helper.Conv2DQuantized
if self._use_normalization else helper.Conv2DOutputQuantized)
self._conv0 = conv2d_quantized(
filters=self._filters,
kernel_size=self._kernel_size,
strides=self._strides,
padding=self._padding,
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=helper.NoOpActivation())
if self._use_normalization:
self._norm0 = helper.norm_by_activation(self._activation,
self._norm_with_quantize,
self._norm)(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._activation_layer = tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(self._activation, use_keras_layer=True),
configs.Default8BitActivationQuantizeConfig())
super(Conv2DBNBlockQuantized, self).build(input_shape)
def call(
self,
inputs: tf.Tensor,
training: Optional[Union[bool, tf.Tensor]] = None) -> tf.Tensor:
"""Run the Conv2DBNBlockQuantized logics."""
if self._use_explicit_padding and self._kernel_size > 1:
inputs = self._pad(inputs)
x = self._conv0(inputs)
if self._use_normalization:
x = self._norm0(x)
return self._activation_layer(x)
@tf_keras.utils.register_keras_serializable(package='Vision')
class InvertedBottleneckBlockQuantized(tf_keras.layers.Layer):
"""A quantized inverted bottleneck block."""
def __init__(self,
in_filters,
out_filters,
expand_ratio,
strides,
kernel_size=3,
se_ratio=None,
stochastic_depth_drop_rate=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
se_inner_activation='relu',
se_gating_activation='sigmoid',
se_round_down_protect=True,
expand_se_in_filters=False,
depthwise_activation=None,
use_sync_bn=False,
dilation_rate=1,
divisible_by=1,
regularize_depthwise=False,
use_depthwise=True,
use_residual=True,
norm_momentum=0.99,
norm_epsilon=0.001,
output_intermediate_endpoints=False,
**kwargs):
"""Initializes an inverted bottleneck block with BN after convolutions.
Args:
in_filters: An `int` number of filters of the input tensor.
out_filters: An `int` number of filters of the output tensor.
expand_ratio: An `int` of expand_ratio for an inverted bottleneck block.
strides: An `int` block stride. If greater than 1, this block will
ultimately downsample the input.
kernel_size: An `int` kernel_size of the depthwise conv layer.
se_ratio: A `float` or None. If not None, se ratio for the squeeze and
excitation layer.
stochastic_depth_drop_rate: A `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: A `str` of kernel_initializer for convolutional
layers.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2d.
Default to None.
activation: A `str` name of the activation function.
se_inner_activation: A `str` name of squeeze-excitation inner activation.
se_gating_activation: A `str` name of squeeze-excitation gating
activation.
se_round_down_protect: A `bool` of whether round down more than 10% will
be allowed in SE layer.
expand_se_in_filters: A `bool` of whether or not to expand in_filter in
squeeze and excitation layer.
depthwise_activation: A `str` name of the activation function for
depthwise only.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
dilation_rate: An `int` that specifies the dilation rate to use for.
divisible_by: An `int` that ensures all inner dimensions are divisible by
this number.
dilated convolution: An `int` to specify the same value for all spatial
dimensions.
regularize_depthwise: A `bool` of whether or not apply regularization on
depthwise.
use_depthwise: A `bool` of whether to uses fused convolutions instead of
depthwise.
use_residual: A `bool` of whether to include residual connection between
input and output.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
output_intermediate_endpoints: A `bool` of whether or not output the
intermediate endpoints.
**kwargs: Additional keyword arguments to be passed.
"""
super(InvertedBottleneckBlockQuantized, self).__init__(**kwargs)
self._in_filters = in_filters
self._out_filters = out_filters
self._expand_ratio = expand_ratio
self._strides = strides
self._kernel_size = kernel_size
self._se_ratio = se_ratio
self._divisible_by = divisible_by
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._dilation_rate = dilation_rate
self._use_sync_bn = use_sync_bn
self._regularize_depthwise = regularize_depthwise
self._use_depthwise = use_depthwise
self._use_residual = use_residual
self._activation = activation
self._se_inner_activation = se_inner_activation
self._se_gating_activation = se_gating_activation
self._se_round_down_protect = se_round_down_protect
self._depthwise_activation = depthwise_activation
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._expand_se_in_filters = expand_se_in_filters
self._output_intermediate_endpoints = output_intermediate_endpoints
norm_layer = (
tf_keras.layers.experimental.SyncBatchNormalization
if use_sync_bn else tf_keras.layers.BatchNormalization)
self._norm_with_quantize = helper.BatchNormalizationQuantized(norm_layer)
self._norm = helper.BatchNormalizationNoQuantized(norm_layer)
if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
if not depthwise_activation:
self._depthwise_activation = activation
if regularize_depthwise:
self._depthwise_regularizer = kernel_regularizer
else:
self._depthwise_regularizer = None
def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
"""Build variables and child layers to prepare for calling."""
expand_filters = self._in_filters
if self._expand_ratio > 1:
# First 1x1 conv for channel expansion.
expand_filters = nn_layers.make_divisible(
self._in_filters * self._expand_ratio, self._divisible_by)
expand_kernel = 1 if self._use_depthwise else self._kernel_size
expand_stride = 1 if self._use_depthwise else self._strides
self._conv0 = helper.Conv2DQuantized(
filters=expand_filters,
kernel_size=expand_kernel,
strides=expand_stride,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=helper.NoOpActivation())
self._norm0 = helper.norm_by_activation(self._activation,
self._norm_with_quantize,
self._norm)(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._activation_layer = tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(self._activation, use_keras_layer=True),
configs.Default8BitActivationQuantizeConfig())
if self._use_depthwise:
# Depthwise conv.
self._conv1 = helper.DepthwiseConv2DQuantized(
kernel_size=(self._kernel_size, self._kernel_size),
strides=self._strides,
padding='same',
depth_multiplier=1,
dilation_rate=self._dilation_rate,
use_bias=False,
depthwise_initializer=self._kernel_initializer,
depthwise_regularizer=self._depthwise_regularizer,
bias_regularizer=self._bias_regularizer,
activation=helper.NoOpActivation(),
)
self._norm1 = helper.norm_by_activation(self._depthwise_activation,
self._norm_with_quantize,
self._norm)(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._depthwise_activation_layer = (
tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(self._depthwise_activation,
use_keras_layer=True),
configs.Default8BitActivationQuantizeConfig()))
# Squeeze and excitation.
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
logging.info('Use Squeeze and excitation.')
in_filters = self._in_filters
if self._expand_se_in_filters:
in_filters = expand_filters
self._squeeze_excitation = qat_nn_layers.SqueezeExcitationQuantized(
in_filters=in_filters,
out_filters=expand_filters,
se_ratio=self._se_ratio,
divisible_by=self._divisible_by,
round_down_protect=self._se_round_down_protect,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._se_inner_activation,
gating_activation=self._se_gating_activation)
else:
self._squeeze_excitation = None
# Last 1x1 conv.
self._conv2 = helper.Conv2DQuantized(
filters=self._out_filters,
kernel_size=1,
strides=1,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=helper.NoOpActivation())
self._norm2 = self._norm_with_quantize(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
self._add = tfmot.quantization.keras.QuantizeWrapperV2(
tf_keras.layers.Add(),
configs.Default8BitQuantizeConfig([], [], True))
super(InvertedBottleneckBlockQuantized, self).build(input_shape)
def get_config(self) -> Dict[str, Any]:
"""Get a config of this layer."""
config = {
'in_filters': self._in_filters,
'out_filters': self._out_filters,
'expand_ratio': self._expand_ratio,
'strides': self._strides,
'kernel_size': self._kernel_size,
'se_ratio': self._se_ratio,
'divisible_by': self._divisible_by,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'se_inner_activation': self._se_inner_activation,
'se_gating_activation': self._se_gating_activation,
'se_round_down_protect': self._se_round_down_protect,
'expand_se_in_filters': self._expand_se_in_filters,
'depthwise_activation': self._depthwise_activation,
'dilation_rate': self._dilation_rate,
'use_sync_bn': self._use_sync_bn,
'regularize_depthwise': self._regularize_depthwise,
'use_depthwise': self._use_depthwise,
'use_residual': self._use_residual,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
'output_intermediate_endpoints': self._output_intermediate_endpoints
}
base_config = super(InvertedBottleneckBlockQuantized, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(
self,
inputs: tf.Tensor,
training: Optional[Union[bool, tf.Tensor]] = None
) -> Union[tf.Tensor, Tuple[tf.Tensor, Dict[str, tf.Tensor]]]:
"""Run the InvertedBottleneckBlockQuantized logics."""
endpoints = {}
shortcut = inputs
if self._expand_ratio > 1:
x = self._conv0(inputs)
x = self._norm0(x)
x = self._activation_layer(x)
else:
x = inputs
if self._use_depthwise:
x = self._conv1(x)
x = self._norm1(x)
x = self._depthwise_activation_layer(x)
if self._output_intermediate_endpoints:
endpoints['depthwise'] = x
if self._squeeze_excitation:
x = self._squeeze_excitation(x)
x = self._conv2(x)
x = self._norm2(x)
if (self._use_residual and self._in_filters == self._out_filters and
self._strides == 1):
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
x = self._add([x, shortcut])
if self._output_intermediate_endpoints:
return x, endpoints
return x
@tf_keras.utils.register_keras_serializable(package='Vision')
class UniversalInvertedBottleneckBlockQuantized(tf_keras.layers.Layer):
"""A quantized inverted bottleneck block with optional depthwise convs."""
def __init__(
self,
in_filters: int,
out_filters: int,
expand_ratio: float,
strides: int,
middle_dw_downsample: bool = True,
start_dw_kernel_size: int = 0,
middle_dw_kernel_size: int = 3,
end_dw_kernel_size: int = 0,
stochastic_depth_drop_rate: float | None = None,
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: tf_keras.regularizers.Regularizer | None = None,
bias_regularizer: tf_keras.regularizers.Regularizer | None = None,
activation: str = 'relu',
depthwise_activation: str | None = None,
use_sync_bn: bool = False,
dilation_rate: int = 1,
divisible_by: int = 1,
regularize_depthwise: bool = False,
use_residual: bool = True,
use_layer_scale: bool = False,
layer_scale_init_value: float = 1e-5,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
output_intermediate_endpoints: bool = False,
**kwargs,
):
"""Initializes a UniversalInvertedBottleneckBlockQuantized.
This is an extension of IB with optional depthwise convs before expansion (
"starting" conv) and after projection ("ending" conv). Both of these convs
are executed without activation. The standard depthwise conv of IB ("middle"
conv) is optional too. This last one is followed by an activation, as in
standard IBs. Squeeze-and-Excite or fused types of IBs are not supported.
Args:
in_filters: The number of filters of the input tensor.
out_filters: The number of filters of the output tensor.
expand_ratio: The filter multiplier for the first inverted bottleneck
stage.
strides: The block stride. If greater than 1, this block will ultimately
downsample the input.
middle_dw_downsample: If True, downsample in the middle depthwise
otherwise downsample in the starting one.
start_dw_kernel_size: The kernel size of the starting depthwise. A value
of zero means that no starting depthwise will be added.
middle_dw_kernel_size: The kernel size of the middle depthwise. A value of
zero means that no middle depthwise will be added.
end_dw_kernel_size: The kernel size of the ending depthwise. A value of
zero means that no ending depthwise will be added.
stochastic_depth_drop_rate: If not None, drop rate for the stochastic
depth layer.
kernel_initializer: The name of the convolutional layer
kernel_initializer.
kernel_regularizer: An optional kernel regularizer for the Conv2ds.
bias_regularizer: An optional bias regularizer for the Conv2ds.
activation: The name of the activation function.
depthwise_activation: The name of the depthwise-only activation function.
use_sync_bn: If True, use synchronized batch normalization.
dilation_rate: The dilation rate to use for convolutions.
divisible_by: Ensures all inner dimensions are divisible by this number.
regularize_depthwise: If True, apply regularization on depthwise.
use_residual: If True, include residual connection between input and
output.
use_layer_scale: If True, use layer scale.
layer_scale_init_value: The initial layer scale value.
norm_momentum: Momentum value for the moving average in normalization.
norm_epsilon: Value added to variance to avoid dividing by zero in
normalization.
output_intermediate_endpoints: This block does not output any intermediate
endpoint, but this argument is included for compatibility with other
blocks.
**kwargs: Additional keyword arguments to be passed to
tf_keras.layers.Layer.
"""
super().__init__(**kwargs)
logging.info(
'UniversalInvertedBottleneckBlockQuantized with depthwise kernel sizes '
'{%d, %d, %d}, strides=%d, and middle downsampling: %s',
start_dw_kernel_size,
middle_dw_kernel_size,
end_dw_kernel_size,
strides,
middle_dw_downsample,
)
self._in_filters = in_filters
self._out_filters = out_filters
self._expand_ratio = expand_ratio
self._strides = strides
self._middle_dw_downsample = middle_dw_downsample
self._start_dw_kernel_size = start_dw_kernel_size
self._middle_dw_kernel_size = middle_dw_kernel_size
self._end_dw_kernel_size = end_dw_kernel_size
self._divisible_by = divisible_by
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._dilation_rate = dilation_rate
self._use_sync_bn = use_sync_bn
self._regularize_depthwise = regularize_depthwise
self._use_residual = use_residual
self._activation = activation
self._depthwise_activation = depthwise_activation
self._kernel_initializer = kernel_initializer
self._use_layer_scale = use_layer_scale
self._layer_scale_init_value = layer_scale_init_value
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._output_intermediate_endpoints = output_intermediate_endpoints
if strides > 1:
if middle_dw_downsample and not middle_dw_kernel_size:
raise ValueError(
'Requested downsampling at a non-existing middle depthwise conv.'
)
if not middle_dw_downsample and not start_dw_kernel_size:
raise ValueError(
'Requested downsampling at a non-existing starting depthwise conv.'
)
if use_sync_bn:
norm_layer = tf_keras.layers.experimental.SyncBatchNormalization
else:
norm_layer = tf_keras.layers.BatchNormalization
self._norm_with_quantize = helper.BatchNormalizationQuantized(norm_layer)
self._norm = helper.BatchNormalizationNoQuantized(norm_layer)
if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
if not depthwise_activation:
self._depthwise_activation = activation
if regularize_depthwise:
self._depthwise_regularizer = kernel_regularizer
else:
self._depthwise_regularizer = None
def build(self, input_shape):
# Starting depthwise conv.
if self._start_dw_kernel_size:
self._start_dw_conv = helper.DepthwiseConv2DQuantized(
kernel_size=self._start_dw_kernel_size,
strides=self._strides if not self._middle_dw_downsample else 1,
padding='same',
depth_multiplier=1,
dilation_rate=self._dilation_rate,
use_bias=False,
depthwise_initializer=tf_utils.clone_initializer(
self._kernel_initializer
),
depthwise_regularizer=self._depthwise_regularizer,
bias_regularizer=self._bias_regularizer,
)
# No activation -> quantized norm should be okay.
self._start_dw_norm = self._norm_with_quantize(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
)
# Expansion with 1x1 convs.
expand_filters = nn_layers.make_divisible(
self._in_filters * self._expand_ratio, self._divisible_by
)
self._expand_conv = helper.Conv2DQuantized(
filters=expand_filters,
kernel_size=1,
strides=1,
padding='same',
use_bias=False,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
)
self._expand_norm = helper.norm_by_activation(
self._activation, self._norm_with_quantize, self._norm
)(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
)
self._expand_act = tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(self._activation, use_keras_layer=True),
configs.Default8BitActivationQuantizeConfig(),
)
# Middle depthwise conv.
if self._middle_dw_kernel_size:
self._middle_dw_conv = helper.DepthwiseConv2DQuantized(
kernel_size=self._middle_dw_kernel_size,
strides=self._strides if self._middle_dw_downsample else 1,
padding='same',
depth_multiplier=1,
dilation_rate=self._dilation_rate,
use_bias=False,
depthwise_initializer=tf_utils.clone_initializer(
self._kernel_initializer
),
depthwise_regularizer=self._depthwise_regularizer,
bias_regularizer=self._bias_regularizer,
)
self._middle_dw_norm = helper.norm_by_activation(
self._activation, self._norm_with_quantize, self._norm
)(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
)
self._middle_dw_act = tfmot.quantization.keras.QuantizeWrapperV2(
tf_utils.get_activation(
self._depthwise_activation, use_keras_layer=True
),
configs.Default8BitActivationQuantizeConfig(),
)
# Projection with 1x1 convs.
self._proj_conv = helper.Conv2DQuantized(
filters=self._out_filters,
kernel_size=1,
strides=1,
padding='same',
use_bias=False,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
)
# No activation -> quantized norm should be okay.
self._proj_norm = self._norm_with_quantize(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
)
# Ending depthwise conv.
if self._end_dw_kernel_size:
self._end_dw_conv = helper.DepthwiseConv2DQuantized(
kernel_size=self._end_dw_kernel_size,
strides=1,
padding='same',
depth_multiplier=1,
dilation_rate=self._dilation_rate,
use_bias=False,
depthwise_initializer=tf_utils.clone_initializer(
self._kernel_initializer
),
depthwise_regularizer=self._depthwise_regularizer,
bias_regularizer=self._bias_regularizer,
)
self._end_dw_norm = self._norm_with_quantize(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
)
if self._use_layer_scale:
raise NotImplementedError
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate
)
else:
self._stochastic_depth = None
super().build(input_shape)
def get_config(self):
config = {
'in_filters': self._in_filters,
'out_filters': self._out_filters,
'expand_ratio': self._expand_ratio,
'strides': self._strides,
'middle_dw_downsample': self._middle_dw_downsample,
'start_dw_kernel_size': self._start_dw_kernel_size,
'middle_dw_kernel_size': self._middle_dw_kernel_size,
'end_dw_kernel_size': self._end_dw_kernel_size,
'divisible_by': self._divisible_by,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'depthwise_activation': self._depthwise_activation,
'dilation_rate': self._dilation_rate,
'use_sync_bn': self._use_sync_bn,
'regularize_depthwise': self._regularize_depthwise,
'use_residual': self._use_residual,
'use_layer_scale': self._use_layer_scale,
'layer_scale_init_value': self._layer_scale_init_value,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
'output_intermediate_endpoints': self._output_intermediate_endpoints,
}
base_config = super().get_config()
return {**base_config, **config}
def call(self, inputs, training=None):
shortcut = inputs
x = inputs
if self._start_dw_kernel_size:
x = self._start_dw_conv(x)
x = self._start_dw_norm(x)
x = self._expand_conv(x)
x = self._expand_norm(x)
x = self._expand_act(x)
if self._middle_dw_kernel_size:
x = self._middle_dw_conv(x)
x = self._middle_dw_norm(x)
x = self._middle_dw_act(x)
x = self._proj_conv(x)
x = self._proj_norm(x)
if self._end_dw_kernel_size:
x = self._end_dw_conv(x)
x = self._end_dw_norm(x)
if self._use_layer_scale:
x = self._layer_scale(x)
if (
self._use_residual
and self._in_filters == self._out_filters
and self._strides == 1
):
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
x = x + shortcut
# Return empty intermediate endpoints to be compatible with other blocks.
if self._output_intermediate_endpoints:
return x, {}
return x
MaybeDwInvertedBottleneckBlockQuantized = (
UniversalInvertedBottleneckBlockQuantized
)