tensorflow/tensorflow

View on GitHub
tensorflow/lite/python/lite_test.py

Summary

Maintainability
F
1 mo
Test Coverage
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for lite.py."""

import io
import logging
import os
import tempfile

from absl.testing import parameterized
import numpy as np
from tensorflow import keras

from tensorflow.lite.python import conversion_metadata_schema_py_generated as metadata_fb
from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_constants
from tensorflow.lite.python import schema_py_generated as schema_fb
from tensorflow.lite.python import util
from tensorflow.lite.python.convert import ConverterError
from tensorflow.lite.python.convert import mlir_quantize
from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.lite.python.interpreter import OpResolverType
from tensorflow.lite.python.util import get_conversion_metadata
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import convert_to_constants
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
from tensorflow.python.platform import gfile
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.training.training_util import write_graph


class LiteTest(test_util.TensorFlowTestCase):
  """Base class of all the tests in this module."""


class TestModels(LiteTest):

  def assertValidDebugInfo(self, debug_info):
    """Verify the DebugInfo is valid."""
    file_names = set()
    for file_path in debug_info.files:
      file_names.add(os.path.basename(file_path))
    # To make the test independent on how the nodes are created, we only assert
    # the name of this test file.
    self.assertIn('lite_test.py', file_names)
    self.assertNotIn('lite_v2_test.py', file_names)


class FromConstructor(TestModels):

  # Tests invalid constructors using a dummy value for the GraphDef.
  def testInvalidConstructor(self):
    message = (
        'If input_tensors and output_tensors are None, both '
        'input_arrays_with_shape and output_arrays|control_output_arrays must '
        'be defined.')

    # `output_arrays` is not defined.
    with self.assertRaises(ValueError) as error:
      lite.TFLiteConverter(
          None, None, [], input_arrays_with_shape=[('input', [3,
                                                              9])]).convert()
    self.assertEqual(message, str(error.exception))

    # `input_arrays_with_shape` is not defined.
    with self.assertRaises(ValueError) as error:
      lite.TFLiteConverter(None, [], None, output_arrays=['output']).convert()
    self.assertEqual(message, str(error.exception))

  # Tests valid constructors using a dummy value for the GraphDef.
  def testValidConstructor(self):
    converter = lite.TFLiteConverter(
        None,
        None,
        None,
        input_arrays_with_shape=[('input', [3, 9])],
        output_arrays=['output'])
    self.assertFalse(converter._has_valid_tensors())
    self.assertEqual(converter.get_input_arrays(), ['input'])

    with self.assertRaises(ValueError) as error:
      converter._set_batch_size(1)
    self.assertEqual(
        'The batch size cannot be set for this model. Please use '
        'input_shapes parameter.', str(error.exception))

    converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor'])
    self.assertTrue(converter._has_valid_tensors())

  def testRedundantArgumentsWarning(self):
    """Test if the warning message when there are redundant arguments."""
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
      out_tensor = math_ops.add(in_tensor, in_tensor, name='add')
      sess = session.Session()

    frozen_graph_def = (
        convert_to_constants.convert_variables_to_constants_from_session_graph(
            sess, sess.graph_def, ['add']))

    # Convert model and ensure model is not None.
    log = io.StringIO()
    handler = logging.StreamHandler(log)
    logging.root.addHandler(handler)
    converter = lite.TFLiteConverter(frozen_graph_def, [in_tensor],
                                     [out_tensor],
                                     [('in_tensor', [2, 16, 16, 3])], ['add'])

    input_warning_message = 'input_arrays_with_shape will be ignored'
    output_warning_message = 'output_arrays will be ignored'

    # Convert model and ensure model is not None.
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)
    self.assertIn(input_warning_message, log.getvalue())
    self.assertIn(output_warning_message, log.getvalue())
    logging.root.removeHandler(handler)

  def testShapeOverriding(self):
    """Test a shape overriding case via the constructor."""
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
      math_ops.add(in_tensor, in_tensor, name='add')
      sess = session.Session()

    frozen_graph_def = (
        convert_to_constants.convert_variables_to_constants_from_session_graph(
            sess, sess.graph_def, ['add']))

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
                                     [('in_tensor', [2, 16, 16, 3])], ['add'])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('in_tensor', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('add', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

  def testPartialShapeOverriding(self):
    """Test a partial shape overriding case via the constructor."""
    with ops.Graph().as_default():
      in_tensor_a = array_ops.placeholder(
          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_a')
      in_tensor_b = array_ops.placeholder(
          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_b')
      math_ops.add(in_tensor_a, in_tensor_b, name='add')
      sess = session.Session()

    frozen_graph_def = (
        convert_to_constants.convert_variables_to_constants_from_session_graph(
            sess, sess.graph_def, ['add']))

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
                                     [('in_tensor_a', [2, 16, 16, 3])], ['add'])
    # There is an unhandled Placeholder op.
    with self.assertRaises(ConverterError):
      converter.convert()

  def testInvalidShapeOverriding(self):
    """Test an invalid shape overriding case via the constructor."""
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
      math_ops.add(in_tensor, in_tensor, name='add')
      sess = session.Session()

    frozen_graph_def = (
        convert_to_constants.convert_variables_to_constants_from_session_graph(
            sess, sess.graph_def, ['add']))

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
                                     [('wrong_tensor', [2, 16, 16, 3])],
                                     ['add'])
    with self.assertRaises(ConverterError):
      converter.convert()


class FromSessionTest(TestModels, parameterized.TestCase):

  def testFloatModel(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('add', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

  def testFloatModelQuantizedInput(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    converter.inference_input_type = dtypes.uint8
    converter.inference_type = dtypes.float32
    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.uint8, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((1., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('add', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])  # float

  def testForgottenCallToAllocateTensors(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()
    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    input_index = interpreter.get_input_details()[0]['index']
    dummy_tensor = np.ones(shape=[1, 16, 16, 3], dtype=np.float32)
    with self.assertRaises(ValueError):
      interpreter.set_tensor(input_index, dummy_tensor)

  @parameterized.named_parameters(
      ('_INT8InputOutput', False, False, dtypes.int8),
      ('_UINT8InputOutput', False, False, dtypes.uint8),
      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16),
      ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True),
      ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True))
  def testIntegerQuantizationWithUnsupportedOps(self,
                                                is_int_only,
                                                is_int16_quantize,
                                                inference_input_output_type,
                                                enable_mlir_quantizer=False):
    with ops.Graph().as_default():
      in_tensor_a = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
      in_tensor_b = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
      # ceil kernel does not support int8 nor int16 types neither.
      left = math_ops.ceil(in_tensor_a)
      out_tensor_b = math_ops.tanh(in_tensor_b)
      add = math_ops.add(left, out_tensor_b)
      # ceil kernel does not support int8 nor int16 types neither.
      out_tensor_a = math_ops.ceil(add)
      sess = session.Session()

    def calibration_gen():
      for _ in range(5):
        yield [
            np.random.uniform(-1, 1, size=(3)).astype(np.float32),
            np.random.uniform(-1, 1, size=(3)).astype(np.float32)
        ]

    quantized_converter = lite.TFLiteConverter.from_session(
        sess, [in_tensor_a, in_tensor_b], [out_tensor_a, out_tensor_b])
    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
    quantized_converter.representative_dataset = calibration_gen
    if is_int_only:
      if is_int16_quantize:
        quantized_converter.target_spec.supported_ops = [
            lite.OpsSet
            .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
            lite.OpsSet.TFLITE_BUILTINS
        ]
      else:
        quantized_converter.target_spec.supported_ops = [
            lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.TFLITE_BUILTINS
        ]
    else:
      if is_int16_quantize:
        quantized_converter.target_spec.supported_ops = [
            lite.OpsSet
            .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
            lite.OpsSet.TFLITE_BUILTINS
        ]
      else:
        quantized_converter.target_spec.supported_ops = [
            lite.OpsSet.TFLITE_BUILTINS
        ]

    quantized_converter.inference_input_type = inference_input_output_type
    quantized_converter.inference_output_type = inference_input_output_type
    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
    quantized_tflite_model = quantized_converter.convert()
    self.assertIsNotNone(quantized_tflite_model)

    expected_dtype = inference_input_output_type.as_numpy_dtype
    # Allow float32 for fallback on non-quantizable op.
    expected_ceil_dtype = (
        expected_dtype if enable_mlir_quantizer else dtypes.float32)

    interpreter = Interpreter(model_content=quantized_tflite_model)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 2)
    self.assertEqual(input_details[0]['dtype'], expected_ceil_dtype)
    self.assertEqual(input_details[1]['dtype'], expected_dtype)
    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 2)
    self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype)
    self.assertEqual(output_details[1]['dtype'], expected_dtype)

  @parameterized.named_parameters(
      ('_PerChannelQuant', False, False),
      ('_PerChannelMlirQuant', False, True),
      ('_PerTensorQuant', True, False),
      ('_PerTensorMlirQuant', True, True),
      ('_PerChannelMlirDynamicRangeQuant', False, False, False),
      ('_PerTensorMlirDynamicRangeQuant', True, False, False))
  def testDisablePerChannelQuantization(self,
                                        disable_per_channel=False,
                                        enable_mlir_quantizer=False,
                                        representative_dataset=True):
    if enable_mlir_quantizer:
      if disable_per_channel:
        k_conv_name = 'output1'
      else:
        k_conv_name = 'tfl.pseudo_qconst1'
    else:
      k_conv_name = 'Conv2D1'

    # Dynamic range quant requires total num elements of filters > 1024.
    k_num_filters = 38
    with ops.Graph().as_default():
      inp, output, calibration_gen = self._getIntegerQuantizeModel(
          k_num_filters)
      sess = session.Session()

    quantized_converter = lite.TFLiteConverter.from_session(
        sess, [inp], [output])
    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
    if representative_dataset:
      quantized_converter.representative_dataset = calibration_gen
    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
    if disable_per_channel:
      quantized_converter._experimental_disable_per_channel = (
          disable_per_channel)
    quantized_tflite_model = quantized_converter.convert()
    self.assertIsNotNone(quantized_tflite_model)

    # Do not apply delegates as XNNPack converts per tensor to per channel.
    interp = Interpreter(
        model_content=quantized_tflite_model,
        experimental_op_resolver_type=OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES,
    )
    interp.allocate_tensors()
    detail = next(
        (d for d in interp.get_tensor_details() if d['name'] == k_conv_name)
    )
    quant_params = detail['quantization_parameters']
    expected_num_params = 1 if disable_per_channel else k_num_filters
    self.assertLen(quant_params['scales'], expected_num_params)
    self.assertLen(quant_params['zero_points'], expected_num_params)

  def testString(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string)
      out_tensor = array_ops.reshape(in_tensor, shape=[2, 2])
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.string_, input_details[0]['dtype'])
    self.assertAllEqual([4], input_details[0]['shape'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('Reshape', output_details[0]['name'])
    self.assertEqual(np.string_, output_details[0]['dtype'])
    self.assertAllEqual([2, 2], output_details[0]['shape'])
    # TODO(b/122659643): Test setting/getting string data via the python
    # interpreter API after support has been added.

  def testIntermediateInputArray(self):
    """Convert a model from an intermediate input array."""
    with ops.Graph().as_default():
      in_tensor_init = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      in_tensor_final = in_tensor_init + in_tensor_init
      out_tensor = in_tensor_final + in_tensor_final
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor_final],
                                                  [out_tensor])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('add', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('add_1', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

  def testSizeNoneInvalid(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Test None as shape when dynamic shapes are disabled. Run with TOCO in
    # order to invoke shape checking code.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    converter.experimental_new_converter = False
    with self.assertRaises(ValueError) as error:
      converter.convert()
    self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
                     str(error.exception))

  def testScalarValid(self):
    # Construct a graph using a scalar (empty shape) input.
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[])
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Test conversion with the scalar input shape.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertEmpty(input_details[0]['shape'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('add', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertEmpty(input_details[0]['shape'])

    # Validate inference using the scalar inputs/outputs.
    test_input = np.array(4.0, dtype=np.float32)
    expected_output = np.array(8.0, dtype=np.float32)
    interpreter.set_tensor(input_details[0]['index'], test_input)
    interpreter.invoke()

    output_data = interpreter.get_tensor(output_details[0]['index'])
    self.assertEqual(expected_output, output_data)

  def testSizeInvalid(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, None, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Test invalid shape. None after 1st dimension. Run with TOCO in order to
    # invoke shape checking code.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    converter.experimental_new_converter = False
    with self.assertRaises(ValueError) as error:
      converter.convert()
    self.assertEqual(
        'None is only supported in the 1st dimension. Tensor '
        '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.',
        str(error.exception))

  def testSizeNone(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, None, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Test None after 1st dimension.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 1, 16, 3], input_details[0]['shape'])
    self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    # Resize tensor with strict checking.
    with self.assertRaises(RuntimeError) as error:
      interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
    self.assertIn(
        'ResizeInputTensorStrict only allows mutating unknown dimensions '
        'identified by -1.', str(error.exception))

    # Resize tensor and invoke.
    interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
    interpreter.allocate_tensors()

    test_input = np.full([1, 16, 16, 3], 1.0, dtype=np.float32)
    interpreter.set_tensor(input_details[0]['index'], test_input)
    interpreter.invoke()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature'])

    output_details = interpreter.get_output_details()
    self.assertAllEqual([1, -1, 16, 3], output_details[0]['shape_signature'])

  def testResizeTensorInputStrict(self):
    # Ensures that resize_tensor_input(strict=True) works as expected.
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)

    # Resize incorrect value.
    with self.assertRaises(RuntimeError) as error:
      interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
    self.assertIn(
        'ResizeInputTensorStrict only allows mutating unknown dimensions '
        'identified by -1.', str(error.exception))

    # Resize correct value.
    interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
    interpreter.allocate_tensors()

  def testBatchSizeValid(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[None, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('add', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

  def testBatchSizeNonZero(self):
    with ops.Graph().as_default():
      in_tensor_1 = array_ops.placeholder(
          shape=[None, 4], dtype=dtypes.float32, name='input1')
      in_tensor_2 = array_ops.placeholder(
          shape=[4, 10], dtype=dtypes.float32, name='input2')
      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2)
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess,
                                                  [in_tensor_1, in_tensor_2],
                                                  [out_tensor])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 2)
    self.assertEqual('input1', input_details[0]['name'])
    self.assertAllEqual([1, 4], input_details[0]['shape'])
    self.assertEqual('input2', input_details[1]['name'])
    self.assertAllEqual([4, 10], input_details[1]['shape'])

  def testFreezeGraph(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      var = variable_scope.get_variable(
          'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
      # Get the second output to ensure freezing properly processes tensor names
      # like 'X:1'.
      out_tensor = nn_ops.top_k(in_tensor + var, name='top_k')[1]
      sess = session.Session()
      sess.run(_global_variables_initializer())

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('top_k:1', output_details[0]['name'])
    self.assertEqual(np.int32, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 1], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

  def testGraphviz(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    converter.output_format = lite_constants.GRAPHVIZ_DOT
    graphviz_output = converter.convert()
    self.assertIsNotNone(graphviz_output)

  def testDumpGraphviz(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    graphviz_dir = self.get_temp_dir()
    converter.dump_graphviz_dir = graphviz_dir
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Ensure interpreter is able to allocate and check graphviz data.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    num_items_graphviz = len(os.listdir(graphviz_dir))
    self.assertIsNotNone(num_items_graphviz)
    self.assertIsNotNone(
        os.path.exists(os.path.join(graphviz_dir, 'toco_AT_IMPORT.dot')))
    self.assertIsNotNone(
        os.path.exists(
            os.path.join(graphviz_dir, 'toco_AFTER_TRANSFORMATIONS.dot')))

  def testDumpConversionSummary(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    log_dir = self.get_temp_dir()
    converter.conversion_summary_dir = log_dir
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    self.assertNotEmpty(os.listdir(log_dir))

  def testDumpConversionSummaryWithOldConverter(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    converter.experimental_new_converter = False
    log_dir = self.get_temp_dir()
    converter.conversion_summary_dir = log_dir
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)
    # Check nothing is generated under the conversion summary path.
    num_items_conversion_summary = len(os.listdir(log_dir))
    self.assertEqual(num_items_conversion_summary, 0)

  def testQuantizeDynamicRange(self):
    np.random.seed(0)
    with ops.Graph().as_default():
      # We need the tensor to have more than 1024 elements for quantize_weights
      # to kick in. Thus, the [33, 33] shape.
      in_tensor_1 = array_ops.placeholder(
          shape=[33, 33], dtype=dtypes.float32, name='inputA')
      in_tensor_2 = constant_op.constant(
          np.random.uniform(low=-10., high=10., size=(33, 33)),
          shape=[33, 33],
          dtype=dtypes.float32,
          name='inputB')
      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
      sess = session.Session()

    # Convert float model.
    float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1],
                                                        [out_tensor])
    float_tflite_model = float_converter.convert()
    self.assertIsNotNone(float_tflite_model)

    # Convert quantized weights model.
    quantized_converter = lite.TFLiteConverter.from_session(
        sess, [in_tensor_1], [out_tensor])

    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
    quantized_tflite_model = quantized_converter.convert()
    self.assertIsNotNone(quantized_tflite_model)

    # Ensure that the quantized weights tflite model is smaller.
    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))

  def testQuantizeDynamicRangeDeprecatedPostTrainingQuantizeAttribute(
      self):
    with ops.Graph().as_default():
      in_tensor_1 = array_ops.placeholder(
          shape=[33, 33], dtype=dtypes.float32, name='inputA')
      in_tensor_2 = constant_op.constant(
          np.random.uniform(low=-10., high=10., size=(33, 33)),
          shape=[33, 33],
          dtype=dtypes.float32,
          name='inputB')
      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
      sess = session.Session()

    quantized_converter = lite.TFLiteConverter.from_session(
        sess, [in_tensor_1], [out_tensor])
    self.assertFalse(quantized_converter.post_training_quantize)

    quantized_converter.post_training_quantize = True
    self.assertTrue(quantized_converter.post_training_quantize)
    self.assertEqual(quantized_converter.optimizations, [lite.Optimize.DEFAULT])

    quantized_tflite_model = quantized_converter.convert()
    self.assertIsNotNone(quantized_tflite_model)

  def _getIntegerQuantizeModel(self, num_filters=16):
    np.random.seed(0)
    inp = array_ops.placeholder(
        dtype=dtypes.float32, shape=(1, 5, 5, 3), name='input')
    conv = nn_ops.conv2d(
        inp,
        filter=array_ops.ones([3, 3, 3, num_filters]),
        strides=[1, 1, 1, 1],
        padding='SAME')
    output = nn_ops.relu(conv, name='output')

    def calibration_gen():
      for _ in range(5):
        yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]

    return (inp, output, calibration_gen)

  def testQuantizeInt8AllowFloat(self):
    with ops.Graph().as_default():
      inp, output, calibration_gen = self._getIntegerQuantizeModel()
      sess = session.Session()

    # Convert float model.
    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
    float_tflite_model = float_converter.convert()
    self.assertIsNotNone(float_tflite_model)
    # Check the conversion metadata.
    metadata = get_conversion_metadata(float_tflite_model)
    self.assertIsNotNone(metadata)
    self.assertEqual(
        metadata.environment.tensorflowVersion.decode('utf-8'),
        versions.__version__)
    self.assertEqual(metadata.environment.apiVersion, 1)
    self.assertEqual(metadata.environment.modelType,
                     metadata_fb.ModelType.TF_SESSION)
    self.assertEqual(metadata.options.allowCustomOps, False)
    self.assertEqual(metadata.options.enableSelectTfOps, False)
    self.assertEqual(metadata.options.forceSelectTfOps, False)
    self.assertAllEqual([], metadata.options.modelOptimizationModes)

    # Convert quantized model.
    quantized_converter = lite.TFLiteConverter.from_session(
        sess, [inp], [output])
    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
    quantized_converter.representative_dataset = calibration_gen
    quantized_tflite_model = quantized_converter.convert()
    self.assertIsNotNone(quantized_tflite_model)
    # Check the conversion metadata.
    metadata = get_conversion_metadata(quantized_tflite_model)
    self.assertIsNotNone(metadata)
    self.assertAllEqual([metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER],
                        metadata.options.modelOptimizationModes)

    # The default input and output types should be float.
    interpreter = Interpreter(model_content=quantized_tflite_model)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual(np.float32, input_details[0]['dtype'])
    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual(np.float32, output_details[0]['dtype'])

    # Ensure that the quantized weights tflite model is smaller.
    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))

  @parameterized.named_parameters(
      # Quantize model to Int8
      ('UseTfliteBuiltinsInt', [lite.OpsSet.TFLITE_BUILTINS_INT8],
       [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER]),
      ('UseTfliteBuiltinsInt16', [
          lite.OpsSet
          .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
      ], [metadata_fb.ModelOptimizationMode.PTQ_INT16]))
  def testQuantizeInt8And16x8(self, supported_ops, expected_opt_modes):
    with ops.Graph().as_default():
      inp, output, calibration_gen = self._getIntegerQuantizeModel()
      sess = session.Session()

    # Convert float model.
    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
    float_tflite_model = float_converter.convert()
    self.assertIsNotNone(float_tflite_model)

    # Convert model by specifying target spec (instead of optimizations), since
    # when targeting an integer only backend, quantization is mandatory.
    quantized_converter = lite.TFLiteConverter.from_session(
        sess, [inp], [output])
    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
    quantized_converter.target_spec.supported_ops = supported_ops
    quantized_converter.representative_dataset = calibration_gen
    quantized_tflite_model = quantized_converter.convert()
    self.assertIsNotNone(quantized_tflite_model)
    # Check the conversion metadata.
    metadata = get_conversion_metadata(quantized_tflite_model)
    self.assertIsNotNone(metadata)
    self.assertEqual(
        metadata.environment.tensorflowVersion.decode('utf-8'),
        versions.__version__)
    self.assertEqual(metadata.environment.apiVersion, 1)
    self.assertEqual(metadata.environment.modelType,
                     metadata_fb.ModelType.TF_SESSION)
    self.assertEqual(metadata.options.allowCustomOps, False)
    self.assertEqual(metadata.options.enableSelectTfOps, False)
    self.assertEqual(metadata.options.forceSelectTfOps, False)
    self.assertAllEqual(expected_opt_modes,
                        metadata.options.modelOptimizationModes)

    # The default input and output types should be float.
    interpreter = Interpreter(model_content=quantized_tflite_model)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual(np.float32, input_details[0]['dtype'])
    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual(np.float32, output_details[0]['dtype'])

    # Ensure that the quantized weights tflite model is smaller.
    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))

  def testQuantizeInt8InputOutput(self):
    with ops.Graph().as_default():
      inp, output, calibration_gen = self._getIntegerQuantizeModel()
      sess = session.Session()

    # Convert float model.
    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
    float_tflite_model = float_converter.convert()
    self.assertIsNotNone(float_tflite_model)

    # Convert quantized weights model.
    quantized_converter = lite.TFLiteConverter.from_session(
        sess, [inp], [output])
    quantized_converter.inference_input_type = dtypes.int8
    quantized_converter.inference_output_type = dtypes.int8
    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
    quantized_converter.representative_dataset = calibration_gen
    quantized_tflite_model = quantized_converter.convert()
    self.assertIsNotNone(quantized_tflite_model)

    # The input and output types should be int8.
    interpreter = Interpreter(model_content=quantized_tflite_model)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual(np.int8, input_details[0]['dtype'])
    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual(np.int8, output_details[0]['dtype'])

    # Ensure that the quantized weights tflite model is smaller.
    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))

  def testInvalidQuantizeInt8(self):
    np.random.seed(0)
    with ops.Graph().as_default():
      # We need the tensor to have more than 1024 elements for quantize_weights
      # to kick in. Thus, the [33, 33] shape.
      in_tensor_1 = array_ops.placeholder(
          shape=[33, 33], dtype=dtypes.float32, name='inputA')
      in_tensor_2 = constant_op.constant(
          np.random.uniform(low=-10., high=10., size=(33, 33)),
          shape=[33, 33],
          dtype=dtypes.float32,
          name='inputB')
      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
      sess = session.Session()

    # Attempt to convert to quantized weights model.
    quantized_converter = lite.TFLiteConverter.from_session(
        sess, [in_tensor_1], [out_tensor])
    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
    # Restricting to int8 type only
    quantized_converter.target_spec.supported_types = [dtypes.int8]
    # A representative dataset is required for full fixed point quantization.
    with self.assertRaises(ValueError) as error:
      quantized_converter.convert()
    self.assertEqual(
        'For full integer quantization, a `representative_dataset` '
        'must be specified.', str(error.exception))

  def testQuantizeUInt8(self):
    with ops.Graph().as_default():
      in_tensor_1 = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
      in_tensor_2 = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
      out_tensor = array_ops.fake_quant_with_min_max_args(
          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess,
                                                  [in_tensor_1, in_tensor_2],
                                                  [out_tensor])
    converter.inference_type = dtypes.uint8
    converter.quantized_input_stats = {
        'inputA': (0., 1.),
        'inputB': (0., 1.)
    }  # mean, std_dev
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 2)
    self.assertEqual('inputA', input_details[0]['name'])
    self.assertEqual(np.uint8, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((1., 0.), input_details[0]['quantization'])

    self.assertEqual('inputB', input_details[1]['name'])
    self.assertEqual(np.uint8, input_details[1]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
    self.assertEqual((1., 0.), input_details[1]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual(np.uint8, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
    self.assertGreater(output_details[0]['quantization'][0], 0)  # scale

  def testQuantizeUInt8UsingDefaultRangeStats(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    converter.inference_type = dtypes.uint8
    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
    converter.default_ranges_stats = (0, 6)  # min, max
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.uint8, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((1., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('add', output_details[0]['name'])
    self.assertEqual(np.uint8, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
    self.assertGreater(output_details[0]['quantization'][0], 0)  # scale

  @parameterized.named_parameters(
      # Quantize to Float16 even if rep data provided.
      ('UseRepresentativeData', True, False, True, False, False, False,
       [metadata_fb.ModelOptimizationMode.PTQ_FLOAT16]),
      # Quantize to Float16 if no rep data provided.
      ('NoRepresentativeData', False, False, True, False, False, False,
       [metadata_fb.ModelOptimizationMode.PTQ_FLOAT16]),
      # Post training quantization if both rep data and int8 included.
      ('SampleDataIncludeInt8', True, True, False, False, True, False,
       [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER]),
      # Same as above, but using MLIR quantizer
      ('SampleDataIncludeInt8Quant', True, True, False, False, True, True,
       [metadata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER]))
  def testQuantizeFloat16(self, use_rep_data, include_int8,
                          is_float16_quantized, is_float16_accumulation,
                          is_post_training_quantized, enable_mlir_quantizer,
                          expected_opt_modes):
    with ops.Graph().as_default():
      inp, output, calibration_gen = self._getIntegerQuantizeModel()
      sess = session.Session()

    bias_idx = 1
    bias_name = 'Conv2D'

    # Convert float model.
    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
    float_tflite_model = float_converter.convert()
    self.assertIsNotNone(float_tflite_model)
    interpreter = Interpreter(model_content=float_tflite_model)
    interpreter.allocate_tensors()
    self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'],
                     bias_name)
    self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'],
                     dtypes.float32)

    # Convert model to quantized version
    quantized_converter = lite.TFLiteConverter.from_session(
        sess, [inp], [output])
    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
    quantized_converter.target_spec.supported_types = [dtypes.float16]
    if include_int8:
      quantized_converter.target_spec.supported_types.append(dtypes.int8)
    if use_rep_data:
      quantized_converter.representative_dataset = calibration_gen
    if is_float16_accumulation:
      quantized_converter.target_spec.experimental_supported_accumulation_type = dtypes.float16  # pylint: disable=line-too-long

    else:
      quantized_tflite_model = quantized_converter.convert()
      self.assertIsNotNone(quantized_tflite_model)
      metadata = get_conversion_metadata(quantized_tflite_model)
      self.assertIsNotNone(metadata)
      self.assertAllEqual(expected_opt_modes,
                          metadata.options.modelOptimizationModes)
      interpreter = Interpreter(model_content=quantized_tflite_model)
      interpreter.allocate_tensors()

      # MLIR quantizer has different bias index.
      bias_name = 'tfl.pseudo_qconst' if enable_mlir_quantizer else 'Conv2D'
      bias_tensor = [
          tensor for tensor in interpreter.get_tensor_details()
          if tensor['name'] == bias_name
      ]
      self.assertLen(bias_tensor, 1)

      if is_float16_quantized:
        # Verify that bias constant is float16 type.
        self.assertEqual(bias_tensor[0]['dtype'], dtypes.float16)
      elif is_post_training_quantized:
        # Verify that bias constants is int32 type.
        self.assertEqual(bias_tensor[0]['dtype'], dtypes.int32)
      else:
        raise ValueError('Invalid test options.')

  def testInvalidQuantizeFloat16(self):
    with ops.Graph().as_default():
      inp, output, _ = self._getIntegerQuantizeModel()
      sess = session.Session()

    # Specify float16 quantization
    quantized_converter = lite.TFLiteConverter.from_session(
        sess, [inp], [output])
    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
    quantized_converter.target_spec.supported_types = [dtypes.float16]
    # Specify only int8 builtin ops
    quantized_converter.target_spec.supported_ops = [
        lite.OpsSet.TFLITE_BUILTINS_INT8
    ]
    with self.assertRaises(ValueError) as error:
      quantized_converter.convert()
    self.assertEqual(
        'As full integer quantization has been enabled by setting '
        '`target_spec.supported_ops`={tf.lite.OpsSet.TFLITE_BUILTINS_INT8}, '
        'thus `target_spec.supported_types` should be left uninitizalized '
        'or set to {tf.int8}.', str(error.exception))

  @parameterized.named_parameters(('InferenceType_INT8', dtypes.int8),
                                  ('InferenceType_UINT8', dtypes.uint8))
  def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = array_ops.fake_quant_with_min_max_args(
          in_tensor + in_tensor, min=0., max=1.)
      sess = session.Session()

    quantized_converter = lite.TFLiteConverter.from_session(
        sess, [in_tensor], [out_tensor])

    with self.assertRaises(ValueError) as error:
      quantized_converter.inference_type = quantized_type
      quantized_converter.convert()
    self.assertEqual(
        'The `quantized_input_stats` flag must be defined when either '
        '`inference_type` flag or `inference_input_type` flag is set to '
        'tf.int8 or tf.uint8. Currently, `inference_type=tf.{}` and '
        '`inference_input_type=None`.'.format(quantized_type.name),
        str(error.exception))

    with self.assertRaises(ValueError) as error:
      quantized_converter.inference_type = dtypes.float32
      quantized_converter.inference_input_type = quantized_type
      quantized_converter.convert()
    self.assertEqual(
        'The `quantized_input_stats` flag must be defined when either '
        '`inference_type` flag or `inference_input_type` flag is set to '
        'tf.int8 or tf.uint8. Currently, `inference_type=tf.float32` and '
        '`inference_input_type=tf.{}`.'.format(quantized_type.name),
        str(error.exception))

    quantized_converter.inference_type = quantized_type
    quantized_converter.inference_input_type = quantized_type

    input_arrays = quantized_converter.get_input_arrays()
    quantized_converter.quantized_input_stats = {input_arrays[0]: (0., 1.)}
    quantized_converter.convert()

  def testInvalidQuantizeQATModelMissingInputStats(self):
    with ops.Graph().as_default():
      in_tensor_1 = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
      in_tensor_2 = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
      out_tensor = array_ops.fake_quant_with_min_max_args(
          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess,
                                                  [in_tensor_1, in_tensor_2],
                                                  [out_tensor])
    converter.inference_type = dtypes.uint8
    converter.quantized_input_stats = {'inputA': (0., 1.)}  # mean, std_dev
    with self.assertRaises(ValueError) as error:
      converter.convert()
    self.assertEqual(
        'Quantization input stats are not available for input tensors '
        '\'inputB\'.', str(error.exception))

  def testTrainingTimeAndPostTrainingCalibrateAndQuantize(self):
    with ops.Graph().as_default():
      inp, output, calibration_gen = self._getIntegerQuantizeModel()
      sess = session.Session()

    # Convert float model.
    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
    float_tflite_model = float_converter.convert()
    self.assertIsNotNone(float_tflite_model)

    converter = lite.TFLiteConverter.from_session(sess, [inp], [output])

    # extra flags to trigger training time quantization conversion
    converter.inference_type = dtypes.int8
    converter.inference_input_type = dtypes.float32
    converter.inference_output_type = dtypes.float32
    input_arrays = converter.get_input_arrays()
    converter.quantized_input_stats = {input_arrays[0]: (0., 1.)}
    # trigger post-training quantization
    converter.optimizations = [lite.Optimize.DEFAULT]
    converter.representative_dataset = calibration_gen
    converter.experimental_new_quantizer = True
    quantized_tflite_model = converter.convert()
    self.assertIsNotNone(quantized_tflite_model)
    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))

    # calibration only api
    converter._experimental_calibrate_only = True
    calibrated_tflite = converter.convert()
    quantized_tflite_model = mlir_quantize(
        calibrated_tflite, fully_quantize=True)
    interpreter = Interpreter(model_content=quantized_tflite_model)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    self.assertEqual(np.int8, input_details[0]['dtype'])
    self.assertEqual((1., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertEqual(np.int8, output_details[0]['dtype'])

  def testFloatTocoConverter(self):
    """Tests deprecated test TocoConverter."""
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Ensure the interpreter is able to load.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

  def testMultipleOutputNodeNames(self):
    """Tests converting a graph with an op that have multiple outputs."""
    with ops.Graph().as_default():
      input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
      out0, out1, out2, out3 = array_ops.split(
          input_tensor, [1, 1, 1, 1], axis=0)
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
                                                  [out0, out1, out2, out3])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    interpreter.set_tensor(input_details[0]['index'],
                           np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32))
    interpreter.invoke()

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 4)
    self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index']))
    self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index']))
    self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index']))
    self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index']))

  @test_util.run_in_graph_and_eager_modes
  def testFunctions(self):
    """Tests tf.function in 1.X."""

    @def_function.function
    def plus_placeholder(x, placeholder):
      return x + placeholder

    with ops.Graph().as_default():
      placeholder = array_ops.placeholder(
          dtype=dtypes.float32, shape=[1], name='input')
      variable_node = variables.Variable(1.0, name='variable_node')
      defun_node = plus_placeholder(variable_node, placeholder)
      output_node = math_ops.multiply(defun_node, 2.0, name='output_node')

      # Initialize variables in the model.
      sess = session.Session()
      sess.run(variables.variables_initializer([variable_node]))

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [placeholder],
                                                  [output_node])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('input', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('output_node', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

  def testInferenceInputOutputTypeFloatDefault(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('add', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])

  def testInferenceInputOutputTypeQuantizedUint8Default(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = array_ops.fake_quant_with_min_max_args(
          in_tensor + in_tensor, min=0., max=1., name='output')
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    converter.inference_type = dtypes.uint8
    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.uint8, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('output', output_details[0]['name'])
    self.assertEqual(np.uint8, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])

  def testReusingConverterWithDifferentPostTrainingQuantization(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      out_tensor = array_ops.fake_quant_with_min_max_args(
          in_tensor + in_tensor, min=0., max=1., name='output')
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])

    converter.post_training_quantize = True
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    converter.post_training_quantize = False
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

  def testResizeWithShape(self):
    with ops.Graph().as_default():
      # Construct a graph with a dynamically shapped input and an internal node
      # that relies on the output of that input's shape.
      in_tensor = array_ops.placeholder(
          shape=[None, None], dtype=dtypes.float32)
      in_tensor2 = [[1, 2], [3, 4]]
      out_tensor = array_ops.reshape(in_tensor2, array_ops.shape(in_tensor))
      sess = session.Session()

    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertAllEqual([1, 1], input_details[0]['shape'])
    self.assertAllEqual([-1, -1], input_details[0]['shape_signature'])

    # Resize tensor and invoke.
    interpreter.resize_tensor_input(0, [4])
    interpreter.allocate_tensors()
    interpreter.invoke()

    # The output should be reshaped properly according to the resized input.
    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual(np.int32, output_details[0]['dtype'])
    self.assertAllEqual([4], output_details[0]['shape'])
    output_data = interpreter.get_tensor(output_details[0]['index'])
    self.assertAllEqual([1, 2, 3, 4], output_data)

  def testResizingIntermediateDynamicTensor(self):
    # This is a regression test for the case where shape of dynamic output
    # tensors changes between invocations.
    # See also https://github.com/tensorflow/tensorflow/issues/26549
    with ops.Graph().as_default():
      input_tensor = array_ops.placeholder(shape=[1, 1], dtype=dtypes.float32)
      input2_tensor = array_ops.placeholder(shape=[1], dtype=dtypes.float32)

      # The bug is triggered only when dynamic tensor is intermediate. Putting
      # some other ops around it.
      neg = math_ops.negative(input2_tensor)
      padding = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int32)
      output_tensor = array_ops.pad(input_tensor, padding) + neg

      sess = session.Session()

    converter = lite.TFLiteConverter.from_session(
        sess, [input_tensor, padding, input2_tensor], [output_tensor])
    tflite_model = converter.convert()

    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    interpreter.set_tensor(input_details[1]['index'],
                           np.array([[1, 1], [1, 1]], dtype=np.int32))
    interpreter.invoke()

    # Without the fix, invocation will fail when changing the shape of
    # intermediate dynamic tensors.
    interpreter.set_tensor(input_details[1]['index'],
                           np.array([[2, 2], [2, 2]], dtype=np.int32))
    interpreter.invoke()

  def testGraphDebugInfo(self):
    """Test a session has debug info captured."""

    @def_function.function
    def plus_placeholder(x, placeholder):
      return x + placeholder

    with ops.Graph().as_default():
      placeholder = array_ops.placeholder(
          dtype=dtypes.float32, shape=[1], name='input')
      variable_node = variables.Variable(1.0, name='variable_node')
      defun_node = plus_placeholder(variable_node, placeholder)
      output_node = math_ops.multiply(defun_node, 2.0, name='output_node')

      # Initialize variables in the model.
      sess = session.Session()
      sess.run(variables.variables_initializer([variable_node]))

    converter = lite.TFLiteConverter.from_session(sess, [placeholder],
                                                  [output_node])
    converter.convert()
    self.assertValidDebugInfo(converter._debug_info)

    # Check the add node in the inlined function is included.
    func = sess.graph.as_graph_def().library.function[0].signature.name
    self.assertIn(('add@' + func), repr(converter._debug_info))

  def testOutputOnlyModel(self):
    with ops.Graph().as_default():
      out_tensor = random_ops.random_normal(shape=[3])
      sess = session.Session()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [], [out_tensor])
    converter.target_spec.supported_ops = [
        lite.OpsSet.TFLITE_BUILTINS,
        lite.OpsSet.SELECT_TF_OPS,
    ]

    # Empty input array is a valid input.
    self.assertTrue(converter._has_valid_tensors())

    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)


class FromFrozenGraphFile(LiteTest):

  def testFloat(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      _ = in_tensor + in_tensor
      sess = session.Session()

    # Write graph to file.
    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
    write_graph(sess.graph_def, '', graph_def_file, False)
    sess.close()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
                                                       ['Placeholder'], ['add'])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('add', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

  def testFloatWithShapesArray(self):
    """Test a shape overriding case."""
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[None, 16, 16, 3], dtype=dtypes.float32)
      _ = in_tensor + in_tensor
      sess = session.Session()

    # Write graph to file.
    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
    write_graph(sess.graph_def, '', graph_def_file, False)
    sess.close()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_frozen_graph(
        graph_def_file, ['Placeholder'], ['add'],
        input_shapes={'Placeholder': [2, 16, 16, 3]})
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])

  def testInvalidShapesArray(self):
    """Test an invalid shape overriding case, which has a wrong input name."""
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[None, 16, 16, 3], dtype=dtypes.float32)
      _ = in_tensor + in_tensor
      sess = session.Session()

    # Write graph to file.
    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
    write_graph(sess.graph_def, '', graph_def_file, False)
    sess.close()

    # Convert model and ensure model is not None.
    with self.assertRaises(ValueError):
      lite.TFLiteConverter.from_frozen_graph(
          graph_def_file, ['Placeholder'], ['add'],
          input_shapes={'wrong_input': [2, 16, 16, 3]})

  def testPartialShapesArray(self):
    """Test a shape overriding case, with the only one input among two."""
    with ops.Graph().as_default():
      a = array_ops.placeholder(
          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='a')
      b = array_ops.placeholder(
          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='b')
      _ = math_ops.add(a, b, name='add')
      sess = session.Session()

    # Write graph to file.
    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
    write_graph(sess.graph_def, '', graph_def_file, False)
    sess.close()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_frozen_graph(
        graph_def_file, ['a', 'b'], ['add'], input_shapes={'a': [2, 16, 16, 3]})
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 2)
    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])

  def testFreezeGraph(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      var = variable_scope.get_variable(
          'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
      _ = in_tensor + var
      sess = session.Session()

    # Write graph to file.
    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
    write_graph(sess.graph_def, '', graph_def_file, False)
    sess.close()

    # Ensure the graph with variables cannot be converted.
    with self.assertRaises(ValueError) as error:
      lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
                                             ['add'])
    self.assertEqual('Please freeze the graph using freeze_graph.py.',
                     str(error.exception))

  def testPbtxt(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      _ = in_tensor + in_tensor
      sess = session.Session()

    # Write graph to file.
    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
    write_graph(sess.graph_def, '', graph_def_file, True)
    sess.close()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
                                                       ['Placeholder'], ['add'])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('add', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

  def testInvalidFileNotFound(self):
    with self.assertRaises(IOError) as error:
      lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'],
                                             ['add'])
    self.assertEqual('File \'invalid_file\' does not exist.',
                     str(error.exception))

  def testInvalidFileBadData(self):
    graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
    with gfile.Open(graph_def_file, 'wb') as temp_file:
      temp_file.write('bad data')
      temp_file.flush()

    # Attempts to convert the invalid model.
    with self.assertRaises(IOError) as error:
      lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
                                             ['add'])
    self.assertEqual(
        'Unable to parse input file \'{}\'.'.format(graph_def_file),
        str(error.exception))

  def testFloatTocoConverter(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      _ = in_tensor + in_tensor
      sess = session.Session()

    # Write graph to file.
    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
    write_graph(sess.graph_def, '', graph_def_file, False)
    sess.close()

    # Convert model and ensure model is not None.
    converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
                                                     ['Placeholder'], ['add'])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Ensure the model is able to load.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

  def testGraphDebugInfo(self):
    """Test a frozen graph doesn't have debug info captured."""
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      _ = in_tensor + in_tensor
      sess = session.Session()

    # Write graph to file.
    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
    write_graph(sess.graph_def, '', graph_def_file, False)
    sess.close()

    # Convert model and ensure model is not None.
    converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
                                                     ['Placeholder'], ['add'])
    converter.convert()
    # GraphDebugInfo should be none for frozen graph.
    self.assertFalse(converter._debug_info)

  def testExcludeConversionMetadata(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[1, 16, 16, 3], dtype=dtypes.float32)
      _ = in_tensor + in_tensor
      sess = session.Session()

    # Write graph to file.
    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
    write_graph(sess.graph_def, '', graph_def_file, False)
    sess.close()

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
                                                       ['Placeholder'], ['add'])
    converter.exclude_conversion_metadata = True
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)
    # Check the conversion metadata.
    metadata = get_conversion_metadata(tflite_model)
    self.assertIsNone(metadata)


