tensorflow/tensorflow

View on GitHub
tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py

Summary

Maintainability
D
1 day
Test Coverage
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base test class for quantize_model Tests."""

from typing import List, Mapping, Optional, Sequence, Tuple

from absl.testing import parameterized
from mlir import ir
from mlir.dialects import stablehlo as stablehlo_dialect
import numpy as np
import tensorflow  # pylint: disable=unused-import

from tensorflow.compiler.mlir.stablehlo import stablehlo
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import save as saved_model_save
from tensorflow.python.types import core


FUNC_ALIAS = 'some_alias'


class QuantizedModelTest(test.TestCase, parameterized.TestCase):
  """Base test class for StableHLO quant tests."""

  def setUp(self) -> None:
    super().setUp()

    # Many test cases for quantization involve creating and saving the input
    # model and saving the output quantized model. These two member
    # attributes can be used to specify the paths for such models,
    # respectively. These paths will be cleaned up after each test case.
    self._input_saved_model_path = self.create_tempdir('input').full_path
    self._output_saved_model_path = self.create_tempdir('output').full_path
    # Extra output path occasionally used for comparing two different
    # quantized models.
    self._output_saved_model_path_2 = self.create_tempdir('output2').full_path

  def _extract_first_xla_call_module_op(
      self, output_saved_model_path: str
  ) -> str:
    """Extracts the first XlaCallModule op from output saved model to string."""
    root = load.load(output_saved_model_path)
    tf_graph_def = root.signatures['serving_default'].graph.as_graph_def()
    for function in tf_graph_def.library.function:
      for node_def in function.node_def:
        if node_def.op == 'XlaCallModule':
          with ir.Context() as context:
            stablehlo_dialect.register_dialect(context)
            # Serialization in VHLO dialect.
            serialized = node_def.attr.get('module').s
            # MLIR bytecode matching StableHLO version.
            mlir_bytecode = stablehlo.deserialize_portable_artifact(serialized)
            stablehlo_module = ir.Module.parse(mlir_bytecode, context=context)
            return str(stablehlo_module)
    raise ValueError('No XlaCallModule found in saved model.')

  def _get_num_xla_call_module_op(self, output_saved_model_path: str) -> int:
    """Gets the number of XlaCallModule ops in the output saved model."""
    root = load.load(output_saved_model_path)
    tf_graph_def = root.signatures['serving_default'].graph.as_graph_def()
    count = 0
    for node_def in tf_graph_def.node:
      if node_def.op == 'XlaCallModule':
        count += 1
    for function in tf_graph_def.library.function:
      for node_def in function.node_def:
        if node_def.op == 'XlaCallModule':
          count += 1
    return count

  def _get_function_aliases(
      self, output_saved_model_path: str, tags: List[str]
  ) -> dict[str, str]:
    """Gets the function aliases in the output saved model."""
    loader = loader_impl.SavedModelLoader(output_saved_model_path)
    return loader.get_meta_graph_def_from_tags(
        tags
    ).meta_info_def.function_aliases

  def _create_matmul_model(
      self,
      input_shape: Sequence[int],
      weight_shape: Sequence[int],
      saved_model_path: str,
      bias_fn: Optional[ops.Operation] = None,
      activation_fn: Optional[ops.Operation] = None,
  ) -> module.Module:
    class MatmulModel(module.Module):
      """A simple model with a single matmul.

      Bias and activation function are optional.
      """

      def __init__(
          self,
          weight_shape: Sequence[int],
      ) -> None:
        """Initializes a MatmulModel.

        Args:
          weight_shape: Shape of the weight tensor.
        """
        self.filters = np.random.uniform(low=-1.0, high=1.0, size=weight_shape)

        if bias_fn is not None:
          self.bias = np.random.uniform(
              low=-1.0, high=1.0, size=weight_shape[-1]
          )

      def has_reshape(self) -> bool:
        return self.bias_fn() and self.bias_size != self.filters.shape[-1]

      @def_function.function
      def matmul(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]:
        """Performs a matrix multiplication.

        Depending on self.bias_fn and self.activation_fn, it may add a bias
        term or go through the activaction function.

        Args:
          input_tensor: Input tensor to matmul with the filter.

        Returns:
          A map of: output key -> output result.
        """
        out = math_ops.matmul(input_tensor, self.filters, name='sample/matmul')
        if bias_fn is not None:
          out = bias_fn(out, self.bias)
        if activation_fn is not None:
          out = activation_fn(out)
        return {'output': out}

    model = MatmulModel(weight_shape)
    saved_model_save.save(
        model,
        saved_model_path,
        signatures=model.matmul.get_concrete_function(
            tensor_spec.TensorSpec(
                shape=input_shape, dtype=dtypes.float32, name='input_tensor'
            )
        ),
    )
    return model

  def _any_log_contains(
      self, substring: str, log_record_list: List['logging.LogRecord']
  ) -> bool:
    """Returns True if any of the log contains a given substring.

    Args:
      substring: A piece of string to check whether it exists in the log
        message.
      log_record_list: A list of `absl.logging.LogRecord`s.

    Returns:
      True if and only if the substring exists in any of the log in
      `log_record_list`.
    """
    return any(
        map(
            lambda log_record: substring in str(log_record.message),
            log_record_list,
        )
    )

  def _create_matmul_and_same_scale_model(
      self,
      input_shape: Sequence[int],
      weight_shape: Sequence[int],
      saved_model_path: str,
      same_scale_op: str,
  ) -> module.Module:
    class MatmulAndSameScaleModel(module.Module):
      """A simple model with a same-scale op.

      Op name in StableHLO dialect is given as a string.
      """

      def __init__(
          self,
          weight_shape: Sequence[int],
          same_scale_op: str,
      ) -> None:
        """Initializes a MatmulModel.

        Args:
          weight_shape: Shape of the weight tensor.
          same_scale_op: Name of the same-scale op to be tested. Raises error
            when an unknown name is given.
        """
        self.filters = np.random.uniform(low=-1.0, high=1.0, size=weight_shape)
        self.same_scale_op = same_scale_op

      @def_function.function
      def matmul_and_same_scale(
          self, input_tensor: core.Tensor
      ) -> Mapping[str, core.Tensor]:
        """Performs a matrix multiplication.

        Args:
          input_tensor: Input tensor to matmul with the filter.

        Returns:
          A map of: output key -> output result.
        """
        out = math_ops.matmul(input_tensor, self.filters, name='sample/matmul')

        if self.same_scale_op == 'concatenate':
          ones = array_ops.ones_like(out)
          out = array_ops.concat([out, ones], 0)
        elif self.same_scale_op == 'gather':
          out = array_ops.gather(out, indices=[0], axis=0)
        elif self.same_scale_op == 'max_pool':
          out = nn_ops.max_pool(out, ksize=3, strides=1, padding='SAME')
        elif self.same_scale_op == 'pad':
          paddings = array_ops.ones(
              (array_ops.rank(out), 2), dtype=dtypes.int32
          )
          out = array_ops.pad(out, paddings, 'CONSTANT')
        elif self.same_scale_op == 'reshape':
          out = array_ops.reshape(out, [-1])
        elif self.same_scale_op == 'select':
          rng = np.random.default_rng(seed=1234)
          condition = ops.convert_to_tensor(
              rng.uniform(low=0.0, high=1.0, size=out.shape) < 0.5
          )
          ones = array_ops.ones_like(out)
          out = math_ops.select(condition, out, ones)
        elif self.same_scale_op == 'slice':
          begin = array_ops.zeros((array_ops.rank(out)), dtype=dtypes.int32)
          size = array_ops.ones((array_ops.rank(out)), dtype=dtypes.int32)
          out = array_ops.slice(out, begin, size)
        elif self.same_scale_op == 'transpose':
          out = array_ops.transpose(out)
        else:
          raise NotImplementedError(
              '{} is not implemented for integration test.'.format(
                  self.same_scale_op
              )
          )

        return {'output': out}

    model = MatmulAndSameScaleModel(weight_shape, same_scale_op)
    saved_model_save.save(
        model,
        saved_model_path,
        signatures=model.matmul_and_same_scale.get_concrete_function(
            tensor_spec.TensorSpec(
                shape=input_shape, dtype=dtypes.float32, name='input_tensor'
            )
        ),
    )
    return model

  def _create_conv2d_model(
      self,
      input_shape: Sequence[int],
      filter_shape: Sequence[int],
      saved_model_path: str,
      bias_fn: Optional[ops.Operation] = None,
      activation_fn: Optional[ops.Operation] = None,
      has_batch_norm: bool = False,
      strides: Sequence[int] = (1, 1, 1, 1),
      dilations: Sequence[int] = (1, 1, 1, 1),
      padding: str = 'SAME',
      has_func_alias: bool = False,
  ) -> module.Module:
    class ConvModel(module.Module):
      """A simple model with a single conv2d, bias and relu."""

      def __init__(self):
        self.out_channel_size = filter_shape[-1]

        # This ensures filters will have different value range per out channel
        self.filters = np.stack(
            [
                np.random.uniform(
                    low=-(i + 1), high=(i + 1), size=filter_shape[:-1]
                ).astype('f4')
                for i in range(self.out_channel_size)
            ],
            axis=-1,
        )

        self.bias = np.random.uniform(
            low=0, high=10, size=(self.out_channel_size)
        ).astype('f4')

      @def_function.function
      def conv2d(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]:
        """Performs a 2D convolution operation.

        Args:
          input_tensor: Input tensor to perform convolution on.

        Returns:
          A map of: output key -> output result.
        """
        scale = [1.0] * self.out_channel_size
        offset = [0.5] * self.out_channel_size
        mean, variance = scale, offset
        out = nn_ops.conv2d(
            input_tensor,
            self.filters,
            strides=strides,
            dilations=dilations,
            padding=padding,
            data_format='NHWC',
            name='sample/conv',
        )
        if bias_fn is not None:
          out = nn_ops.bias_add(out, self.bias)
        if has_batch_norm:
          # Fusing is supported for non-training case.
          out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3(
              out, scale, offset, mean, variance, is_training=False
          )
        if activation_fn is not None:
          out = activation_fn(out)
        return {'output': out}

    model = ConvModel()
    save_options = None
    if has_func_alias:
      save_options = tensorflow.saved_model.SaveOptions(
          function_aliases={FUNC_ALIAS: model.conv2d}
      )
    saved_model_save.save(
        model,
        saved_model_path,
        signatures=model.conv2d.get_concrete_function(
            tensor_spec.TensorSpec(
                shape=input_shape, dtype=dtypes.float32, name='input_tensor'
            )
        ),
        options=save_options,
    )
    return model

  def _create_gather_model(self, input_type, use_variable) -> module.Module:
    class GatherModel(module.Module):
      """A simple model with a single gather."""

      def __init__(self, use_variable):
        """Initializes a GatherModel.

        Args:
          use_variable: If True, creates a variable for weight.
        """
        super().__init__()
        w_val = np.random.randn(128, 32).astype('f4')
        if use_variable:
          self.w = variables.Variable(w_val)
        else:
          self.w = w_val

      @def_function.function(
          input_signature=[
              tensor_spec.TensorSpec(
                  shape=[6], dtype=input_type, name='input_tensor'
              )
          ]
      )
      def __call__(
          self, input_tensor: core.Tensor
      ) -> Mapping[str, core.Tensor]:
        """Performs a gather operation."""
        out = array_ops.gather_v2(self.w, input_tensor)
        return {'output': out}

    return GatherModel(use_variable)

  def _create_add_model(
      self,
      shape: Sequence[int],
      saved_model_path: str,
  ) -> module.Module:
    class AddModel(module.Module):
      """A simple model with a single add."""

      def __init__(self):
        pass

      @def_function.function
      def add(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]:
        """Performs an add operation.

        Args:
          input_tensor: Input tensor to perform add on.

        Returns:
          A map of: output key -> output result.
        """
        out = math_ops.add(input_tensor, input_tensor)
        return {'output': out}

    model = AddModel()
    saved_model_save.save(
        model,
        saved_model_path,
        signatures=model.add.get_concrete_function(
            tensor_spec.TensorSpec(
                shape=shape, dtype=dtypes.float32, name='input_tensor'
            )
        ),
    )
    return model

  # Prepares sample einsum input data shapes.
  # This function returns:
  # 1. Shape for input 1
  # 2. Shape for input 2
  # 3. Shape for bias
  # 4. Signature for input 1 (Could contain None dimension)
  # 5. Signature for input 2 (Could contain None dimension)
  def _prepare_sample_einsum_datashapes(
      self,
      equation: str,
      generate_unknown_shape_signature: bool = False,
      use_bias: bool = False,
  ) -> Tuple[
      List[Optional[int]],
      List[Optional[int]],
      Optional[List[Optional[int]]],
      List[Optional[int]],
      List[Optional[int]],
  ]:
    # 1. Parse equation.
    comma_pos = equation.find(',')
    arrow_pos = equation.find('->')
    x_labels = equation[0:comma_pos]
    y_labels = equation[comma_pos + 1 : arrow_pos]
    out_labels = equation[arrow_pos + 1 :]

    # 2. Create sample shapes.
    label_to_size = {'a': 4, 'b': 32, 'c': 64, 'd': 128, 'e': 8}
    x_shape = [label_to_size.get(x_label) for x_label in x_labels]
    y_shape = [label_to_size.get(y_label) for y_label in y_labels]
    bias_shape = None
    if use_bias:
      bias_shape = [label_to_size.get(out_label) for out_label in out_labels]
      bias_shape = bias_shape[-1:]
    contracting_dims = set()

    x_signature = list(x_shape)
    y_signature = list(y_shape)
    if generate_unknown_shape_signature:
      for c in x_labels:
        if c in y_labels:
          contracting_dims.add(c)
      x_signature = [
          None if c not in contracting_dims else x_shape[cidx]
          for cidx, c in enumerate(x_labels)
      ]
      y_signature = [
          None if c not in contracting_dims else y_shape[cidx]
          for cidx, c in enumerate(y_labels)
      ]
    return x_shape, y_shape, bias_shape, x_signature, y_signature

  def _create_einsum_model(
      self,
      saved_model_path: str,
      equation: str,
      y_shape: Sequence[int],
      x_signature: Sequence[Optional[int]],
      y_signature: Sequence[Optional[int]],
      bias_shape: Optional[Sequence[int]] = None,
  ) -> module.Module:
    class EinsumModel(module.Module):
      """Einsum class."""

      def __init__(self):
        self._bias = None
        if bias_shape is not None:
          self._bias = array_ops.constant(
              np.random.uniform(size=bias_shape), dtype=dtypes.float32
          )

        self._kernel = np.random.uniform(size=y_shape).astype('f4')
        self._min = (-0.8, -0.8, -0.9)
        self._max = (0.9, 0.9, 1.0)

      @def_function.function(
          input_signature=[
              tensor_spec.TensorSpec(
                  name='x', shape=x_signature, dtype=dtypes.float32
              )
          ]
      )
      def einsum_with_kernel(self, x: core.Tensor) -> Mapping[str, core.Tensor]:
        return self._einsum(x, self._kernel)

      @def_function.function(
          input_signature=[
              tensor_spec.TensorSpec(
                  name='x', shape=x_signature, dtype=dtypes.float32
              ),
              tensor_spec.TensorSpec(
                  name='y', shape=y_signature, dtype=dtypes.float32
              ),
          ]
      )
      def einsum_without_kernel(
          self, x: core.Tensor, y: core.Tensor
      ) -> Mapping[str, core.Tensor]:
        return self._einsum(x, y)

      def _einsum(self, x, y):

        out = tensorflow.einsum(equation, x, y)
        if self._bias is not None:
          out = nn_ops.bias_add(out, self._bias)
        return {'output': out}

    model = EinsumModel()
    signatures = {
        'serving_default': model.einsum_with_kernel.get_concrete_function(
            tensor_spec.TensorSpec(
                name='x', shape=x_signature, dtype=dtypes.float32
            )
        ),
    }
    saved_model_save.save(model, saved_model_path, signatures=signatures)
    return model