tensorflow/models

View on GitHub
research/attention_ocr/python/model_test.py

Summary

Maintainability
B
4 hrs
Test Coverage
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the model."""
import string

import numpy as np
import tensorflow as tf
from tensorflow.contrib import slim

import model
import data_provider


def create_fake_charset(num_char_classes):
  charset = {}
  for i in range(num_char_classes):
    charset[i] = string.printable[i % len(string.printable)]
  return charset


class ModelTest(tf.test.TestCase):

  def setUp(self):
    tf.test.TestCase.setUp(self)

    self.rng = np.random.RandomState([11, 23, 50])

    self.batch_size = 4
    self.image_width = 600
    self.image_height = 30
    self.seq_length = 40
    self.num_char_classes = 72
    self.null_code = 62
    self.num_views = 4

    feature_size = 288
    self.conv_tower_shape = (self.batch_size, 1, 72, feature_size)
    self.features_shape = (self.batch_size, self.seq_length, feature_size)
    self.chars_logit_shape = (self.batch_size, self.seq_length,
                              self.num_char_classes)
    self.length_logit_shape = (self.batch_size, self.seq_length + 1)
    # Placeholder knows image dimensions, but not batch size.
    self.input_images = tf.compat.v1.placeholder(
        tf.float32,
        shape=(None, self.image_height, self.image_width, 3),
        name='input_node')

    self.initialize_fakes()

  def initialize_fakes(self):
    self.images_shape = (self.batch_size, self.image_height, self.image_width,
                         3)
    self.fake_images = self.rng.randint(
        low=0, high=255, size=self.images_shape).astype('float32')
    self.fake_conv_tower_np = self.rng.randn(*self.conv_tower_shape).astype(
        'float32')
    self.fake_conv_tower = tf.constant(self.fake_conv_tower_np)
    self.fake_logits = tf.constant(
        self.rng.randn(*self.chars_logit_shape).astype('float32'))
    self.fake_labels = tf.constant(
        self.rng.randint(
            low=0,
            high=self.num_char_classes,
            size=(self.batch_size, self.seq_length)).astype('int64'))

  def create_model(self, charset=None):
    return model.Model(
        self.num_char_classes,
        self.seq_length,
        num_views=4,
        null_code=62,
        charset=charset)

  def test_char_related_shapes(self):
    charset = create_fake_charset(self.num_char_classes)
    ocr_model = self.create_model(charset=charset)
    with self.test_session() as sess:
      endpoints_tf = ocr_model.create_base(
          images=self.input_images, labels_one_hot=None)
      sess.run(tf.compat.v1.global_variables_initializer())
      tf.compat.v1.tables_initializer().run()
      endpoints = sess.run(
          endpoints_tf, feed_dict={self.input_images: self.fake_images})

      self.assertEqual(
          (self.batch_size, self.seq_length, self.num_char_classes),
          endpoints.chars_logit.shape)
      self.assertEqual(
          (self.batch_size, self.seq_length, self.num_char_classes),
          endpoints.chars_log_prob.shape)
      self.assertEqual((self.batch_size, self.seq_length),
                       endpoints.predicted_chars.shape)
      self.assertEqual((self.batch_size, self.seq_length),
                       endpoints.predicted_scores.shape)
      self.assertEqual((self.batch_size,), endpoints.predicted_text.shape)
      self.assertEqual((self.batch_size,), endpoints.predicted_conf.shape)
      self.assertEqual((self.batch_size,), endpoints.normalized_seq_conf.shape)

  def test_predicted_scores_are_within_range(self):
    ocr_model = self.create_model()

    _, _, scores = ocr_model.char_predictions(self.fake_logits)
    with self.test_session() as sess:
      scores_np = sess.run(
          scores, feed_dict={self.input_images: self.fake_images})

    values_in_range = (scores_np >= 0.0) & (scores_np <= 1.0)
    self.assertTrue(
        np.all(values_in_range),
        msg=('Scores contains out of the range values %s' %
             scores_np[np.logical_not(values_in_range)]))

  def test_conv_tower_shape(self):
    with self.test_session() as sess:
      ocr_model = self.create_model()
      conv_tower = ocr_model.conv_tower_fn(self.input_images)

      sess.run(tf.compat.v1.global_variables_initializer())
      conv_tower_np = sess.run(
          conv_tower, feed_dict={self.input_images: self.fake_images})

      self.assertEqual(self.conv_tower_shape, conv_tower_np.shape)

  def test_model_size_less_then1_gb(self):
    # NOTE: Actual amount of memory occupied my TF during training will be at
    # least 4X times bigger because of space need to store original weights,
    # updates, gradients and variances. It also depends on the type of used
    # optimizer.
    ocr_model = self.create_model()
    ocr_model.create_base(images=self.input_images, labels_one_hot=None)
    with self.test_session() as sess:
      tfprof_root = tf.compat.v1.profiler.profile(
          sess.graph,
          options=tf.compat.v1.profiler.ProfileOptionBuilder
          .trainable_variables_parameter())

      model_size_bytes = 4 * tfprof_root.total_parameters
      self.assertLess(model_size_bytes, 1 * 2**30)

  def test_create_summaries_is_runnable(self):
    ocr_model = self.create_model()
    data = data_provider.InputEndpoints(
        images=self.fake_images,
        images_orig=self.fake_images,
        labels=self.fake_labels,
        labels_one_hot=slim.one_hot_encoding(self.fake_labels,
                                             self.num_char_classes))
    endpoints = ocr_model.create_base(
        images=self.fake_images, labels_one_hot=None)
    charset = create_fake_charset(self.num_char_classes)
    summaries = ocr_model.create_summaries(
        data, endpoints, charset, is_training=False)
    with self.test_session() as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      sess.run(tf.compat.v1.local_variables_initializer())
      tf.compat.v1.tables_initializer().run()
      sess.run(summaries)  # just check it is runnable

  def test_sequence_loss_function_without_label_smoothing(self):
    model = self.create_model()
    model.set_mparam('sequence_loss_fn', label_smoothing=0)

    loss = model.sequence_loss_fn(self.fake_logits, self.fake_labels)
    with self.test_session() as sess:
      loss_np = sess.run(loss, feed_dict={self.input_images: self.fake_images})

    # This test checks that the loss function is 'runnable'.
    self.assertEqual(loss_np.shape, tuple())

  def encode_coordinates_alt(self, net):
    """An alternative implemenation for the encoding coordinates.

    Args:
      net: a tensor of shape=[batch_size, height, width, num_features]

    Returns:
      a list of tensors with encoded image coordinates in them.
    """
    batch_size = tf.shape(input=net)[0]
    _, h, w, _ = net.shape.as_list()
    h_loc = [
        tf.tile(
            tf.reshape(
                tf.contrib.layers.one_hot_encoding(
                    tf.constant([i]), num_classes=h), [h, 1]), [1, w])
        for i in range(h)
    ]
    h_loc = tf.concat([tf.expand_dims(t, 2) for t in h_loc], 2)
    w_loc = [
        tf.tile(
            tf.contrib.layers.one_hot_encoding(
                tf.constant([i]), num_classes=w),
            [h, 1]) for i in range(w)
    ]
    w_loc = tf.concat([tf.expand_dims(t, 2) for t in w_loc], 2)
    loc = tf.concat([h_loc, w_loc], 2)
    loc = tf.tile(tf.expand_dims(loc, 0), [batch_size, 1, 1, 1])
    return tf.concat([net, loc], 3)

  def test_encoded_coordinates_have_correct_shape(self):
    model = self.create_model()
    model.set_mparam('encode_coordinates_fn', enabled=True)
    conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)

    with self.test_session() as sess:
      conv_w_coords = sess.run(
          conv_w_coords_tf, feed_dict={self.input_images: self.fake_images})

    batch_size, height, width, feature_size = self.conv_tower_shape
    self.assertEqual(conv_w_coords.shape,
                     (batch_size, height, width, feature_size + height + width))

  def test_disabled_coordinate_encoding_returns_features_unchanged(self):
    model = self.create_model()
    model.set_mparam('encode_coordinates_fn', enabled=False)
    conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)

    with self.test_session() as sess:
      conv_w_coords = sess.run(
          conv_w_coords_tf, feed_dict={self.input_images: self.fake_images})

    self.assertAllEqual(conv_w_coords, self.fake_conv_tower_np)

  def test_coordinate_encoding_is_correct_for_simple_example(self):
    shape = (1, 2, 3, 4)  # batch_size, height, width, feature_size
    fake_conv_tower = tf.constant(2 * np.ones(shape), dtype=tf.float32)
    model = self.create_model()
    model.set_mparam('encode_coordinates_fn', enabled=True)
    conv_w_coords_tf = model.encode_coordinates_fn(fake_conv_tower)

    with self.test_session() as sess:
      conv_w_coords = sess.run(
          conv_w_coords_tf, feed_dict={self.input_images: self.fake_images})

    # Original features
    self.assertAllEqual(conv_w_coords[0, :, :, :4],
                        [[[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]],
                         [[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]])
    # Encoded coordinates
    self.assertAllEqual(conv_w_coords[0, :, :, 4:],
                        [[[1, 0, 1, 0, 0], [1, 0, 0, 1, 0], [1, 0, 0, 0, 1]],
                         [[0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 1, 0, 0, 1]]])

  def test_alt_implementation_of_coordinate_encoding_returns_same_values(self):
    model = self.create_model()
    model.set_mparam('encode_coordinates_fn', enabled=True)
    conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
    conv_w_coords_alt_tf = self.encode_coordinates_alt(self.fake_conv_tower)

    with self.test_session() as sess:
      conv_w_coords_tf, conv_w_coords_alt_tf = sess.run(
          [conv_w_coords_tf, conv_w_coords_alt_tf])

    self.assertAllEqual(conv_w_coords_tf, conv_w_coords_alt_tf)

  def test_predicted_text_has_correct_shape_w_charset(self):
    charset = create_fake_charset(self.num_char_classes)
    ocr_model = self.create_model(charset=charset)

    with self.test_session() as sess:
      endpoints_tf = ocr_model.create_base(
          images=self.fake_images, labels_one_hot=None)

      sess.run(tf.compat.v1.global_variables_initializer())
      tf.compat.v1.tables_initializer().run()
      endpoints = sess.run(endpoints_tf)

      self.assertEqual(endpoints.predicted_text.shape, (self.batch_size,))
      self.assertEqual(len(endpoints.predicted_text[0]), self.seq_length)


class CharsetMapperTest(tf.test.TestCase):

  def test_text_corresponds_to_ids(self):
    charset = create_fake_charset(36)
    ids = tf.constant([[17, 14, 21, 21, 24], [32, 24, 27, 21, 13]],
                      dtype=tf.int64)
    charset_mapper = model.CharsetMapper(charset)

    with self.test_session() as sess:
      tf.compat.v1.tables_initializer().run()
      text = sess.run(charset_mapper.get_text(ids))

    self.assertAllEqual(text, [b'hello', b'world'])


if __name__ == '__main__':
  tf.test.main()