class FromFrozenGraphObjectDetection(LiteTest):

  def _initObjectDetectionArgs(self):
    # Initializes the arguments required for the object detection model.
    # Looks for the model file which is saved in a different location internally
    # and externally.
    filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb')
    if not os.path.exists(filename):
      filename = os.path.join(
          resource_loader.get_root_dir_with_all_resources(),
          '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb')
      if not os.path.exists(filename):
        raise IOError("File '{0}' does not exist.".format(filename))

    self._graph_def_file = filename
    self._input_arrays = ['normalized_input_image_tensor']
    self._output_arrays = [
        'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
        'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
    ]
    self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}

  def testTFLiteGraphDef(self):
    # Tests the object detection model that cannot be loaded in TensorFlow.
    self._initObjectDetectionArgs()

    converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file,
                                                       self._input_arrays,
                                                       self._output_arrays,
                                                       self._input_shapes)
    converter.allow_custom_ops = True
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('normalized_input_image_tensor', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 300, 300, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 4)
    self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 10, 4], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

    self.assertEqual('TFLite_Detection_PostProcess:1',
                     output_details[1]['name'])
    self.assertAllEqual([1, 10], output_details[1]['shape'])
    self.assertEqual('TFLite_Detection_PostProcess:2',
                     output_details[2]['name'])
    self.assertAllEqual([1, 10], output_details[2]['shape'])
    self.assertEqual('TFLite_Detection_PostProcess:3',
                     output_details[3]['name'])
    self.assertAllEqual([1], output_details[3]['shape'])

  def testTFLiteGraphDefWithControlOutput(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(
          shape=[5, 5], dtype=dtypes.float32, name='input')
      out_tensor = in_tensor + in_tensor
      logging_ops.print_v2(out_tensor)
      sess = session.Session()

    converter = lite.TFLiteConverter(
        sess.graph_def,
        input_tensors=None,
        output_tensors=None,
        input_arrays_with_shape=[('input', [5, 5])],
        output_arrays=None,
        experimental_debug_info_func=None)
    converter._control_output_arrays = ['PrintV2']
    converter.target_spec.supported_ops = [
        lite.OpsSet.TFLITE_BUILTINS,
        lite.OpsSet.SELECT_TF_OPS,
    ]
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    model = util._convert_model_from_bytearray_to_object(tflite_model)
    self.assertEqual(model.operatorCodes[0].builtinCode,
                     schema_fb.BuiltinOperator.ADD)
    self.assertEqual(model.operatorCodes[1].builtinCode,
                     schema_fb.BuiltinOperator.CUSTOM)
    self.assertEqual(model.operatorCodes[1].customCode, b'FlexStringFormat')
    self.assertEqual(model.operatorCodes[2].builtinCode,
                     schema_fb.BuiltinOperator.CUSTOM)
    self.assertEqual(model.operatorCodes[2].customCode, b'FlexPrintV2')

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('input', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([5, 5], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 0)

  def testModifyIOToUint8(self):
    # Tests the object detection model that cannot be loaded in TensorFlow.
    self._initObjectDetectionArgs()

    def representative_dataset_gen():
      for _ in range(2):
        yield [
            np.random.uniform(low=0, high=1,
                              size=(1, 300, 300, 3)).astype(np.float32)
        ]

    converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file,
                                                       self._input_arrays,
                                                       self._output_arrays,
                                                       self._input_shapes)
    converter.representative_dataset = representative_dataset_gen
    converter.target_spec.supported_ops = {lite.OpsSet.TFLITE_BUILTINS_INT8}
    converter.inference_type = dtypes.int8
    converter.inference_input_type = dtypes.uint8
    converter.inference_output_type = dtypes.uint8
    converter.experimental_new_quantizer = True
    converter.quantized_input_stats = {
        'normalized_input_image_tensor': (0., 1.)
    }  # mean, std_dev
    converter.allow_custom_ops = True
    tflite_model = converter.convert()

    self.assertIsNotNone(tflite_model)

    model = util._convert_model_from_bytearray_to_object(tflite_model)
    quant_opcode_idxs = util.get_quantize_opcode_idx(model)

    subgraph = model.subgraphs[0]
    tensors = subgraph.tensors
    operators = subgraph.operators
    for op in operators:
      if op.opcodeIndex in quant_opcode_idxs:
        input_type = util._convert_tflite_enum_type_to_tf_type(
            tensors[op.inputs[0]].type)
        if op.outputs[0] in subgraph.outputs:
          self.assertEqual(input_type, dtypes.float32)


class FromSavedModelTest(TestModels):

  def _createSavedModel(self, shape):
    """Create a simple SavedModel."""
    saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
    with ops.Graph().as_default():
      with session.Session() as sess:
        in_tensor_1 = array_ops.placeholder(
            shape=shape, dtype=dtypes.float32, name='inputB')
        in_tensor_2 = array_ops.placeholder(
            shape=shape, dtype=dtypes.float32, name='inputA')
        out_tensor = in_tensor_1 + in_tensor_2
        inputs = {'x': in_tensor_1, 'y': in_tensor_2}
        outputs = {'z': out_tensor}
        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
    return saved_model_dir

  def testSimpleModel(self):
    """Test a SavedModel."""
    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 2)
    self.assertStartsWith(input_details[0]['name'], 'inputA')
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    self.assertStartsWith(input_details[1]['name'], 'inputB')
    self.assertEqual(np.float32, input_details[1]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
    self.assertEqual((0., 0.), input_details[1]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertStartsWith(output_details[0]['name'], 'add')
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

  def testNoneBatchSize(self):
    """Test a SavedModel, with None in input tensor's shape."""
    saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])

    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 2)
    self.assertStartsWith(input_details[0]['name'], 'inputA')
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    self.assertStartsWith(input_details[1]['name'], 'inputB')
    self.assertEqual(np.float32, input_details[1]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
    self.assertEqual((0., 0.), input_details[1]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertStartsWith(output_details[0]['name'], 'add')
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

  def testOrderInputArrays(self):
    """Test a SavedModel ordering of input arrays."""
    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])

    converter = lite.TFLiteConverter.from_saved_model(
        saved_model_dir, input_arrays=['inputB', 'inputA'])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 2)
    self.assertStartsWith(input_details[0]['name'], 'inputA')
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    self.assertStartsWith(input_details[1]['name'], 'inputB')
    self.assertEqual(np.float32, input_details[1]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
    self.assertEqual((0., 0.), input_details[1]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertStartsWith(output_details[0]['name'], 'add')
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

  def testShapeOverriding(self):
    """Test a SavedModel with the input_shapes arugment."""
    saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_saved_model(
        saved_model_dir,
        input_shapes={
            'inputA': [2, 16, 16, 3],
            'inputB': [2, 16, 16, 3]
        })
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 2)
    self.assertStartsWith(input_details[0]['name'], 'inputA')
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    self.assertStartsWith(input_details[1]['name'], 'inputB')
    self.assertEqual(np.float32, input_details[1]['dtype'])
    self.assertAllEqual([2, 16, 16, 3], input_details[1]['shape'])
    self.assertEqual((0., 0.), input_details[1]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertStartsWith(output_details[0]['name'], 'add')
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

  def testWrongInputShapes(self):
    """Test a SavedModel with a wrong name in the input_shapes argument."""
    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])

    # Check case where input shape is given.
    with self.assertRaises(ValueError):
      lite.TFLiteConverter.from_saved_model(
          saved_model_dir,
          input_arrays=['inputA'],
          input_shapes={'wrong_input': [1, 16, 16, 3]})

  def testSubsetInputShaapes(self):
    """Test a SavedModel with a subset of the input array names of the model."""
    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])

    # Check case where input shape is given.
    converter = lite.TFLiteConverter.from_saved_model(
        saved_model_dir,
        input_arrays=['inputA'],
        input_shapes={'inputA': [1, 16, 16, 3]})

    # Since we only partially specify the input, this is not allowed.
    with self.assertRaises(ConverterError):
      _ = converter.convert()

    # Check case where input shape is None.
    converter = lite.TFLiteConverter.from_saved_model(
        saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})

    # Since we only partially specify the input, this is not allowed.
    with self.assertRaises(ConverterError):
      _ = converter.convert()

  def testSimpleModelTocoConverter(self):
    """Test a SavedModel with deprecated TocoConverter."""
    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])

    # Convert model and ensure model is not None.
    converter = lite.TocoConverter.from_saved_model(saved_model_dir)
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Ensure the model is able to load.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

  def testGraphDebugInfo(self):
    """Test a SavedModel has debug info captured."""
    self.skipTest(
        'b/221093690: The debug info is not from self._createSavedModel(), '
        'but from saved_model.loader_impl().')
    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
    converter.convert()
    self.assertValidDebugInfo(converter._debug_info)


