official/projects/qat/vision/quantization/configs_test.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for configs.py."""
# Import libraries
import numpy as np
import tensorflow as tf, tf_keras
import tensorflow_model_optimization as tfmot
from official.modeling import tf_utils
from official.projects.qat.vision.quantization import configs
class _TestHelper(object):
def _convert_list(self, list_of_tuples):
"""Transforms a list of 2-tuples to a tuple of 2 lists.
`QuantizeConfig` methods return a list of 2-tuples in the form
[(weight1, quantizer1), (weight2, quantizer2)]. This function converts
it into a 2-tuple of lists. ([weight1, weight2]), (quantizer1, quantizer2).
Args:
list_of_tuples: List of 2-tuples.
Returns:
2-tuple of lists.
"""
list1 = []
list2 = []
for a, b in list_of_tuples:
list1.append(a)
list2.append(b)
return list1, list2
# TODO(pulkitb): Consider asserting on full equality for quantizers.
def _assert_weight_quantizers(self, quantizer_list):
for quantizer in quantizer_list:
self.assertIsInstance(
quantizer,
tfmot.quantization.keras.quantizers.LastValueQuantizer)
def _assert_activation_quantizers(self, quantizer_list):
for quantizer in quantizer_list:
self.assertIsInstance(
quantizer,
tfmot.quantization.keras.quantizers.MovingAverageQuantizer)
def _assert_kernel_equality(self, a, b):
self.assertAllEqual(a.numpy(), b.numpy())
class Default8BitQuantizeConfigTest(tf.test.TestCase, _TestHelper):
def _simple_dense_layer(self):
layer = tf_keras.layers.Dense(2)
layer.build(input_shape=(3,))
return layer
def testGetsQuantizeWeightsAndQuantizers(self):
layer = self._simple_dense_layer()
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
(weights, weight_quantizers) = self._convert_list(
quantize_config.get_weights_and_quantizers(layer))
self._assert_weight_quantizers(weight_quantizers)
self.assertEqual([layer.kernel], weights)
def testGetsQuantizeActivationsAndQuantizers(self):
layer = self._simple_dense_layer()
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
(activations, activation_quantizers) = self._convert_list(
quantize_config.get_activations_and_quantizers(layer))
self._assert_activation_quantizers(activation_quantizers)
self.assertEqual([layer.activation], activations)
def testSetsQuantizeWeights(self):
layer = self._simple_dense_layer()
quantize_kernel = tf_keras.backend.variable(
np.ones(layer.kernel.shape.as_list()))
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
quantize_config.set_quantize_weights(layer, [quantize_kernel])
self._assert_kernel_equality(layer.kernel, quantize_kernel)
def testSetsQuantizeActivations(self):
layer = self._simple_dense_layer()
quantize_activation = tf_keras.activations.relu
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
quantize_config.set_quantize_activations(layer, [quantize_activation])
self.assertEqual(layer.activation, quantize_activation)
def testSetsQuantizeWeights_ErrorOnWrongNumberOfWeights(self):
layer = self._simple_dense_layer()
quantize_kernel = tf_keras.backend.variable(
np.ones(layer.kernel.shape.as_list()))
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
with self.assertRaises(ValueError):
quantize_config.set_quantize_weights(layer, [])
with self.assertRaises(ValueError):
quantize_config.set_quantize_weights(layer,
[quantize_kernel, quantize_kernel])
def testSetsQuantizeWeights_ErrorOnWrongShapeOfWeight(self):
layer = self._simple_dense_layer()
quantize_kernel = tf_keras.backend.variable(np.ones([1, 2]))
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
with self.assertRaises(ValueError):
quantize_config.set_quantize_weights(layer, [quantize_kernel])
def testSetsQuantizeActivations_ErrorOnWrongNumberOfActivations(self):
layer = self._simple_dense_layer()
quantize_activation = tf_keras.activations.relu
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
with self.assertRaises(ValueError):
quantize_config.set_quantize_activations(layer, [])
with self.assertRaises(ValueError):
quantize_config.set_quantize_activations(
layer, [quantize_activation, quantize_activation])
def testGetsResultQuantizers_ReturnsQuantizer(self):
layer = self._simple_dense_layer()
quantize_config = configs.Default8BitQuantizeConfig(
[], [], True)
output_quantizers = quantize_config.get_output_quantizers(layer)
self.assertLen(output_quantizers, 1)
self._assert_activation_quantizers(output_quantizers)
def testGetsResultQuantizers_EmptyWhenFalse(self):
layer = self._simple_dense_layer()
quantize_config = configs.Default8BitQuantizeConfig(
[], [], False)
output_quantizers = quantize_config.get_output_quantizers(layer)
self.assertEqual([], output_quantizers)
def testSerialization(self):
quantize_config = configs.Default8BitQuantizeConfig(
['kernel'], ['activation'], False)
expected_config = {
'class_name': 'Default8BitQuantizeConfig',
'config': {
'weight_attrs': ['kernel'],
'activation_attrs': ['activation'],
'quantize_output': False
}
}
serialized_quantize_config = tf_utils.serialize_keras_object(
quantize_config
)
self.assertEqual(expected_config, serialized_quantize_config)
quantize_config_from_config = (
tf_utils.deserialize_keras_object(
serialized_quantize_config,
module_objects=globals(),
custom_objects=configs._types_dict(),
)
)
self.assertEqual(quantize_config, quantize_config_from_config)
if __name__ == '__main__':
tf.test.main()