tensorflow/models

View on GitHub
research/vid2depth/dataset/gen_data.py

Summary

Maintainability
B
4 hrs
Test Coverage
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Generates data for training/validation and save it to disk."""

# Example usage:
#
# python dataset/gen_data.py \
#   --alsologtostderr \
#   --dataset_name kitti_raw_eigen \
#   --dataset_dir ~/vid2depth/dataset/kitti-raw-uncompressed \
#   --data_dir ~/vid2depth/data/kitti_raw_eigen_s3 \
#   --seq_length 3 \
#   --num_threads 12

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools
import multiprocessing
import os
from absl import app
from absl import flags
from absl import logging
import dataset_loader
import numpy as np
import scipy.misc
import tensorflow as tf

gfile = tf.gfile
FLAGS = flags.FLAGS

DATASETS = [
    'kitti_raw_eigen', 'kitti_raw_stereo', 'kitti_odom', 'cityscapes', 'bike'
]

flags.DEFINE_enum('dataset_name', None, DATASETS, 'Dataset name.')
flags.DEFINE_string('dataset_dir', None, 'Location for dataset source files.')
flags.DEFINE_string('data_dir', None, 'Where to save the generated data.')
# Note: Training time grows linearly with sequence length.  Use 2 or 3.
flags.DEFINE_integer('seq_length', 3, 'Length of each training sequence.')
flags.DEFINE_integer('img_height', 128, 'Image height.')
flags.DEFINE_integer('img_width', 416, 'Image width.')
flags.DEFINE_integer(
    'num_threads', None, 'Number of worker threads. '
    'Defaults to number of CPU cores.')

flags.mark_flag_as_required('dataset_name')
flags.mark_flag_as_required('dataset_dir')
flags.mark_flag_as_required('data_dir')

# Process data in chunks for reporting progress.
NUM_CHUNKS = 100


def _generate_data():
  """Extract sequences from dataset_dir and store them in data_dir."""
  if not gfile.Exists(FLAGS.data_dir):
    gfile.MakeDirs(FLAGS.data_dir)

  global dataloader  # pylint: disable=global-variable-undefined
  if FLAGS.dataset_name == 'bike':
    dataloader = dataset_loader.Bike(FLAGS.dataset_dir,
                                     img_height=FLAGS.img_height,
                                     img_width=FLAGS.img_width,
                                     seq_length=FLAGS.seq_length)
  elif FLAGS.dataset_name == 'kitti_odom':
    dataloader = dataset_loader.KittiOdom(FLAGS.dataset_dir,
                                          img_height=FLAGS.img_height,
                                          img_width=FLAGS.img_width,
                                          seq_length=FLAGS.seq_length)
  elif FLAGS.dataset_name == 'kitti_raw_eigen':
    dataloader = dataset_loader.KittiRaw(FLAGS.dataset_dir,
                                         split='eigen',
                                         img_height=FLAGS.img_height,
                                         img_width=FLAGS.img_width,
                                         seq_length=FLAGS.seq_length)
  elif FLAGS.dataset_name == 'kitti_raw_stereo':
    dataloader = dataset_loader.KittiRaw(FLAGS.dataset_dir,
                                         split='stereo',
                                         img_height=FLAGS.img_height,
                                         img_width=FLAGS.img_width,
                                         seq_length=FLAGS.seq_length)
  elif FLAGS.dataset_name == 'cityscapes':
    dataloader = dataset_loader.Cityscapes(FLAGS.dataset_dir,
                                           img_height=FLAGS.img_height,
                                           img_width=FLAGS.img_width,
                                           seq_length=FLAGS.seq_length)
  else:
    raise ValueError('Unknown dataset')

  # The default loop below uses multiprocessing, which can make it difficult
  # to locate source of errors in data loader classes.
  # Uncomment this loop for easier debugging:

  # all_examples = {}
  # for i in range(dataloader.num_train):
  #   _gen_example(i, all_examples)
  #   logging.info('Generated: %d', len(all_examples))

  all_frames = range(dataloader.num_train)
  frame_chunks = np.array_split(all_frames, NUM_CHUNKS)

  manager = multiprocessing.Manager()
  all_examples = manager.dict()
  num_cores = multiprocessing.cpu_count()
  num_threads = num_cores if FLAGS.num_threads is None else FLAGS.num_threads
  pool = multiprocessing.Pool(num_threads)

  # Split into training/validation sets. Fixed seed for repeatability.
  np.random.seed(8964)

  if not gfile.Exists(FLAGS.data_dir):
    gfile.MakeDirs(FLAGS.data_dir)

  with gfile.Open(os.path.join(FLAGS.data_dir, 'train.txt'), 'w') as train_f:
    with gfile.Open(os.path.join(FLAGS.data_dir, 'val.txt'), 'w') as val_f:
      logging.info('Generating data...')
      for index, frame_chunk in enumerate(frame_chunks):
        all_examples.clear()
        pool.map(_gen_example_star,
                 itertools.izip(frame_chunk, itertools.repeat(all_examples)))
        logging.info('Chunk %d/%d: saving %s entries...', index + 1, NUM_CHUNKS,
                     len(all_examples))
        for _, example in all_examples.items():
          if example:
            s = example['folder_name']
            frame = example['file_name']
            if np.random.random() < 0.1:
              val_f.write('%s %s\n' % (s, frame))
            else:
              train_f.write('%s %s\n' % (s, frame))
  pool.close()
  pool.join()


def _gen_example(i, all_examples):
  """Saves one example to file.  Also adds it to all_examples dict."""
  example = dataloader.get_example_with_index(i)
  if not example:
    return
  image_seq_stack = _stack_image_seq(example['image_seq'])
  example.pop('image_seq', None)  # Free up memory.
  intrinsics = example['intrinsics']
  fx = intrinsics[0, 0]
  fy = intrinsics[1, 1]
  cx = intrinsics[0, 2]
  cy = intrinsics[1, 2]
  save_dir = os.path.join(FLAGS.data_dir, example['folder_name'])
  if not gfile.Exists(save_dir):
    gfile.MakeDirs(save_dir)
  img_filepath = os.path.join(save_dir, '%s.jpg' % example['file_name'])
  scipy.misc.imsave(img_filepath, image_seq_stack.astype(np.uint8))
  cam_filepath = os.path.join(save_dir, '%s_cam.txt' % example['file_name'])
  example['cam'] = '%f,0.,%f,0.,%f,%f,0.,0.,1.' % (fx, cx, fy, cy)
  with open(cam_filepath, 'w') as cam_f:
    cam_f.write(example['cam'])

  key = example['folder_name'] + '_' + example['file_name']
  all_examples[key] = example


def _gen_example_star(params):
  return _gen_example(*params)


def _stack_image_seq(seq):
  for i, im in enumerate(seq):
    if i == 0:
      res = im
    else:
      res = np.hstack((res, im))
  return res


def main(_):
  _generate_data()


if __name__ == '__main__':
  app.run(main)