class MyAddLayer(keras.layers.Layer):

  def __init__(self, increment, **kwargs):
    super(MyAddLayer, self).__init__(**kwargs)
    self._increment = increment

  def call(self, inputs):
    return inputs + self._increment

  def get_config(self):
    config = super(MyAddLayer, self).get_config()
    config['increment'] = self._increment
    return config


class FromKerasFile(TestModels, parameterized.TestCase):

  def setUp(self):
    super(FromKerasFile, self).setUp()
    self._keras_file = None
    self._custom_objects = None
    if not context.executing_eagerly():
      keras.backend.clear_session()

  def tearDown(self):
    if self._keras_file:
      os.remove(self._keras_file)
    super(FromKerasFile, self).tearDown()

  def _getSequentialModel(self, include_custom_layer=False):
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(2, input_shape=(3,)))
    if include_custom_layer:
      model.add(MyAddLayer(1.0))
    model.add(keras.layers.RepeatVector(3))
    model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
    model.compile(
        loss=keras.losses.MSE,
        optimizer='sgd',
        metrics=[keras.metrics.categorical_accuracy],
        sample_weight_mode='temporal')
    x = np.random.random((1, 3))
    y = np.random.random((1, 3, 3))
    model.train_on_batch(x, y)
    model.predict(x)

    try:
      fd, self._keras_file = tempfile.mkstemp('.h5')
      keras.models.save_model(model, self._keras_file)
    finally:
      os.close(fd)

    if include_custom_layer:
      self._custom_objects = {'MyAddLayer': MyAddLayer}

  @parameterized.named_parameters(('_graph', context.graph_mode),
                                  ('_eager', context.eager_mode))
  def testSequentialModel(self, test_context):
    """Test a Sequential tf.keras model with default inputs."""
    with test_context():
      self._getSequentialModel()

      converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
      tflite_model = converter.convert()
      self.assertIsNotNone(tflite_model)

    # Check tensor details of converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEndsWith(input_details[0]['name'], 'dense_input')
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 3, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

    # Check inference of converted model.
    input_data = np.array([[1, 2, 3]], dtype=np.float32)
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    tflite_result = interpreter.get_tensor(output_details[0]['index'])

    keras_model = keras.models.load_model(self._keras_file)
    keras_result = keras_model.predict(input_data)

    np.testing.assert_almost_equal(tflite_result, keras_result, 5)

  @parameterized.named_parameters(('_graph', context.graph_mode),
                                  ('_eager', context.eager_mode))
  def testCustomLayer(self, test_context):
    """Test a Sequential tf.keras model with default inputs."""
    with test_context():
      self._getSequentialModel(include_custom_layer=True)

      converter = lite.TFLiteConverter.from_keras_model_file(
          self._keras_file, custom_objects=self._custom_objects)
      tflite_model = converter.convert()
      self.assertIsNotNone(tflite_model)

    # Check tensor details of converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Check inference of converted model.
    input_data = np.array([[1, 2, 3]], dtype=np.float32)
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    tflite_result = interpreter.get_tensor(output_details[0]['index'])

    keras_model = keras.models.load_model(
        self._keras_file, custom_objects=self._custom_objects)
    keras_result = keras_model.predict(input_data)

    np.testing.assert_almost_equal(tflite_result, keras_result, 5)

  def testSequentialModelInputArray(self):
    """Test a Sequential tf.keras model testing input arrays argument."""
    ops.disable_eager_execution()
    self._getSequentialModel()

    # Invalid input array raises error.
    with self.assertRaises(ValueError) as error:
      lite.TFLiteConverter.from_keras_model_file(
          self._keras_file, input_arrays=['invalid-input'])
    self.assertEqual("Invalid tensors 'invalid-input' were found.",
                     str(error.exception))

    # Valid input array.
    converter = lite.TFLiteConverter.from_keras_model_file(
        self._keras_file, input_arrays=['dense_input'])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

  def testSequentialModelInputShape(self):
    """Test a Sequential tf.keras model testing input shapes argument."""
    self._getSequentialModel()

    # Passing in shape of invalid input array raises error.
    with self.assertRaises(ValueError) as error:
      converter = lite.TFLiteConverter.from_keras_model_file(
          self._keras_file, input_shapes={'invalid-input': [2, 3]})
    self.assertEqual(
        "Invalid tensor 'invalid-input' found in tensor shapes map.",
        str(error.exception))

    # Passing in shape of valid input array.
    converter = lite.TFLiteConverter.from_keras_model_file(
        self._keras_file, input_shapes={'dense_input': [2, 3]})
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check input shape from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEndsWith(input_details[0]['name'], 'dense_input')
    self.assertAllEqual([2, 3], input_details[0]['shape'])

  def testSequentialModelOutputArray(self):
    """Test a Sequential tf.keras model testing output arrays argument."""
    ops.disable_eager_execution()
    self._getSequentialModel()

    # Invalid output array raises error.
    with self.assertRaises(ValueError) as error:
      lite.TFLiteConverter.from_keras_model_file(
          self._keras_file, output_arrays=['invalid-output'])
    self.assertEqual("Invalid tensors 'invalid-output' were found.",
                     str(error.exception))

    # Valid output array.
    converter = lite.TFLiteConverter.from_keras_model_file(
        self._keras_file, output_arrays=['time_distributed/Reshape_1'])
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

  @parameterized.named_parameters(('_graph', context.graph_mode),
                                  ('_eager', context.eager_mode))
  def testFunctionalModel(self, test_context):
    """Test a Functional tf.keras model with default inputs."""
    with test_context():
      inputs = keras.layers.Input(shape=(3,), name='input')
      x = keras.layers.Dense(2)(inputs)
      output = keras.layers.Dense(3)(x)

      model = keras.models.Model(inputs, output)
      model.compile(
          loss=keras.losses.MSE,
          optimizer='sgd',
          metrics=[keras.metrics.categorical_accuracy])
      x = np.random.random((1, 3))
      y = np.random.random((1, 3))
      model.train_on_batch(x, y)

      model.predict(x)
      fd, self._keras_file = tempfile.mkstemp('.h5')
      try:
        keras.models.save_model(model, self._keras_file)
      finally:
        os.close(fd)

      # Convert to TFLite model.
      converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
      tflite_model = converter.convert()
      self.assertIsNotNone(tflite_model)

    # Check tensor details of converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('input', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

    # Check inference of converted model.
    input_data = np.array([[1, 2, 3]], dtype=np.float32)
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    tflite_result = interpreter.get_tensor(output_details[0]['index'])

    keras_model = keras.models.load_model(self._keras_file)
    keras_result = keras_model.predict(input_data)

    np.testing.assert_almost_equal(tflite_result, keras_result, 5)

  def _getFunctionalModelMultipleInputs(self):
    a = keras.layers.Input(shape=(3,), name='input_a')
    b = keras.layers.Input(shape=(3,), name='input_b')
    dense = keras.layers.Dense(4, name='dense')
    c = dense(a)
    d = dense(b)
    e = keras.layers.Dropout(0.5, name='dropout')(c)

    model = keras.models.Model([a, b], [d, e])
    model.compile(
        loss=keras.losses.MSE,
        optimizer='sgd',
        metrics=[keras.metrics.mae],
        loss_weights=[1., 0.5])

    input_a_np = np.random.random((10, 3))
    input_b_np = np.random.random((10, 3))
    output_d_np = np.random.random((10, 4))
    output_e_np = np.random.random((10, 4))
    model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])

    model.predict([input_a_np, input_b_np], batch_size=5)
    fd, self._keras_file = tempfile.mkstemp('.h5')
    try:
      keras.models.save_model(model, self._keras_file)
    finally:
      os.close(fd)

  def testFunctionalModelMultipleInputs(self):
    """Test a Functional tf.keras model with multiple inputs and outputs."""
    self._getFunctionalModelMultipleInputs()

    # Convert to TFLite model.
    converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 2)
    self.assertEndsWith(input_details[0]['name'], 'input_a')
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    self.assertEndsWith(input_details[1]['name'], 'input_b')
    self.assertEqual(np.float32, input_details[1]['dtype'])
    self.assertAllEqual([1, 3], input_details[1]['shape'])
    self.assertEqual((0., 0.), input_details[1]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 2)
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 4], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

    self.assertEqual(np.float32, output_details[1]['dtype'])
    self.assertAllEqual([1, 4], output_details[1]['shape'])
    self.assertEqual((0., 0.), output_details[1]['quantization'])

  def testShapeOverriding(self):
    """Test a Functional tf.keras model with input shape overriding."""
    self._getFunctionalModelMultipleInputs()

    # Convert to TFLite model.
    converter = lite.TFLiteConverter.from_keras_model_file(
        self._keras_file, input_shapes={
            'input_a': {2, 3},
            'input_b': {2, 3}
        })
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 2)
    self.assertEndsWith(input_details[0]['name'], 'input_a')
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([2, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    self.assertEndsWith(input_details[1]['name'], 'input_b')
    self.assertEqual(np.float32, input_details[1]['dtype'])
    self.assertAllEqual([2, 3], input_details[1]['shape'])
    self.assertEqual((0., 0.), input_details[1]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 2)
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([2, 4], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

    self.assertEqual(np.float32, output_details[1]['dtype'])
    self.assertAllEqual([2, 4], output_details[1]['shape'])
    self.assertEqual((0., 0.), output_details[1]['quantization'])

  def testPartialShapeOverriding(self):
    """Test a Functional tf.keras model with partial input shape overriding."""
    self._getFunctionalModelMultipleInputs()

    # Convert to TFLite model.
    converter = lite.TFLiteConverter.from_keras_model_file(
        self._keras_file, input_shapes={'input_a': {2, 3}})
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 2)
    self.assertEndsWith(input_details[0]['name'], 'input_a')
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([2, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    self.assertEndsWith(input_details[1]['name'], 'input_b')
    self.assertEqual(np.float32, input_details[1]['dtype'])
    self.assertAllEqual([1, 3], input_details[1]['shape'])
    self.assertEqual((0., 0.), input_details[1]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 2)
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 4], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

    self.assertEqual(np.float32, output_details[1]['dtype'])
    self.assertAllEqual([2, 4], output_details[1]['shape'])
    self.assertEqual((0., 0.), output_details[1]['quantization'])

  def testWrongShapeOverriding(self):
    """Test a Functional tf.keras model with wrong input shape overriding."""
    self._getFunctionalModelMultipleInputs()

    # Convert to TFLite model.
    with self.assertRaises(ValueError):
      lite.TFLiteConverter.from_keras_model_file(
          self._keras_file, input_shapes={'wrong_input': {2, 3}})

  def testFunctionalSequentialModel(self):
    """Test a Functional tf.keras model containing a Sequential model."""
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(2, input_shape=(3,)))
    model.add(keras.layers.RepeatVector(3))
    model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
    model = keras.models.Model(model.input, model.output)

    model.compile(
        loss=keras.losses.MSE,
        optimizer='sgd',
        metrics=[keras.metrics.categorical_accuracy],
        sample_weight_mode='temporal')
    x = np.random.random((1, 3))
    y = np.random.random((1, 3, 3))
    model.train_on_batch(x, y)
    model.predict(x)

    model.predict(x)
    fd, self._keras_file = tempfile.mkstemp('.h5')
    try:
      keras.models.save_model(model, self._keras_file)
    finally:
      os.close(fd)

    # Convert to TFLite model.
    converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Check tensor details of converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEndsWith(input_details[0]['name'], 'dense_input')
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([1, 3], input_details[0]['shape'])
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([1, 3, 3], output_details[0]['shape'])
    self.assertEqual((0., 0.), output_details[0]['quantization'])

    # Check inference of converted model.
    input_data = np.array([[1, 2, 3]], dtype=np.float32)
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    tflite_result = interpreter.get_tensor(output_details[0]['index'])

    keras_model = keras.models.load_model(self._keras_file)
    keras_result = keras_model.predict(input_data)

    np.testing.assert_almost_equal(tflite_result, keras_result, 5)

  def testSequentialModelTocoConverter(self):
    """Test a Sequential tf.keras model with deprecated TocoConverter."""
    self._getSequentialModel()

    converter = lite.TocoConverter.from_keras_model_file(self._keras_file)
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)

    # Ensure the model is able to load.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

  @parameterized.named_parameters(('_graph', context.graph_mode),
                                  ('_eager', context.eager_mode))
  def testGraphDebugInfo(self, test_context):
    """Test a Sequential tf.keras model has debug info captured."""
    self.skipTest('TODO(b/291005679): will not be able to fix on OSS')
    with test_context():
      self._getSequentialModel()
      converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
      converter.convert()
      self.assertValidDebugInfo(converter._debug_info)


class SparsityTest(TestModels):

  def _getSparsificableModel(self, matrix_b_values):
    with ops.Graph().as_default():
      in_tensor_1 = array_ops.placeholder(
          shape=[16, 4], dtype=dtypes.float32, name='input1')
      in_tensor_2 = constant_op.constant(
          matrix_b_values, shape=[4, 8], dtype=dtypes.float32)
      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2)
      sess = session.Session()

    return (sess, [in_tensor_1], [out_tensor])

  def testRandomSparsity(self):
    matrix_b_values = [
        0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1
    ]
    sess, inputs, outputs = self._getSparsificableModel(matrix_b_values)
    float_converter = lite.TFLiteConverter.from_session(sess, inputs, outputs)
    float_converter.optimizations = [lite.Optimize.EXPERIMENTAL_SPARSITY]
    float_tflite_model = float_converter.convert()
    self.assertIsNotNone(float_tflite_model)
    # Check the conversion metadata.
    metadata = get_conversion_metadata(float_tflite_model)
    self.assertIsNotNone(metadata)
    self.assertAllEqual([metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY],
                        metadata.options.modelOptimizationModes)

  def testSparsifyModel(self):
    matrix_b_values = [
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 0
    ]
    sess, inputs, outputs = self._getSparsificableModel(matrix_b_values)
    converter = lite.TFLiteConverter.from_session(sess, inputs, outputs)
    converter.optimizations = {lite.Optimize.EXPERIMENTAL_SPARSITY}
    tflite_model = converter.convert()
    self.assertTrue(tflite_model)
    # Check the conversion metadata.
    metadata = get_conversion_metadata(tflite_model)
    self.assertIsNotNone(metadata)
    self.assertAllEqual([
        metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY,
    ], metadata.options.modelOptimizationModes)

  def testSparsifyQuantizedModel(self):
    matrix_b_values = [
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 0
    ]
    sess, inputs, outputs = self._getSparsificableModel(matrix_b_values)
    converter = lite.TFLiteConverter.from_session(sess, inputs, outputs)
    converter.optimizations = {
        lite.Optimize.DEFAULT, lite.Optimize.EXPERIMENTAL_SPARSITY
    }
    tflite_model = converter.convert()
    self.assertIsNotNone(tflite_model)
    # Check the conversion metadata.
    metadata = get_conversion_metadata(tflite_model)
    self.assertIsNotNone(metadata)
    self.assertAllEqual([
        metadata_fb.ModelOptimizationMode.PTQ_DYNAMIC_RANGE,
        metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY,
    ], metadata.options.modelOptimizationModes)


