tensorflow/models

View on GitHub
official/vision/data/tfrecord_lib_test.py

Summary

Maintainability
A
3 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for tfrecord_lib."""

import os

from absl import flags
from absl.testing import parameterized
import tensorflow as tf, tf_keras

from official.vision.data import create_coco_tf_record as create_coco_tf_record_lib
from official.vision.data import tfrecord_lib


FLAGS = flags.FLAGS


def process_sample(x):
  d = {'x': x}
  return tf.train.Example(features=tf.train.Features(feature=d)), 0


def parse_function(example_proto):

  feature_description = {
      'x': tf.io.FixedLenFeature([], tf.int64, default_value=-1)
  }
  return tf.io.parse_single_example(example_proto, feature_description)


class TfrecordLibTest(parameterized.TestCase):

  def test_write_tf_record_dataset(self):
    data = [(tfrecord_lib.convert_to_feature(i),) for i in range(17)]

    path = os.path.join(FLAGS.test_tmpdir, 'train')

    tfrecord_lib.write_tf_record_dataset(
        path, data, process_sample, 3, multiple_processes=0)
    tfrecord_files = tf.io.gfile.glob(path + '*')

    self.assertLen(tfrecord_files, 3)

    dataset = tf.data.TFRecordDataset(tfrecord_files)
    dataset = dataset.map(parse_function)

    read_values = set(d['x'] for d in dataset.as_numpy_iterator())
    self.assertSetEqual(read_values, set(range(17)))

  def test_convert_to_feature_float(self):

    proto = tfrecord_lib.convert_to_feature(0.0)
    self.assertEqual(proto.float_list.value[0], 0.0)

  def test_convert_to_feature_int(self):

    proto = tfrecord_lib.convert_to_feature(0)
    self.assertEqual(proto.int64_list.value[0], 0)

  def test_convert_to_feature_bytes(self):

    proto = tfrecord_lib.convert_to_feature(b'123')
    self.assertEqual(proto.bytes_list.value[0], b'123')

  def test_convert_to_feature_float_list(self):

    proto = tfrecord_lib.convert_to_feature([0.0, 1.0])
    self.assertSequenceAlmostEqual(proto.float_list.value, [0.0, 1.0])

  def test_convert_to_feature_int_list(self):

    proto = tfrecord_lib.convert_to_feature([0, 1])
    self.assertSequenceAlmostEqual(proto.int64_list.value, [0, 1])

  def test_convert_to_feature_bytes_list(self):

    proto = tfrecord_lib.convert_to_feature([b'123', b'456'])
    self.assertSequenceAlmostEqual(proto.bytes_list.value, [b'123', b'456'])

  def test_obj_annotation_tf_example(self):
    images = [
        {
            'id': 0,
            'file_name': 'example1.jpg',
            'height': 512,
            'width': 512,
        },
        {
            'id': 1,
            'file_name': 'example2.jpg',
            'height': 512,
            'width': 512,
        },
    ]
    img_to_obj_annotation = {
        0: [{
            'id': 0,
            'image_id': 0,
            'category_id': 1,
            'bbox': [3, 1, 511, 510],
            'area': 260610.00,
            'segmentation': [],
            'iscrowd': 0,
        }],
        1: [{
            'id': 1,
            'image_id': 1,
            'category_id': 1,
            'bbox': [1, 1, 100, 150],
            'area': 15000.00,
            'segmentation': [],
            'iscrowd': 0,
        }],
    }
    id_to_name_map = {
        0: 'Super-Class',
        1: 'Class-1',
    }

    temp_dir = FLAGS.test_tmpdir
    image_dir = os.path.join(temp_dir, 'data')
    if not os.path.exists(image_dir):
      os.mkdir(image_dir)
    for image in images:
      image_path = os.path.join(image_dir, image['file_name'])
      tf_keras.utils.save_img(
          image_path,
          tf.ones(shape=(image['height'], image['width'], 3)).numpy(),
      )

    output_path = os.path.join(image_dir, 'train')
    coco_annotations_iter = create_coco_tf_record_lib.generate_annotations(
        images=images,
        image_dirs=[image_dir],
        panoptic_masks_dir=None,
        img_to_obj_annotation=img_to_obj_annotation,
        img_to_caption_annotation=None,
        img_to_panoptic_annotation=None,
        is_category_thing=None,
        id_to_name_map=id_to_name_map,
        include_panoptic_masks=False,
        include_masks=False,
    )

    tfrecord_lib.write_tf_record_dataset(
        output_path,
        coco_annotations_iter,
        create_coco_tf_record_lib.create_tf_example,
        1,
        multiple_processes=0,
    )
    tfrecord_files = tf.io.gfile.glob(output_path + '*')

    self.assertLen(tfrecord_files, 1)

    ds = tf.data.TFRecordDataset(tfrecord_files)
    assertion_count = 0
    for _ in ds:
      assertion_count += 1

    self.assertEqual(assertion_count, 2)


if __name__ == '__main__':
  tf.test.main()