tensorflow/models

View on GitHub
research/object_detection/dataset_tools/context_rcnn/add_context_to_examples.py

Summary

Maintainability
F
6 days
Test Coverage
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""A Beam job to add contextual memory banks to tf.Examples.

This tool groups images containing bounding boxes and embedded context features
by a key, either `image/location` or `image/seq_id`, and time horizon,
then uses these groups to build up a contextual memory bank from the embedded
context features from each image in the group and adds that context to the
output tf.Examples for each image in the group.

Steps to generate a dataset with context from one with bounding boxes and
embedded context features:
1. Use object/detection/export_inference_graph.py to get a `saved_model` for
  inference. The input node must accept a tf.Example proto.
2. Run this tool with `saved_model` from step 1 and a TFRecord of tf.Example
  protos containing images, bounding boxes, and embedded context features.
  The context features can be added to tf.Examples using
  generate_embedding_data.py.

Example Usage:
--------------
python add_context_to_examples.py \
  --input_tfrecord path/to/input_tfrecords* \
  --output_tfrecord path/to/output_tfrecords \
  --sequence_key image/location \
  --time_horizon month

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import copy
import datetime
import io
import itertools
import json
import os
import numpy as np
import PIL.Image
import six
import tensorflow as tf

try:
  import apache_beam as beam  # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
  pass


class ReKeyDataFn(beam.DoFn):
  """Re-keys tfrecords by sequence_key.

  This Beam DoFn re-keys the tfrecords by a user-defined sequence_key
  """

  def __init__(self, sequence_key, time_horizon,
               reduce_image_size, max_image_dimension):
    """Initialization function.

    Args:
      sequence_key: A feature name to use as a key for grouping sequences.
        Must point to a key of type bytes_list
      time_horizon: What length of time to use to partition the data when
        building the memory banks. Options: `year`, `month`, `week`, `day `,
        `hour`, `minute`, None
      reduce_image_size: Whether to reduce the sizes of the stored images.
      max_image_dimension: maximum dimension of reduced images
    """
    self._sequence_key = sequence_key
    if time_horizon is None or time_horizon in {'year', 'month', 'week', 'day',
                                                'hour', 'minute'}:
      self._time_horizon = time_horizon
    else:
      raise ValueError('Time horizon not supported.')
    self._reduce_image_size = reduce_image_size
    self._max_image_dimension = max_image_dimension
    self._session = None
    self._num_examples_processed = beam.metrics.Metrics.counter(
        'data_rekey', 'num_tf_examples_processed')
    self._num_images_resized = beam.metrics.Metrics.counter(
        'data_rekey', 'num_images_resized')
    self._num_images_read = beam.metrics.Metrics.counter(
        'data_rekey', 'num_images_read')
    self._num_images_found = beam.metrics.Metrics.counter(
        'data_rekey', 'num_images_read')
    self._num_got_shape = beam.metrics.Metrics.counter(
        'data_rekey', 'num_images_got_shape')
    self._num_images_found_size = beam.metrics.Metrics.counter(
        'data_rekey', 'num_images_found_size')
    self._num_examples_cleared = beam.metrics.Metrics.counter(
        'data_rekey', 'num_examples_cleared')
    self._num_examples_updated = beam.metrics.Metrics.counter(
        'data_rekey', 'num_examples_updated')

  def process(self, tfrecord_entry):
    return self._rekey_examples(tfrecord_entry)

  def _largest_size_at_most(self, height, width, largest_side):
    """Computes new shape with the largest side equal to `largest_side`.

    Args:
      height: an int indicating the current height.
      width: an int indicating the current width.
      largest_side: A python integer indicating the size of
        the largest side after resize.
    Returns:
      new_height: an int indicating the new height.
      new_width: an int indicating the new width.
    """

    x_scale = float(largest_side) / float(width)
    y_scale = float(largest_side) / float(height)
    scale = min(x_scale, y_scale)

    new_width = int(width * scale)
    new_height = int(height * scale)

    return new_height, new_width

  def _resize_image(self, input_example):
    """Resizes the image within input_example and updates the height and width.

    Args:
      input_example: A tf.Example that we want to update to contain a resized
        image.
    Returns:
      input_example: Updated tf.Example.
    """

    original_image = copy.deepcopy(
        input_example.features.feature['image/encoded'].bytes_list.value[0])
    self._num_images_read.inc(1)

    height = copy.deepcopy(
        input_example.features.feature['image/height'].int64_list.value[0])

    width = copy.deepcopy(
        input_example.features.feature['image/width'].int64_list.value[0])

    self._num_got_shape.inc(1)

    new_height, new_width = self._largest_size_at_most(
        height, width, self._max_image_dimension)

    self._num_images_found_size.inc(1)

    encoded_jpg_io = io.BytesIO(original_image)
    image = PIL.Image.open(encoded_jpg_io)
    resized_image = image.resize((new_width, new_height))

    with io.BytesIO() as output:
      resized_image.save(output, format='JPEG')
      encoded_resized_image = output.getvalue()

    self._num_images_resized.inc(1)

    del input_example.features.feature['image/encoded'].bytes_list.value[:]
    del input_example.features.feature['image/height'].int64_list.value[:]
    del input_example.features.feature['image/width'].int64_list.value[:]

    self._num_examples_cleared.inc(1)

    input_example.features.feature['image/encoded'].bytes_list.value.extend(
        [encoded_resized_image])
    input_example.features.feature['image/height'].int64_list.value.extend(
        [new_height])
    input_example.features.feature['image/width'].int64_list.value.extend(
        [new_width])
    self._num_examples_updated.inc(1)

    return input_example

  def _rekey_examples(self, tfrecord_entry):
    serialized_example = copy.deepcopy(tfrecord_entry)

    input_example = tf.train.Example.FromString(serialized_example)

    self._num_images_found.inc(1)

    if self._reduce_image_size:
      input_example = self._resize_image(input_example)
      self._num_images_resized.inc(1)

    new_key = input_example.features.feature[
        self._sequence_key].bytes_list.value[0]

    if self._time_horizon:
      date_captured = datetime.datetime.strptime(
          six.ensure_str(input_example.features.feature[
              'image/date_captured'].bytes_list.value[0]), '%Y-%m-%d %H:%M:%S')
      year = date_captured.year
      month = date_captured.month
      day = date_captured.day
      week = np.floor(float(day) / float(7))
      hour = date_captured.hour
      minute = date_captured.minute

      if self._time_horizon == 'year':
        new_key = new_key + six.ensure_binary('/' + str(year))
      elif self._time_horizon == 'month':
        new_key = new_key + six.ensure_binary(
            '/' + str(year) + '/' + str(month))
      elif self._time_horizon == 'week':
        new_key = new_key + six.ensure_binary(
            '/' + str(year) + '/' + str(month) + '/' + str(week))
      elif self._time_horizon == 'day':
        new_key = new_key + six.ensure_binary(
            '/' + str(year) + '/' + str(month) + '/' + str(day))
      elif self._time_horizon == 'hour':
        new_key = new_key + six.ensure_binary(
            '/' + str(year) + '/' + str(month) + '/' + str(day) + '/' + (
                str(hour)))
      elif self._time_horizon == 'minute':
        new_key = new_key + six.ensure_binary(
            '/' + str(year) + '/' + str(month) + '/' + str(day) + '/' + (
                str(hour) + '/' + str(minute)))

    self._num_examples_processed.inc(1)

    return [(new_key, input_example)]


class SortGroupedDataFn(beam.DoFn):
  """Sorts data within a keyed group.

  This Beam DoFn sorts the grouped list of image examples by frame_num
  """

  def __init__(self, sequence_key, sorted_image_ids,
               max_num_elements_in_context_features):
    """Initialization function.

    Args:
      sequence_key: A feature name to use as a key for grouping sequences.
        Must point to a key of type bytes_list
      sorted_image_ids: Whether the image ids are sortable to use as sorting
        tie-breakers
      max_num_elements_in_context_features: The maximum number of elements
        allowed in the memory bank
    """
    self._session = None
    self._num_examples_processed = beam.metrics.Metrics.counter(
        'sort_group', 'num_groups_sorted')
    self._too_many_elements = beam.metrics.Metrics.counter(
        'sort_group', 'too_many_elements')
    self._split_elements = beam.metrics.Metrics.counter(
        'sort_group', 'split_elements')
    self._sequence_key = six.ensure_binary(sequence_key)
    self._sorted_image_ids = sorted_image_ids
    self._max_num_elements_in_context_features = (
        max_num_elements_in_context_features)

  def process(self, grouped_entry):
    return self._sort_image_examples(grouped_entry)

  def _sort_image_examples(self, grouped_entry):
    key, example_collection = grouped_entry
    example_list = list(example_collection)

    def get_frame_num(example):
      return example.features.feature['image/seq_frame_num'].int64_list.value[0]

    def get_date_captured(example):
      return datetime.datetime.strptime(
          six.ensure_str(
              example.features.feature[
                  'image/date_captured'].bytes_list.value[0]),
          '%Y-%m-%d %H:%M:%S')

    def get_image_id(example):
      return example.features.feature['image/source_id'].bytes_list.value[0]

    if self._sequence_key == six.ensure_binary('image/seq_id'):
      sorting_fn = get_frame_num
    elif self._sequence_key == six.ensure_binary('image/location'):
      if self._sorted_image_ids:
        sorting_fn = get_image_id
      else:
        sorting_fn = get_date_captured

    sorted_example_list = sorted(example_list, key=sorting_fn)

    num_embeddings = 0
    for example in sorted_example_list:
      num_embeddings += example.features.feature[
          'image/embedding_count'].int64_list.value[0]

    self._num_examples_processed.inc(1)

    # To handle cases where there are more context embeddings within
    # the time horizon than the specified maximum, we split the context group
    # into subsets sequentially in time, with each subset having the maximum
    # number of context embeddings except the final one, which holds the
    # remainder.
    if num_embeddings > self._max_num_elements_in_context_features:
      leftovers = sorted_example_list
      output_list = []
      count = 0
      self._too_many_elements.inc(1)
      num_embeddings = 0
      max_idx = 0
      for idx, example in enumerate(leftovers):
        num_embeddings += example.features.feature[
            'image/embedding_count'].int64_list.value[0]
        if num_embeddings <= self._max_num_elements_in_context_features:
          max_idx = idx
      while num_embeddings > self._max_num_elements_in_context_features:
        self._split_elements.inc(1)
        new_key = key + six.ensure_binary('_' + str(count))
        new_list = leftovers[:max_idx]
        output_list.append((new_key, new_list))
        leftovers = leftovers[max_idx:]
        count += 1
        num_embeddings = 0
        max_idx = 0
        for idx, example in enumerate(leftovers):
          num_embeddings += example.features.feature[
              'image/embedding_count'].int64_list.value[0]
          if num_embeddings <= self._max_num_elements_in_context_features:
            max_idx = idx
      new_key = key + six.ensure_binary('_' + str(count))
      output_list.append((new_key, leftovers))
    else:
      output_list = [(key, sorted_example_list)]

    return output_list


def get_sliding_window(example_list, max_clip_length, stride_length):
  """Yields a sliding window over data from example_list.

  Sliding window has width max_clip_len (n) and stride stride_len (m).
     s -> (s0,s1,...s[n-1]), (s[m],s[m+1],...,s[m+n]), ...

  Args:
    example_list: A list of examples.
    max_clip_length: The maximum length of each clip.
    stride_length: The stride between each clip.

  Yields:
    A list of lists of examples, each with length <= max_clip_length
  """

  # check if the list is too short to slide over
  if len(example_list) < max_clip_length:
    yield example_list
  else:
    starting_values = [i*stride_length for i in
                       range(len(example_list)) if
                       len(example_list) > i*stride_length]
    for start in starting_values:
      result = tuple(itertools.islice(example_list, start,
                                      min(start + max_clip_length,
                                          len(example_list))))
      yield result


class GenerateContextFn(beam.DoFn):
  """Generates context data for camera trap images.

  This Beam DoFn builds up contextual memory banks from groups of images and
  stores them in the output tf.Example or tf.Sequence_example for each image.
  """

  def __init__(self, sequence_key, add_context_features, image_ids_to_keep,
               keep_context_features_image_id_list=False,
               subsample_context_features_rate=0,
               keep_only_positives=False,
               context_features_score_threshold=0.7,
               keep_only_positives_gt=False,
               max_num_elements_in_context_features=5000,
               pad_context_features=False,
               output_type='tf_example', max_clip_length=None,
               context_feature_length=2057):
    """Initialization function.

    Args:
      sequence_key: A feature name to use as a key for grouping sequences.
      add_context_features: Whether to keep and store the contextual memory
        bank.
      image_ids_to_keep: A list of image ids to save, to use to build data
        subsets for evaluation.
      keep_context_features_image_id_list: Whether to save an ordered list of
        the ids of the images in the contextual memory bank.
      subsample_context_features_rate: What rate to subsample images for the
        contextual memory bank.
      keep_only_positives: Whether to only keep high scoring
        (>context_features_score_threshold) features in the contextual memory
        bank.
      context_features_score_threshold: What threshold to use for keeping
        features.
      keep_only_positives_gt: Whether to only keep features from images that
        contain objects based on the ground truth (for training).
      max_num_elements_in_context_features: the maximum number of elements in
        the memory bank
      pad_context_features: Whether to pad the context features to a fixed size.
      output_type: What type of output, tf_example of tf_sequence_example
      max_clip_length: The maximum length of a sequence example, before
        splitting into multiple
      context_feature_length: The length of the context feature embeddings
        stored in the input data.
    """
    self._session = None
    self._num_examples_processed = beam.metrics.Metrics.counter(
        'sequence_data_generation', 'num_seq_examples_processed')
    self._num_keys_processed = beam.metrics.Metrics.counter(
        'sequence_data_generation', 'num_keys_processed')
    self._sequence_key = sequence_key
    self._add_context_features = add_context_features
    self._pad_context_features = pad_context_features
    self._output_type = output_type
    self._max_clip_length = max_clip_length
    if six.ensure_str(image_ids_to_keep) == 'All':
      self._image_ids_to_keep = None
    else:
      with tf.io.gfile.GFile(image_ids_to_keep) as f:
        self._image_ids_to_keep = json.load(f)
    self._keep_context_features_image_id_list = (
        keep_context_features_image_id_list)
    self._subsample_context_features_rate = subsample_context_features_rate
    self._keep_only_positives = keep_only_positives
    self._keep_only_positives_gt = keep_only_positives_gt
    self._context_features_score_threshold = context_features_score_threshold
    self._max_num_elements_in_context_features = (
        max_num_elements_in_context_features)
    self._context_feature_length = context_feature_length

    self._images_kept = beam.metrics.Metrics.counter(
        'sequence_data_generation', 'images_kept')
    self._images_loaded = beam.metrics.Metrics.counter(
        'sequence_data_generation', 'images_loaded')

  def process(self, grouped_entry):
    return self._add_context_to_example(copy.deepcopy(grouped_entry))

  def _build_context_features(self, example_list):
    context_features = []
    context_features_image_id_list = []
    count = 0
    example_embedding = []

    for idx, example in enumerate(example_list):
      if self._subsample_context_features_rate > 0:
        if (idx % self._subsample_context_features_rate) != 0:
          example.features.feature[
              'context_features_idx'].int64_list.value.append(
                  self._max_num_elements_in_context_features + 1)
          continue
      if self._keep_only_positives:
        if example.features.feature[
            'image/embedding_score'
            ].float_list.value[0] < self._context_features_score_threshold:
          example.features.feature[
              'context_features_idx'].int64_list.value.append(
                  self._max_num_elements_in_context_features + 1)
          continue
      if self._keep_only_positives_gt:
        if len(example.features.feature[
            'image/object/bbox/xmin'
            ].float_list.value) < 1:
          example.features.feature[
              'context_features_idx'].int64_list.value.append(
                  self._max_num_elements_in_context_features + 1)
          continue

      example_embedding = list(example.features.feature[
          'image/embedding'].float_list.value)
      context_features.extend(example_embedding)
      num_embeddings = example.features.feature[
          'image/embedding_count'].int64_list.value[0]
      example_image_id = example.features.feature[
          'image/source_id'].bytes_list.value[0]
      for _ in range(num_embeddings):
        example.features.feature[
            'context_features_idx'].int64_list.value.append(count)
        count += 1
        context_features_image_id_list.append(example_image_id)

    if not example_embedding:
      example_embedding.append(np.zeros(self._context_feature_length))

    feature_length = self._context_feature_length

    # If the example_list is not empty and image/embedding_length is in the
    # featture dict, feature_length will be assigned to that. Otherwise, it will
    # be kept as default.
    if example_list and (
        'image/embedding_length' in example_list[0].features.feature):
      feature_length = example_list[0].features.feature[
          'image/embedding_length'].int64_list.value[0]

    if self._pad_context_features:
      while len(context_features_image_id_list) < (
          self._max_num_elements_in_context_features):
        context_features_image_id_list.append('')

    return context_features, feature_length, context_features_image_id_list

  def _add_context_to_example(self, grouped_entry):
    key, example_collection = grouped_entry
    list_of_examples = []

    example_list = list(example_collection)

    if self._add_context_features:
      context_features, feature_length, context_features_image_id_list = (
          self._build_context_features(example_list))

    if self._image_ids_to_keep is not None:
      new_example_list = []
      for example in example_list:
        im_id = example.features.feature['image/source_id'].bytes_list.value[0]
        self._images_loaded.inc(1)
        if six.ensure_str(im_id) in self._image_ids_to_keep:
          self._images_kept.inc(1)
          new_example_list.append(example)
      if new_example_list:
        example_list = new_example_list
      else:
        return []

    if self._output_type == 'tf_sequence_example':
      if self._max_clip_length is not None:
        # For now, no overlap
        clips = get_sliding_window(
            example_list, self._max_clip_length, self._max_clip_length)
      else:
        clips = [example_list]

      for clip_num, clip_list in enumerate(clips):
        # initialize sequence example
        seq_example = tf.train.SequenceExample()
        video_id = six.ensure_str(key)+'_'+ str(clip_num)
        seq_example.context.feature['clip/media_id'].bytes_list.value.append(
            video_id.encode('utf8'))
        seq_example.context.feature['clip/frames'].int64_list.value.append(
            len(clip_list))

        seq_example.context.feature[
            'clip/start/timestamp'].int64_list.value.append(0)
        seq_example.context.feature[
            'clip/end/timestamp'].int64_list.value.append(len(clip_list))
        seq_example.context.feature['image/format'].bytes_list.value.append(
            six.ensure_binary('JPG'))
        seq_example.context.feature['image/channels'].int64_list.value.append(3)
        context_example = clip_list[0]
        seq_example.context.feature['image/height'].int64_list.value.append(
            context_example.features.feature[
                'image/height'].int64_list.value[0])
        seq_example.context.feature['image/width'].int64_list.value.append(
            context_example.features.feature['image/width'].int64_list.value[0])

        seq_example.context.feature[
            'image/context_feature_length'].int64_list.value.append(
                feature_length)
        seq_example.context.feature[
            'image/context_features'].float_list.value.extend(
                context_features)
        if self._keep_context_features_image_id_list:
          seq_example.context.feature[
              'image/context_features_image_id_list'].bytes_list.value.extend(
                  context_features_image_id_list)

        encoded_image_list = seq_example.feature_lists.feature_list[
            'image/encoded']
        timestamps_list = seq_example.feature_lists.feature_list[
            'image/timestamp']
        context_features_idx_list = seq_example.feature_lists.feature_list[
            'image/context_features_idx']
        date_captured_list = seq_example.feature_lists.feature_list[
            'image/date_captured']
        unix_time_list = seq_example.feature_lists.feature_list[
            'image/unix_time']
        location_list = seq_example.feature_lists.feature_list['image/location']
        image_ids_list = seq_example.feature_lists.feature_list[
            'image/source_id']
        gt_xmin_list = seq_example.feature_lists.feature_list[
            'region/bbox/xmin']
        gt_xmax_list = seq_example.feature_lists.feature_list[
            'region/bbox/xmax']
        gt_ymin_list = seq_example.feature_lists.feature_list[
            'region/bbox/ymin']
        gt_ymax_list = seq_example.feature_lists.feature_list[
            'region/bbox/ymax']
        gt_type_list = seq_example.feature_lists.feature_list[
            'region/label/index']
        gt_type_string_list = seq_example.feature_lists.feature_list[
            'region/label/string']
        gt_is_annotated_list = seq_example.feature_lists.feature_list[
            'region/is_annotated']

        for idx, example in enumerate(clip_list):

          encoded_image = encoded_image_list.feature.add()
          encoded_image.bytes_list.value.extend(
              example.features.feature['image/encoded'].bytes_list.value)

          image_id = image_ids_list.feature.add()
          image_id.bytes_list.value.append(
              example.features.feature['image/source_id'].bytes_list.value[0])

          timestamp = timestamps_list.feature.add()
          # Timestamp is currently order in the list.
          timestamp.int64_list.value.extend([idx])

          context_features_idx = context_features_idx_list.feature.add()
          context_features_idx.int64_list.value.extend(
              example.features.feature['context_features_idx'].int64_list.value)

          date_captured = date_captured_list.feature.add()
          date_captured.bytes_list.value.extend(
              example.features.feature['image/date_captured'].bytes_list.value)
          unix_time = unix_time_list.feature.add()
          unix_time.float_list.value.extend(
              example.features.feature['image/unix_time'].float_list.value)
          location = location_list.feature.add()
          location.bytes_list.value.extend(
              example.features.feature['image/location'].bytes_list.value)

          gt_xmin = gt_xmin_list.feature.add()
          gt_xmax = gt_xmax_list.feature.add()
          gt_ymin = gt_ymin_list.feature.add()
          gt_ymax = gt_ymax_list.feature.add()
          gt_type = gt_type_list.feature.add()
          gt_type_str = gt_type_string_list.feature.add()

          gt_is_annotated = gt_is_annotated_list.feature.add()
          gt_is_annotated.int64_list.value.append(1)

          gt_xmin.float_list.value.extend(
              example.features.feature[
                  'image/object/bbox/xmin'].float_list.value)
          gt_xmax.float_list.value.extend(
              example.features.feature[
                  'image/object/bbox/xmax'].float_list.value)
          gt_ymin.float_list.value.extend(
              example.features.feature[
                  'image/object/bbox/ymin'].float_list.value)
          gt_ymax.float_list.value.extend(
              example.features.feature[
                  'image/object/bbox/ymax'].float_list.value)

          gt_type.int64_list.value.extend(
              example.features.feature[
                  'image/object/class/label'].int64_list.value)
          gt_type_str.bytes_list.value.extend(
              example.features.feature[
                  'image/object/class/text'].bytes_list.value)

        self._num_examples_processed.inc(1)
        list_of_examples.append(seq_example)

    elif self._output_type == 'tf_example':

      for example in example_list:
        im_id = example.features.feature['image/source_id'].bytes_list.value[0]

        if self._add_context_features:
          example.features.feature[
              'image/context_features'].float_list.value.extend(
                  context_features)
          example.features.feature[
              'image/context_feature_length'].int64_list.value.append(
                  feature_length)

        if self._keep_context_features_image_id_list:
          example.features.feature[
              'image/context_features_image_id_list'].bytes_list.value.extend(
                  context_features_image_id_list)

        self._num_examples_processed.inc(1)
        list_of_examples.append(example)

    return list_of_examples


def construct_pipeline(pipeline,
                       input_tfrecord,
                       output_tfrecord,
                       sequence_key,
                       time_horizon=None,
                       subsample_context_features_rate=0,
                       reduce_image_size=True,
                       max_image_dimension=1024,
                       add_context_features=True,
                       sorted_image_ids=True,
                       image_ids_to_keep='All',
                       keep_context_features_image_id_list=False,
                       keep_only_positives=False,
                       context_features_score_threshold=0.7,
                       keep_only_positives_gt=False,
                       max_num_elements_in_context_features=5000,
                       num_shards=0,
                       output_type='tf_example',
                       max_clip_length=None,
                       context_feature_length=2057):
  """Returns a beam pipeline to run object detection inference.

  Args:
    pipeline: Initialized beam pipeline.
    input_tfrecord: An TFRecord of tf.train.Example protos containing images.
    output_tfrecord: An TFRecord of tf.train.Example protos that contain images
      in the input TFRecord and the detections from the model.
    sequence_key: A feature name to use as a key for grouping sequences.
    time_horizon: What length of time to use to partition the data when building
      the memory banks. Options: `year`, `month`, `week`, `day `, `hour`,
      `minute`, None.
    subsample_context_features_rate: What rate to subsample images for the
        contextual memory bank.
    reduce_image_size: Whether to reduce the size of the stored images.
    max_image_dimension: The maximum image dimension to use for resizing.
    add_context_features: Whether to keep and store the contextual memory bank.
    sorted_image_ids: Whether the image ids are sortable, and can be used as
      datetime tie-breakers when building memory banks.
    image_ids_to_keep: A list of image ids to save, to use to build data subsets
      for evaluation.
    keep_context_features_image_id_list: Whether to save an ordered list of the
      ids of the images in the contextual memory bank.
    keep_only_positives: Whether to only keep high scoring
      (>context_features_score_threshold) features in the contextual memory
      bank.
    context_features_score_threshold: What threshold to use for keeping
      features.
    keep_only_positives_gt: Whether to only keep features from images that
      contain objects based on the ground truth (for training).
    max_num_elements_in_context_features: the maximum number of elements in the
      memory bank
    num_shards: The number of output shards.
    output_type: What type of output, tf_example of tf_sequence_example
    max_clip_length: The maximum length of a sequence example, before
      splitting into multiple
    context_feature_length: The length of the context feature embeddings stored
      in the input data.
  """
  if output_type == 'tf_example':
    coder = beam.coders.ProtoCoder(tf.train.Example)
  elif output_type == 'tf_sequence_example':
    coder = beam.coders.ProtoCoder(tf.train.SequenceExample)
  else:
    raise ValueError('Unsupported output type.')
  input_collection = (
      pipeline | 'ReadInputTFRecord' >> beam.io.tfrecordio.ReadFromTFRecord(
          input_tfrecord,
          coder=beam.coders.BytesCoder()))
  rekey_collection = input_collection | 'RekeyExamples' >> beam.ParDo(
      ReKeyDataFn(sequence_key, time_horizon,
                  reduce_image_size, max_image_dimension))
  grouped_collection = (
      rekey_collection | 'GroupBySequenceKey' >> beam.GroupByKey())
  grouped_collection = (
      grouped_collection | 'ReshuffleGroups' >> beam.Reshuffle())
  ordered_collection = (
      grouped_collection | 'OrderByFrameNumber' >> beam.ParDo(
          SortGroupedDataFn(sequence_key, sorted_image_ids,
                            max_num_elements_in_context_features)))
  ordered_collection = (
      ordered_collection | 'ReshuffleSortedGroups' >> beam.Reshuffle())
  output_collection = (
      ordered_collection | 'AddContextToExamples' >> beam.ParDo(
          GenerateContextFn(
              sequence_key, add_context_features, image_ids_to_keep,
              keep_context_features_image_id_list=(
                  keep_context_features_image_id_list),
              subsample_context_features_rate=subsample_context_features_rate,
              keep_only_positives=keep_only_positives,
              keep_only_positives_gt=keep_only_positives_gt,
              context_features_score_threshold=(
                  context_features_score_threshold),
              max_num_elements_in_context_features=(
                  max_num_elements_in_context_features),
              output_type=output_type,
              max_clip_length=max_clip_length,
              context_feature_length=context_feature_length)))

  output_collection = (
      output_collection | 'ReshuffleExamples' >> beam.Reshuffle())
  _ = output_collection | 'WritetoDisk' >> beam.io.tfrecordio.WriteToTFRecord(
      output_tfrecord,
      num_shards=num_shards,
      coder=coder)


def parse_args(argv):
  """Command-line argument parser.

  Args:
    argv: command line arguments
  Returns:
    beam_args: Arguments for the beam pipeline.
    pipeline_args: Arguments for the pipeline options, such as runner type.
  """
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--input_tfrecord',
      dest='input_tfrecord',
      required=True,
      help='TFRecord containing images in tf.Example format for object '
      'detection, with bounding boxes and contextual feature embeddings.')
  parser.add_argument(
      '--output_tfrecord',
      dest='output_tfrecord',
      required=True,
      help='TFRecord containing images in tf.Example format, with added '
      'contextual memory banks.')
  parser.add_argument(
      '--sequence_key',
      dest='sequence_key',
      default='image/location',
      help='Key to use when grouping sequences: so far supports `image/seq_id` '
      'and `image/location`.')
  parser.add_argument(
      '--context_feature_length',
      dest='context_feature_length',
      default=2057,
      help='The length of the context feature embeddings stored in the input '
      'data.')
  parser.add_argument(
      '--time_horizon',
      dest='time_horizon',
      default=None,
      help='What time horizon to use when splitting the data, if any. Options '
      'are: `year`, `month`, `week`, `day `, `hour`, `minute`, `None`.')
  parser.add_argument(
      '--subsample_context_features_rate',
      dest='subsample_context_features_rate',
      default=0,
      help='Whether to subsample the context_features, and if so how many to '
      'sample. If the rate is set to X, it will sample context from 1 out of '
      'every X images. Default is sampling from every image, which is X=0.')
  parser.add_argument(
      '--reduce_image_size',
      dest='reduce_image_size',
      default=True,
      help='downsamples images to have longest side max_image_dimension, '
      'maintaining aspect ratio')
  parser.add_argument(
      '--max_image_dimension',
      dest='max_image_dimension',
      default=1024,
      help='Sets max image dimension for resizing.')
  parser.add_argument(
      '--add_context_features',
      dest='add_context_features',
      default=True,
      help='Adds a memory bank of embeddings to each clip')
  parser.add_argument(
      '--sorted_image_ids',
      dest='sorted_image_ids',
      default=True,
      help='Whether the image source_ids are sortable to deal with '
      'date_captured tie-breaks.')
  parser.add_argument(
      '--image_ids_to_keep',
      dest='image_ids_to_keep',
      default='All',
      help='Path to .json list of image ids to keep, used for ground truth '
      'eval creation.')
  parser.add_argument(
      '--keep_context_features_image_id_list',
      dest='keep_context_features_image_id_list',
      default=False,
      help='Whether or not to keep a list of the image_ids corresponding to '
      'the memory bank.')
  parser.add_argument(
      '--keep_only_positives',
      dest='keep_only_positives',
      default=False,
      help='Whether or not to keep only positive boxes based on score.')
  parser.add_argument(
      '--context_features_score_threshold',
      dest='context_features_score_threshold',
      default=0.7,
      help='What score threshold to use for boxes in context_features, when '
      '`keep_only_positives` is set to `True`.')
  parser.add_argument(
      '--keep_only_positives_gt',
      dest='keep_only_positives_gt',
      default=False,
      help='Whether or not to keep only positive boxes based on gt class.')
  parser.add_argument(
      '--max_num_elements_in_context_features',
      dest='max_num_elements_in_context_features',
      default=2000,
      help='Sets max number of context feature elements per memory bank. '
      'If the number of images in the context group is greater than '
      '`max_num_elements_in_context_features`, the context group will be split.'
      )
  parser.add_argument(
      '--output_type',
      dest='output_type',
      default='tf_example',
      help='Output type, one of `tf_example`, `tf_sequence_example`.')
  parser.add_argument(
      '--max_clip_length',
      dest='max_clip_length',
      default=None,
      help='Max length for sequence example outputs.')
  parser.add_argument(
      '--num_shards',
      dest='num_shards',
      default=0,
      help='Number of output shards.')
  beam_args, pipeline_args = parser.parse_known_args(argv)
  return beam_args, pipeline_args


def main(argv=None, save_main_session=True):
  """Runs the Beam pipeline that performs inference.

  Args:
    argv: Command line arguments.
    save_main_session: Whether to save the main session.
  """
  args, pipeline_args = parse_args(argv)

  pipeline_options = beam.options.pipeline_options.PipelineOptions(
            pipeline_args)
  pipeline_options.view_as(
      beam.options.pipeline_options.SetupOptions).save_main_session = (
          save_main_session)

  dirname = os.path.dirname(args.output_tfrecord)
  tf.io.gfile.makedirs(dirname)

  p = beam.Pipeline(options=pipeline_options)

  construct_pipeline(
      p,
      args.input_tfrecord,
      args.output_tfrecord,
      args.sequence_key,
      args.time_horizon,
      args.subsample_context_features_rate,
      args.reduce_image_size,
      args.max_image_dimension,
      args.add_context_features,
      args.sorted_image_ids,
      args.image_ids_to_keep,
      args.keep_context_features_image_id_list,
      args.keep_only_positives,
      args.context_features_score_threshold,
      args.keep_only_positives_gt,
      args.max_num_elements_in_context_features,
      args.num_shards,
      args.output_type,
      args.max_clip_length,
      args.context_feature_length)

  p.run()


if __name__ == '__main__':
  main()