class GrapplerTest(TestModels, parameterized.TestCase):

  def testConstantFolding(self):
    ops.disable_eager_execution()
    # Constant folding handles the tf.broadcast_to operation which was not
    # supported by the TFLite at the time this test was added.
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(shape=[3, 3], dtype=dtypes.float32)
      y_const = constant_op.constant([1., 2., 3.])
      y_broadcast = array_ops.broadcast_to(y_const, [3, 3])
      out_tensor = math_ops.matmul(in_tensor, y_broadcast, name='output')
      sess = session.Session()

    # Convert model.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertAllEqual([3, 3], input_details[0]['shape'])

    output_details = interpreter.get_output_details()
    self.assertLen(output_details, 1)
    self.assertEqual('output', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertAllEqual([3, 3], output_details[0]['shape'])

  def testInputNodeIsNotFolded(self):
    ops.disable_eager_execution()
    # Constant folding handles the tf.broadcast_to operation which was not
    # supported by the TFLite at the time this test was added.
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
      y_const = constant_op.constant([1., 2., 3.])
      y_add = y_const + y_const
      out_tensor = in_tensor * y_add
      sess = session.Session()

    # Convert model.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor, y_const],
                                                  [out_tensor])
    tflite_model = converter.convert()

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 2)
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual('Const', input_details[1]['name'])

  def testGrapplerConstFolding(self):
    # Constant folding converts the following add operation to tf.broadcast_to
    # operation which was not supported by the TFLite at the time this test was
    # added.
    @def_function.function
    def plus_placeholder(x, placeholder):
      return x + placeholder

    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32)
      out_tensor = plus_placeholder(
          array_ops.zeros([2, 2, 2]),
          array_ops.reshape(in_tensor, shape=[2, 2]))
      sess = session.Session()

    # Convert model.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertLen(input_details, 1)
    self.assertEqual('Placeholder', input_details[0]['name'])


