research/vid2depth/dataset/dataset_loader.py
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes to load KITTI and Cityscapes data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import json
import os
import re
from absl import logging
import numpy as np
import scipy.misc
CITYSCAPES_CROP_BOTTOM = True # Crop bottom 25% to remove the car hood.
CITYSCAPES_CROP_PCT = 0.75
CITYSCAPES_SAMPLE_EVERY = 2 # Sample every 2 frames to match KITTI frame rate.
BIKE_SAMPLE_EVERY = 6 # 5fps, since the bike's motion is slower.
class Bike(object):
"""Load bike video frames."""
def __init__(self,
dataset_dir,
img_height=128,
img_width=416,
seq_length=3,
sample_every=BIKE_SAMPLE_EVERY):
self.dataset_dir = dataset_dir
self.img_height = img_height
self.img_width = img_width
self.seq_length = seq_length
self.sample_every = sample_every
self.frames = self.collect_frames()
self.num_frames = len(self.frames)
self.num_train = self.num_frames
logging.info('Total frames collected: %d', self.num_frames)
def collect_frames(self):
"""Create a list of unique ids for available frames."""
video_list = os.listdir(self.dataset_dir)
logging.info('video_list: %s', video_list)
frames = []
for video in video_list:
im_files = glob.glob(os.path.join(self.dataset_dir, video, '*.jpg'))
im_files = sorted(im_files, key=natural_keys)
# Adding 3 crops of the video.
frames.extend(['A' + video + '/' + os.path.basename(f) for f in im_files])
frames.extend(['B' + video + '/' + os.path.basename(f) for f in im_files])
frames.extend(['C' + video + '/' + os.path.basename(f) for f in im_files])
return frames
def get_example_with_index(self, target_index):
if not self.is_valid_sample(target_index):
return False
example = self.load_example(target_index)
return example
def load_intrinsics(self, unused_frame_idx, cy):
"""Load intrinsics."""
# https://www.wired.com/2013/05/calculating-the-angular-view-of-an-iphone/
# https://codeyarns.com/2015/09/08/how-to-compute-intrinsic-camera-matrix-for-a-camera/
# https://stackoverflow.com/questions/39992968/how-to-calculate-field-of-view-of-the-camera-from-camera-intrinsic-matrix
# # iPhone: These numbers are for images with resolution 720 x 1280.
# Assuming FOV = 50.9 => fx = (1280 // 2) / math.tan(fov / 2) = 1344.8
intrinsics = np.array([[1344.8, 0, 1280 // 2],
[0, 1344.8, cy],
[0, 0, 1.0]])
return intrinsics
def is_valid_sample(self, target_index):
"""Checks whether we can find a valid sequence around this frame."""
target_video, _ = self.frames[target_index].split('/')
start_index, end_index = get_seq_start_end(target_index,
self.seq_length,
self.sample_every)
if start_index < 0 or end_index >= self.num_frames:
return False
start_video, _ = self.frames[start_index].split('/')
end_video, _ = self.frames[end_index].split('/')
if target_video == start_video and target_video == end_video:
return True
return False
def load_image_raw(self, frame_id):
"""Reads the image and crops it according to first letter of frame_id."""
crop_type = frame_id[0]
img_file = os.path.join(self.dataset_dir, frame_id[1:])
img = scipy.misc.imread(img_file)
allowed_height = int(img.shape[1] * self.img_height / self.img_width)
# Starting height for the middle crop.
mid_crop_top = int(img.shape[0] / 2 - allowed_height / 2)
# How much to go up or down to get the other two crops.
height_var = int(mid_crop_top / 3)
if crop_type == 'A':
crop_top = mid_crop_top - height_var
cy = allowed_height / 2 + height_var
elif crop_type == 'B':
crop_top = mid_crop_top
cy = allowed_height / 2
elif crop_type == 'C':
crop_top = mid_crop_top + height_var
cy = allowed_height / 2 - height_var
else:
raise ValueError('Unknown crop_type: %s' % crop_type)
crop_bottom = crop_top + allowed_height + 1
return img[crop_top:crop_bottom, :, :], cy
def load_image_sequence(self, target_index):
"""Returns a list of images around target index."""
start_index, end_index = get_seq_start_end(target_index,
self.seq_length,
self.sample_every)
image_seq = []
for idx in range(start_index, end_index + 1, self.sample_every):
frame_id = self.frames[idx]
img, cy = self.load_image_raw(frame_id)
if idx == target_index:
zoom_y = self.img_height / img.shape[0]
zoom_x = self.img_width / img.shape[1]
img = scipy.misc.imresize(img, (self.img_height, self.img_width))
image_seq.append(img)
return image_seq, zoom_x, zoom_y, cy
def load_example(self, target_index):
"""Returns a sequence with requested target frame."""
image_seq, zoom_x, zoom_y, cy = self.load_image_sequence(target_index)
target_video, target_filename = self.frames[target_index].split('/')
# Put A, B, C at the end for better shuffling.
target_video = target_video[1:] + target_video[0]
intrinsics = self.load_intrinsics(target_index, cy)
intrinsics = self.scale_intrinsics(intrinsics, zoom_x, zoom_y)
example = {}
example['intrinsics'] = intrinsics
example['image_seq'] = image_seq
example['folder_name'] = target_video
example['file_name'] = target_filename.split('.')[0]
return example
def scale_intrinsics(self, mat, sx, sy):
out = np.copy(mat)
out[0, 0] *= sx
out[0, 2] *= sx
out[1, 1] *= sy
out[1, 2] *= sy
return out
class KittiRaw(object):
"""Reads KITTI raw data files."""
def __init__(self,
dataset_dir,
split,
load_pose=False,
img_height=128,
img_width=416,
seq_length=3):
static_frames_file = 'dataset/kitti/static_frames.txt'
test_scene_file = 'dataset/kitti/test_scenes_' + split + '.txt'
with open(get_resource_path(test_scene_file), 'r') as f:
test_scenes = f.readlines()
self.test_scenes = [t[:-1] for t in test_scenes]
self.dataset_dir = dataset_dir
self.img_height = img_height
self.img_width = img_width
self.seq_length = seq_length
self.load_pose = load_pose
self.cam_ids = ['02', '03']
self.date_list = [
'2011_09_26', '2011_09_28', '2011_09_29', '2011_09_30', '2011_10_03'
]
self.collect_static_frames(static_frames_file)
self.collect_train_frames()
def collect_static_frames(self, static_frames_file):
with open(get_resource_path(static_frames_file), 'r') as f:
frames = f.readlines()
self.static_frames = []
for fr in frames:
if fr == '\n':
continue
unused_date, drive, frame_id = fr.split(' ')
fid = '%.10d' % (np.int(frame_id[:-1]))
for cam_id in self.cam_ids:
self.static_frames.append(drive + ' ' + cam_id + ' ' + fid)
def collect_train_frames(self):
"""Creates a list of training frames."""
all_frames = []
for date in self.date_list:
date_dir = os.path.join(self.dataset_dir, date)
drive_set = os.listdir(date_dir)
for dr in drive_set:
drive_dir = os.path.join(date_dir, dr)
if os.path.isdir(drive_dir):
if dr[:-5] in self.test_scenes:
continue
for cam in self.cam_ids:
img_dir = os.path.join(drive_dir, 'image_' + cam, 'data')
num_frames = len(glob.glob(img_dir + '/*[0-9].png'))
for i in range(num_frames):
frame_id = '%.10d' % i
all_frames.append(dr + ' ' + cam + ' ' + frame_id)
for s in self.static_frames:
try:
all_frames.remove(s)
except ValueError:
pass
self.train_frames = all_frames
self.num_train = len(self.train_frames)
def is_valid_sample(self, frames, target_index):
"""Checks whether we can find a valid sequence around this frame."""
num_frames = len(frames)
target_drive, cam_id, _ = frames[target_index].split(' ')
start_index, end_index = get_seq_start_end(target_index, self.seq_length)
if start_index < 0 or end_index >= num_frames:
return False
start_drive, start_cam_id, _ = frames[start_index].split(' ')
end_drive, end_cam_id, _ = frames[end_index].split(' ')
if (target_drive == start_drive and target_drive == end_drive and
cam_id == start_cam_id and cam_id == end_cam_id):
return True
return False
def get_example_with_index(self, target_index):
if not self.is_valid_sample(self.train_frames, target_index):
return False
example = self.load_example(self.train_frames, target_index)
return example
def load_image_sequence(self, frames, target_index):
"""Returns a sequence with requested target frame."""
start_index, end_index = get_seq_start_end(target_index, self.seq_length)
image_seq = []
for index in range(start_index, end_index + 1):
drive, cam_id, frame_id = frames[index].split(' ')
img = self.load_image_raw(drive, cam_id, frame_id)
if index == target_index:
zoom_y = self.img_height / img.shape[0]
zoom_x = self.img_width / img.shape[1]
img = scipy.misc.imresize(img, (self.img_height, self.img_width))
image_seq.append(img)
return image_seq, zoom_x, zoom_y
def load_pose_sequence(self, frames, target_index):
"""Returns a sequence of pose vectors for frames around the target frame."""
target_drive, _, target_frame_id = frames[target_index].split(' ')
target_pose = self.load_pose_raw(target_drive, target_frame_id)
start_index, end_index = get_seq_start_end(target_frame_id, self.seq_length)
pose_seq = []
for index in range(start_index, end_index + 1):
if index == target_frame_id:
continue
drive, _, frame_id = frames[index].split(' ')
pose = self.load_pose_raw(drive, frame_id)
# From target to index.
pose = np.dot(np.linalg.inv(pose), target_pose)
pose_seq.append(pose)
return pose_seq
def load_example(self, frames, target_index):
"""Returns a sequence with requested target frame."""
image_seq, zoom_x, zoom_y = self.load_image_sequence(frames, target_index)
target_drive, target_cam_id, target_frame_id = (
frames[target_index].split(' '))
intrinsics = self.load_intrinsics_raw(target_drive, target_cam_id)
intrinsics = self.scale_intrinsics(intrinsics, zoom_x, zoom_y)
example = {}
example['intrinsics'] = intrinsics
example['image_seq'] = image_seq
example['folder_name'] = target_drive + '_' + target_cam_id + '/'
example['file_name'] = target_frame_id
if self.load_pose:
pose_seq = self.load_pose_sequence(frames, target_index)
example['pose_seq'] = pose_seq
return example
def load_pose_raw(self, drive, frame_id):
date = drive[:10]
pose_file = os.path.join(self.dataset_dir, date, drive, 'poses',
frame_id + '.txt')
with open(pose_file, 'r') as f:
pose = f.readline()
pose = np.array(pose.split(' ')).astype(np.float32).reshape(3, 4)
pose = np.vstack((pose, np.array([0, 0, 0, 1]).reshape((1, 4))))
return pose
def load_image_raw(self, drive, cam_id, frame_id):
date = drive[:10]
img_file = os.path.join(self.dataset_dir, date, drive, 'image_' + cam_id,
'data', frame_id + '.png')
img = scipy.misc.imread(img_file)
return img
def load_intrinsics_raw(self, drive, cam_id):
date = drive[:10]
calib_file = os.path.join(self.dataset_dir, date, 'calib_cam_to_cam.txt')
filedata = self.read_raw_calib_file(calib_file)
p_rect = np.reshape(filedata['P_rect_' + cam_id], (3, 4))
intrinsics = p_rect[:3, :3]
return intrinsics
# From https://github.com/utiasSTARS/pykitti/blob/master/pykitti/utils.py
def read_raw_calib_file(self, filepath):
"""Read in a calibration file and parse into a dictionary."""
data = {}
with open(filepath, 'r') as f:
for line in f:
key, value = line.split(':', 1)
# The only non-float values in these files are dates, which we don't
# care about.
try:
data[key] = np.array([float(x) for x in value.split()])
except ValueError:
pass
return data
def scale_intrinsics(self, mat, sx, sy):
out = np.copy(mat)
out[0, 0] *= sx
out[0, 2] *= sx
out[1, 1] *= sy
out[1, 2] *= sy
return out
class KittiOdom(object):
"""Reads KITTI odometry data files."""
def __init__(self, dataset_dir, img_height=128, img_width=416, seq_length=3):
self.dataset_dir = dataset_dir
self.img_height = img_height
self.img_width = img_width
self.seq_length = seq_length
self.train_seqs = [0, 1, 2, 3, 4, 5, 6, 7, 8]
self.test_seqs = [9, 10]
self.collect_test_frames()
self.collect_train_frames()
def collect_test_frames(self):
self.test_frames = []
for seq in self.test_seqs:
seq_dir = os.path.join(self.dataset_dir, 'sequences', '%.2d' % seq)
img_dir = os.path.join(seq_dir, 'image_2')
num_frames = len(glob.glob(os.path.join(img_dir, '*.png')))
for n in range(num_frames):
self.test_frames.append('%.2d %.6d' % (seq, n))
self.num_test = len(self.test_frames)
def collect_train_frames(self):
self.train_frames = []
for seq in self.train_seqs:
seq_dir = os.path.join(self.dataset_dir, 'sequences', '%.2d' % seq)
img_dir = os.path.join(seq_dir, 'image_2')
num_frames = len(glob.glob(img_dir + '/*.png'))
for n in range(num_frames):
self.train_frames.append('%.2d %.6d' % (seq, n))
self.num_train = len(self.train_frames)
def is_valid_sample(self, frames, target_frame_index):
"""Checks whether we can find a valid sequence around this frame."""
num_frames = len(frames)
target_frame_drive, _ = frames[target_frame_index].split(' ')
start_index, end_index = get_seq_start_end(target_frame_index,
self.seq_length)
if start_index < 0 or end_index >= num_frames:
return False
start_drive, _ = frames[start_index].split(' ')
end_drive, _ = frames[end_index].split(' ')
if target_frame_drive == start_drive and target_frame_drive == end_drive:
return True
return False
def load_image_sequence(self, frames, target_frame_index):
"""Returns a sequence with requested target frame."""
start_index, end_index = get_seq_start_end(target_frame_index,
self.seq_length)
image_seq = []
for index in range(start_index, end_index + 1):
drive, frame_id = frames[index].split(' ')
img = self.load_image(drive, frame_id)
if index == target_frame_index:
zoom_y = self.img_height / img.shape[0]
zoom_x = self.img_width / img.shape[1]
img = scipy.misc.imresize(img, (self.img_height, self.img_width))
image_seq.append(img)
return image_seq, zoom_x, zoom_y
def load_example(self, frames, target_frame_index):
"""Returns a sequence with requested target frame."""
image_seq, zoom_x, zoom_y = self.load_image_sequence(frames,
target_frame_index)
target_frame_drive, target_frame_id = frames[target_frame_index].split(' ')
intrinsics = self.load_intrinsics(target_frame_drive, target_frame_id)
intrinsics = self.scale_intrinsics(intrinsics, zoom_x, zoom_y)
example = {}
example['intrinsics'] = intrinsics
example['image_seq'] = image_seq
example['folder_name'] = target_frame_drive
example['file_name'] = target_frame_id
return example
def get_example_with_index(self, target_frame_index):
if not self.is_valid_sample(self.train_frames, target_frame_index):
return False
example = self.load_example(self.train_frames, target_frame_index)
return example
def load_image(self, drive, frame_id):
img_file = os.path.join(self.dataset_dir, 'sequences',
'%s/image_2/%s.png' % (drive, frame_id))
img = scipy.misc.imread(img_file)
return img
def load_intrinsics(self, drive, unused_frame_id):
calib_file = os.path.join(self.dataset_dir, 'sequences',
'%s/calib.txt' % drive)
proj_c2p, _ = self.read_calib_file(calib_file)
intrinsics = proj_c2p[:3, :3]
return intrinsics
def read_calib_file(self, filepath, cam_id=2):
"""Read in a calibration file and parse into a dictionary."""
def parse_line(line, shape):
data = line.split()
data = np.array(data[1:]).reshape(shape).astype(np.float32)
return data
with open(filepath, 'r') as f:
mat = f.readlines()
proj_c2p = parse_line(mat[cam_id], shape=(3, 4))
proj_v2c = parse_line(mat[-1], shape=(3, 4))
filler = np.array([0, 0, 0, 1]).reshape((1, 4))
proj_v2c = np.concatenate((proj_v2c, filler), axis=0)
return proj_c2p, proj_v2c
def scale_intrinsics(self, mat, sx, sy):
out = np.copy(mat)
out[0, 0] *= sx
out[0, 2] *= sx
out[1, 1] *= sy
out[1, 2] *= sy
return out
class Cityscapes(object):
"""Reads Cityscapes data files."""
def __init__(self,
dataset_dir,
split='train',
crop_bottom=CITYSCAPES_CROP_BOTTOM, # Crop the car logo.
crop_pct=CITYSCAPES_CROP_PCT,
sample_every=CITYSCAPES_SAMPLE_EVERY,
img_height=128,
img_width=416,
seq_length=3):
self.dataset_dir = dataset_dir
self.split = split
self.crop_bottom = crop_bottom
self.crop_pct = crop_pct
self.sample_every = sample_every
self.img_height = img_height
self.img_width = img_width
self.seq_length = seq_length
self.frames = self.collect_frames(split)
self.num_frames = len(self.frames)
if split == 'train':
self.num_train = self.num_frames
else:
self.num_test = self.num_frames
logging.info('Total frames collected: %d', self.num_frames)
def collect_frames(self, split):
img_dir = os.path.join(self.dataset_dir, 'leftImg8bit_sequence', split)
city_list = os.listdir(img_dir)
frames = []
for city in city_list:
img_files = glob.glob(os.path.join(img_dir, city, '*.png'))
for f in img_files:
frame_id = os.path.basename(f).split('leftImg8bit')[0]
frames.append(frame_id)
return frames
def get_example_with_index(self, target_index):
target_frame_id = self.frames[target_index]
if not self.is_valid_example(target_frame_id):
return False
example = self.load_example(self.frames[target_index])
return example
def load_intrinsics(self, frame_id, split):
"""Read intrinsics data for frame."""
city, seq, _, _ = frame_id.split('_')
camera_file = os.path.join(self.dataset_dir, 'camera', split, city,
city + '_' + seq + '_*_camera.json')
camera_file = glob.glob(camera_file)[0]
with open(camera_file, 'r') as f:
camera = json.load(f)
fx = camera['intrinsic']['fx']
fy = camera['intrinsic']['fy']
u0 = camera['intrinsic']['u0']
v0 = camera['intrinsic']['v0']
# Cropping the bottom of the image and then resizing it to the same
# (height, width) amounts to stretching the image's height.
if self.crop_bottom:
fy *= 1.0 / self.crop_pct
intrinsics = np.array([[fx, 0, u0],
[0, fy, v0],
[0, 0, 1]])
return intrinsics
def is_valid_example(self, target_frame_id):
"""Checks whether we can find a valid sequence around this frame."""
city, snippet_id, target_local_frame_id, _ = target_frame_id.split('_')
start_index, end_index = get_seq_start_end(
int(target_local_frame_id), self.seq_length, self.sample_every)
for index in range(start_index, end_index + 1, self.sample_every):
local_frame_id = '%.6d' % index
frame_id = '%s_%s_%s_' % (city, snippet_id, local_frame_id)
image_filepath = os.path.join(self.dataset_dir, 'leftImg8bit_sequence',
self.split, city,
frame_id + 'leftImg8bit.png')
if not os.path.exists(image_filepath):
return False
return True
def load_image_sequence(self, target_frame_id):
"""Returns a sequence with requested target frame."""
city, snippet_id, target_local_frame_id, _ = target_frame_id.split('_')
start_index, end_index = get_seq_start_end(
int(target_local_frame_id), self.seq_length, self.sample_every)
image_seq = []
for index in range(start_index, end_index + 1, self.sample_every):
local_frame_id = '%.6d' % index
frame_id = '%s_%s_%s_' % (city, snippet_id, local_frame_id)
image_filepath = os.path.join(self.dataset_dir, 'leftImg8bit_sequence',
self.split, city,
frame_id + 'leftImg8bit.png')
img = scipy.misc.imread(image_filepath)
if self.crop_bottom:
ymax = int(img.shape[0] * self.crop_pct)
img = img[:ymax]
raw_shape = img.shape
if index == int(target_local_frame_id):
zoom_y = self.img_height / raw_shape[0]
zoom_x = self.img_width / raw_shape[1]
img = scipy.misc.imresize(img, (self.img_height, self.img_width))
image_seq.append(img)
return image_seq, zoom_x, zoom_y
def load_example(self, target_frame_id):
"""Returns a sequence with requested target frame."""
image_seq, zoom_x, zoom_y = self.load_image_sequence(target_frame_id)
intrinsics = self.load_intrinsics(target_frame_id, self.split)
intrinsics = self.scale_intrinsics(intrinsics, zoom_x, zoom_y)
example = {}
example['intrinsics'] = intrinsics
example['image_seq'] = image_seq
example['folder_name'] = target_frame_id.split('_')[0]
example['file_name'] = target_frame_id[:-1]
return example
def scale_intrinsics(self, mat, sx, sy):
out = np.copy(mat)
out[0, 0] *= sx
out[0, 2] *= sx
out[1, 1] *= sy
out[1, 2] *= sy
return out
def get_resource_path(relative_path):
return relative_path
def get_seq_start_end(target_index, seq_length, sample_every=1):
"""Returns absolute seq start and end indices for a given target frame."""
half_offset = int((seq_length - 1) / 2) * sample_every
end_index = target_index + half_offset
start_index = end_index - (seq_length - 1) * sample_every
return start_index, end_index
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
return [atoi(c) for c in re.split(r'(\d+)', text)]