tensorflow/models

View on GitHub
research/lstm_object_detection/builders/graph_rewriter_builder_test.py

Summary

Maintainability
D
2 days
Test Coverage
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for graph_rewriter_builder."""
import mock
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from lstm_object_detection.builders import graph_rewriter_builder
from lstm_object_detection.protos import quant_overrides_pb2
from object_detection.protos import graph_rewriter_pb2


class QuantizationBuilderTest(tf.test.TestCase):

  def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
    with mock.patch.object(
        contrib_quantize,
        'experimental_create_training_graph') as mock_quant_fn:
      with mock.patch.object(contrib_layers,
                             'summarize_collection') as mock_summarize_col:
        graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
        graph_rewriter_proto.quantization.delay = 10
        graph_rewriter_proto.quantization.weight_bits = 8
        graph_rewriter_proto.quantization.activation_bits = 8
        graph_rewrite_fn = graph_rewriter_builder.build(
            graph_rewriter_proto, is_training=True)
        graph_rewrite_fn()
        _, kwargs = mock_quant_fn.call_args
        self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
        self.assertEqual(kwargs['quant_delay'], 10)
        mock_summarize_col.assert_called_with('quant_vars')

  def testQuantizationBuilderSetsUpCorrectEvalArguments(self):
    with mock.patch.object(contrib_quantize,
                           'experimental_create_eval_graph') as mock_quant_fn:
      with mock.patch.object(contrib_layers,
                             'summarize_collection') as mock_summarize_col:
        graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
        graph_rewriter_proto.quantization.delay = 10
        graph_rewrite_fn = graph_rewriter_builder.build(
            graph_rewriter_proto, is_training=False)
        graph_rewrite_fn()
        _, kwargs = mock_quant_fn.call_args
        self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
        mock_summarize_col.assert_called_with('quant_vars')

  def testQuantizationBuilderAddsQuantOverride(self):
    graph = ops.Graph()
    with graph.as_default():
      self._buildGraph()

      quant_overrides_proto = quant_overrides_pb2.QuantOverrides()
      quant_config = quant_overrides_proto.quant_configs.add()
      quant_config.op_name = 'test_graph/add_ab'
      quant_config.quant_op_name = 'act_quant'
      quant_config.fixed_range = True
      quant_config.min = 0
      quant_config.max = 6
      quant_config.delay = 100

      graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
      graph_rewriter_proto.quantization.delay = 10
      graph_rewriter_proto.quantization.weight_bits = 8
      graph_rewriter_proto.quantization.activation_bits = 8

      graph_rewrite_fn = graph_rewriter_builder.build(
          graph_rewriter_proto,
          quant_overrides_config=quant_overrides_proto,
          is_training=True)
      graph_rewrite_fn()

      act_quant_found = False
      quant_delay_found = False
      for op in graph.get_operations():
        if (quant_config.quant_op_name in op.name and
            op.type == 'FakeQuantWithMinMaxArgs'):
          act_quant_found = True
          min_val = op.get_attr('min')
          max_val = op.get_attr('max')
          self.assertEqual(min_val, quant_config.min)
          self.assertEqual(max_val, quant_config.max)
        if ('activate_quant' in op.name and
            quant_config.quant_op_name in op.name and op.type == 'Const'):
          tensor = op.get_attr('value')
          if tensor.int64_val[0] == quant_config.delay:
            quant_delay_found = True

      self.assertTrue(act_quant_found)
      self.assertTrue(quant_delay_found)

  def _buildGraph(self, scope='test_graph'):
    with ops.name_scope(scope):
      a = tf.constant(10, dtype=dtypes.float32, name='input_a')
      b = tf.constant(20, dtype=dtypes.float32, name='input_b')
      ab = tf.add(a, b, name='add_ab')
      c = tf.constant(30, dtype=dtypes.float32, name='input_c')
      abc = tf.multiply(ab, c, name='mul_ab_c')
      return abc


if __name__ == '__main__':
  tf.test.main()