class DefaultConverterAttrsTest(LiteTest):

  def testAttrs(self):
    with ops.Graph().as_default():
      in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32)
      out_tensor = in_tensor + in_tensor
      sess = session.Session()

    # Convert model.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])

    # Assert output format.
    self.assertEqual(converter.output_format, lite_constants.TFLITE)

    # Assert the default inference type is float.
    self.assertEqual(converter.inference_type, dtypes.float32)

    # Assert the default inference type overrides are None.
    self.assertIsNone(converter.inference_input_type)
    self.assertIsNone(converter.inference_output_type)

    # Assert the default quantization options are not set.
    self.assertEqual(converter.quantized_input_stats, {})
    self.assertIsNone(converter.default_ranges_stats)
    self.assertFalse(converter.reorder_across_fake_quant)
    self.assertFalse(converter.change_concat_input_ranges)

    # Assert dropping control dependency is enabled by default.
    self.assertIsNotNone(converter.drop_control_dependency)

    # Assert dumping extra information is disabled by default.
    self.assertIsNone(converter.dump_graphviz_dir)
    self.assertFalse(converter.dump_graphviz_video)
    self.assertIsNone(converter.conversion_summary_dir)


class ControlFlowV1OpsTest(LiteTest):

  def testConverterErrorOnControlFlowV1Ops(self):
    graph_def_file = resource_loader.get_path_to_datafile(
        'testdata/control_flow_v1.pbtxt')
    input_arrays = ['a', 'b', 'c', 'd']
    output_arrays = ['Merge']

    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
                                                       input_arrays,
                                                       output_arrays)
    with self.assertRaises(ConverterError) as error:
      converter.convert()
    self.assertIn(
        'Failed to functionalize Control Flow V1 ops. Consider using Control '
        'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
        'tf/compat/v1/enable_control_flow_v2.', str(error.exception))


class QuantizationModeTest(LiteTest, parameterized.TestCase):

  @parameterized.named_parameters(
      ('size', lite.Optimize.OPTIMIZE_FOR_SIZE),
      ('latency', lite.Optimize.OPTIMIZE_FOR_LATENCY))
  def testDeprecatedOptionWarning(self, optimization):
    """Test if the warning message when using TOCO is logged."""
    log = io.StringIO()
    handler = logging.StreamHandler(log)
    logging.root.addHandler(handler)
    warning_message = 'please use optimizations=[Optimize.DEFAULT] instead.'
    lite.QuantizationMode([optimization], lite.TargetSpec(), None, None)
    self.assertIn(warning_message, log.getvalue())
    logging.root.removeHandler(handler)


if __name__ == '__main__':
  test.main()