tensorflow/models

View on GitHub
official/projects/qat/vision/n_bit/configs.py

Summary

Maintainability
B
4 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Default 8-bit QuantizeConfigs."""
from typing import Sequence, Callable, Tuple, Any, Dict

import tensorflow as tf, tf_keras
import tensorflow_model_optimization as tfmot


Quantizer = tfmot.quantization.keras.quantizers.Quantizer
Layer = tf_keras.layers.Layer
Activation = Callable[[tf.Tensor], tf.Tensor]
WeightAndQuantizer = Tuple[tf.Variable, Quantizer]
ActivationAndQuantizer = Tuple[Activation, Quantizer]


class DefaultNBitOutputQuantizeConfig(
    tfmot.quantization.keras.QuantizeConfig):
  """QuantizeConfig which only quantizes the output from a layer."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def get_weights_and_quantizers(
      self, layer: Layer) -> Sequence[WeightAndQuantizer]:
    return []

  def get_activations_and_quantizers(
      self, layer: Layer) -> Sequence[ActivationAndQuantizer]:
    return []

  def set_quantize_weights(self,
                           layer: Layer,
                           quantize_weights: Sequence[tf.Tensor]):
    pass

  def set_quantize_activations(self,
                               layer: Layer,
                               quantize_activations: Sequence[Activation]):
    pass

  def get_output_quantizers(self, layer: Layer) -> Sequence[Quantizer]:
    return [
        tfmot.quantization.keras.quantizers.MovingAverageQuantizer(
            num_bits=self._num_bits_activation, per_axis=False,
            symmetric=False, narrow_range=False)  # activation/output
    ]

  def get_config(self) -> Dict[str, Any]:
    return {
        'num_bits_weight': self._num_bits_weight,
        'num_bits_activation': self._num_bits_activation,
    }


class NoOpQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
  """QuantizeConfig which does not quantize any part of the layer."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def get_weights_and_quantizers(
      self, layer: Layer) -> Sequence[WeightAndQuantizer]:
    return []

  def get_activations_and_quantizers(
      self, layer: Layer) -> Sequence[ActivationAndQuantizer]:
    return []

  def set_quantize_weights(
      self,
      layer: Layer,
      quantize_weights: Sequence[tf.Tensor]):
    pass

  def set_quantize_activations(
      self,
      layer: Layer,
      quantize_activations: Sequence[Activation]):
    pass

  def get_output_quantizers(self, layer: Layer) -> Sequence[Quantizer]:
    return []

  def get_config(self) -> Dict[str, Any]:
    return {
        'num_bits_weight': self._num_bits_weight,
        'num_bits_activation': self._num_bits_activation,
    }


class DefaultNBitQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
  """QuantizeConfig for non recurrent Keras layers."""

  def __init__(self,
               weight_attrs: Sequence[str],
               activation_attrs: Sequence[str],
               quantize_output: bool,
               num_bits_weight: int = 8,
               num_bits_activation: int = 8):
    """Initializes a default N-bit quantize config."""
    self.weight_attrs = weight_attrs
    self.activation_attrs = activation_attrs
    self.quantize_output = quantize_output
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

    # TODO(pulkitb): For some layers such as Conv2D, per_axis should be True.
    # Add mapping for which layers support per_axis.
    self.weight_quantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer(
        num_bits=num_bits_weight, per_axis=False,
        symmetric=True, narrow_range=True)  # weight
    self.activation_quantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer(
        num_bits=num_bits_activation, per_axis=False,
        symmetric=False, narrow_range=False)  # activation/output

  def get_weights_and_quantizers(
      self, layer: Layer) -> Sequence[WeightAndQuantizer]:
    """See base class."""
    return [(getattr(layer, weight_attr), self.weight_quantizer)
            for weight_attr in self.weight_attrs]

  def get_activations_and_quantizers(
      self, layer: Layer) -> Sequence[ActivationAndQuantizer]:
    """See base class."""
    return [(getattr(layer, activation_attr), self.activation_quantizer)
            for activation_attr in self.activation_attrs]

  def set_quantize_weights(
      self,
      layer: Layer,
      quantize_weights: Sequence[tf.Tensor]):
    """See base class."""
    if len(self.weight_attrs) != len(quantize_weights):
      raise ValueError(
          '`set_quantize_weights` called on layer {} with {} '
          'weight parameters, but layer expects {} values.'.format(
              layer.name, len(quantize_weights), len(self.weight_attrs)))

    for weight_attr, weight in zip(self.weight_attrs, quantize_weights):
      current_weight = getattr(layer, weight_attr)
      if current_weight.shape != weight.shape:
        raise ValueError('Existing layer weight shape {} is incompatible with'
                         'provided weight shape {}'.format(
                             current_weight.shape, weight.shape))

      setattr(layer, weight_attr, weight)

  def set_quantize_activations(
      self,
      layer: Layer,
      quantize_activations: Sequence[Activation]):
    """See base class."""
    if len(self.activation_attrs) != len(quantize_activations):
      raise ValueError(
          '`set_quantize_activations` called on layer {} with {} '
          'activation parameters, but layer expects {} values.'.format(
              layer.name, len(quantize_activations),
              len(self.activation_attrs)))

    for activation_attr, activation in zip(
        self.activation_attrs, quantize_activations):
      setattr(layer, activation_attr, activation)

  def get_output_quantizers(self, layer: Layer) -> Sequence[Quantizer]:
    """See base class."""
    if self.quantize_output:
      return [self.activation_quantizer]
    return []

  @classmethod
  def from_config(cls, config: Dict[str, Any]) -> object:
    """Instantiates a `DefaultNBitQuantizeConfig` from its config.

    Args:
        config: Output of `get_config()`.

    Returns:
        A `DefaultNBitQuantizeConfig` instance.
    """
    return cls(**config)

  def get_config(self) -> Dict[str, Any]:
    """Get a config for this quantize config."""
    # TODO(pulkitb): Add weight and activation quantizer to config.
    # Currently it's created internally, but ideally the quantizers should be
    # part of the constructor and passed in from the registry.
    return {
        'weight_attrs': self.weight_attrs,
        'activation_attrs': self.activation_attrs,
        'quantize_output': self.quantize_output,
        'num_bits_weight': self._num_bits_weight,
        'num_bits_activation': self._num_bits_activation
    }

  def __eq__(self, other):
    if not isinstance(other, DefaultNBitQuantizeConfig):
      return False

    return (self.weight_attrs == other.weight_attrs and
            self.activation_attrs == self.activation_attrs and
            self.weight_quantizer == other.weight_quantizer and
            self.activation_quantizer == other.activation_quantizer and
            self.quantize_output == other.quantize_output)

  def __ne__(self, other):
    return not self.__eq__(other)


class DefaultNBitConvWeightsQuantizer(
    tfmot.quantization.keras.quantizers.LastValueQuantizer):
  """Quantizer for handling weights in Conv2D/DepthwiseConv2D layers."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    """Construct LastValueQuantizer with params specific for TFLite Convs."""

    super(DefaultNBitConvWeightsQuantizer, self).__init__(
        num_bits=num_bits_weight, per_axis=True,
        symmetric=True, narrow_range=True)  # weight
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def build(self,
            tensor_shape: tf.TensorShape,
            name: str,
            layer: Layer):
    """Build min/max quantization variables."""
    min_weight = layer.add_weight(
        name + '_min',
        shape=(tensor_shape[-1],),
        initializer=tf_keras.initializers.Constant(-6.0),
        trainable=False)
    max_weight = layer.add_weight(
        name + '_max',
        shape=(tensor_shape[-1],),
        initializer=tf_keras.initializers.Constant(6.0),
        trainable=False)

    return {'min_var': min_weight, 'max_var': max_weight}


class NoQuantizer(tfmot.quantization.keras.quantizers.Quantizer):
  """Dummy quantizer for explicitly not quantize."""

  def __call__(self, inputs, training, weights, **kwargs):
    return tf.identity(inputs)

  def get_config(self):
    return {}

  def build(self, tensor_shape, name, layer):
    return {}


class DefaultNBitConvQuantizeConfig(DefaultNBitQuantizeConfig):
  """QuantizeConfig for Conv2D/DepthwiseConv2D layers."""

  def __init__(self,
               weight_attrs: Sequence[str],
               activation_attrs: Sequence[str],
               quantize_output: bool,
               num_bits_weight: int = 8,
               num_bits_activation: int = 8):
    """Initializes default N-bit quantization config for the conv layer."""
    super().__init__(weight_attrs=weight_attrs,
                     activation_attrs=activation_attrs,
                     quantize_output=quantize_output,
                     num_bits_weight=num_bits_weight,
                     num_bits_activation=num_bits_activation)
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

    self.weight_quantizer = DefaultNBitConvWeightsQuantizer(
        num_bits_weight=num_bits_weight,
        num_bits_activation=num_bits_activation)


class DefaultNBitActivationQuantizeConfig(
    tfmot.quantization.keras.QuantizeConfig):
  """QuantizeConfig for keras.layers.Activation.

  `keras.layers.Activation` needs a separate `QuantizeConfig` since the
  decision to quantize depends on the specific activation type.
  """

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def _assert_activation_layer(self, layer: Layer):
    if not isinstance(layer, tf_keras.layers.Activation):
      raise RuntimeError(
          'DefaultNBitActivationQuantizeConfig can only be used with '
          '`keras.layers.Activation`.')

  def get_weights_and_quantizers(
      self, layer: Layer) -> Sequence[WeightAndQuantizer]:
    """See base class."""
    self._assert_activation_layer(layer)
    return []

  def get_activations_and_quantizers(
      self, layer: Layer) -> Sequence[ActivationAndQuantizer]:
    """See base class."""
    self._assert_activation_layer(layer)
    return []

  def set_quantize_weights(
      self,
      layer: Layer,
      quantize_weights: Sequence[tf.Tensor]):
    """See base class."""
    self._assert_activation_layer(layer)

  def set_quantize_activations(
      self,
      layer: Layer,
      quantize_activations: Sequence[Activation]):
    """See base class."""
    self._assert_activation_layer(layer)

  def get_output_quantizers(self, layer: Layer) -> Sequence[Quantizer]:
    """See base class."""
    self._assert_activation_layer(layer)

    if not hasattr(layer.activation, '__name__'):
      raise ValueError('Activation {} not supported by '
                       'DefaultNBitActivationQuantizeConfig.'.format(
                           layer.activation))

    # This code is copied from TFMOT repo, but added relu6 to support mobilenet.
    if layer.activation.__name__ in ['relu', 'relu6', 'swish']:
      # 'relu' should generally get fused into the previous layer.
      return [tfmot.quantization.keras.quantizers.MovingAverageQuantizer(
          num_bits=self._num_bits_activation, per_axis=False,
          symmetric=False, narrow_range=False)]  # activation/output
    elif layer.activation.__name__ in ['linear', 'softmax', 'sigmoid']:
      return []

    raise ValueError('Activation {} not supported by '
                     'DefaultNBitActivationQuantizeConfig.'.format(
                         layer.activation))

  def get_config(self) -> Dict[str, Any]:
    """Get a config for this quantizer config."""
    return {
        'num_bits_weight': self._num_bits_weight,
        'num_bits_activation': self._num_bits_activation,
    }


def _types_dict():
  return {
      'DefaultNBitOutputQuantizeConfig':
          DefaultNBitOutputQuantizeConfig,
      'NoOpQuantizeConfig':
          NoOpQuantizeConfig,
      'DefaultNBitQuantizeConfig':
          DefaultNBitQuantizeConfig,
      'DefaultNBitConvWeightsQuantizer':
          DefaultNBitConvWeightsQuantizer,
      'DefaultNBitConvQuantizeConfig':
          DefaultNBitConvQuantizeConfig,
      'DefaultNBitActivationQuantizeConfig':
          DefaultNBitActivationQuantizeConfig,
  }