tensorflow/models

View on GitHub
official/legacy/detection/dataloader/input_reader.py

Summary

Maintainability
A
1 hr
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data loader and input processing."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Optional, Text
import tensorflow as tf, tf_keras
from official.legacy.detection.dataloader import factory
from official.legacy.detection.dataloader import mode_keys as ModeKeys
from official.modeling.hyperparams import params_dict


class InputFn(object):
  """Input function that creates dataset from files."""

  def __init__(self,
               file_pattern: Text,
               params: params_dict.ParamsDict,
               mode: Text,
               batch_size: int,
               num_examples: Optional[int] = -1):
    """Initialize.

    Args:
      file_pattern: the file pattern for the data example (TFRecords).
      params: the parameter object for constructing example parser and model.
      mode: ModeKeys.TRAIN or ModeKeys.Eval
      batch_size: the data batch size.
      num_examples: If positive, only takes this number of examples and raise
        tf.errors.OutOfRangeError after that. If non-positive, it will be
        ignored.
    """
    assert file_pattern is not None
    assert mode is not None
    assert batch_size is not None
    self._file_pattern = file_pattern
    self._mode = mode
    self._is_training = (mode == ModeKeys.TRAIN)
    self._batch_size = batch_size
    self._num_examples = num_examples
    self._parser_fn = factory.parser_generator(params, mode)
    self._dataset_fn = tf.data.TFRecordDataset

    self._input_sharding = (not self._is_training)
    try:
      if self._is_training:
        self._input_sharding = params.train.input_sharding
      else:
        self._input_sharding = params.eval.input_sharding
    except AttributeError:
      pass

  def __call__(self, ctx=None, batch_size: int = None):
    """Provides tf.data.Dataset object.

    Args:
      ctx: context object.
      batch_size: expected batch size input data.

    Returns:
      tf.data.Dataset object.
    """
    if not batch_size:
      batch_size = self._batch_size
    assert batch_size is not None
    dataset = tf.data.Dataset.list_files(
        self._file_pattern, shuffle=self._is_training)

    if self._input_sharding and ctx and ctx.num_input_pipelines > 1:
      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
    dataset = dataset.cache()

    if self._is_training:
      dataset = dataset.repeat()

    dataset = dataset.interleave(
        map_func=self._dataset_fn,
        cycle_length=32,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if self._is_training:
      dataset = dataset.shuffle(1000)
    if self._num_examples > 0:
      dataset = dataset.take(self._num_examples)

    # Parses the fetched records to input tensors for model function.
    dataset = dataset.map(
        self._parser_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset