nideep/datasets/nyudv2/test_nyudv2_to_lmdb.py
from nose.tools import assert_false, assert_greater, assert_is_instance, \
assert_list_equal, assert_raises, assert_true, assert_equal
from mock import patch
import os
import tempfile
import shutil
import numpy as np
from scipy import io
import h5py
import nideep.datasets.nyudv2.nyudv2_to_lmdb as n2l
import nideep.iow.to_lmdb as tol
import nideep.iow.read_lmdb as rl
import nideep.iow.copy_lmdb as cl
import caffe
class TestHandlingSplitsFile:
@classmethod
def setup_class(self):
self.dir_tmp = tempfile.mkdtemp()
self.path_splits = os.path.join(self.dir_tmp, 'foo.mat')
data = {'testNdxs': np.array([[2], [4], [10]])}
io.savemat(self.path_splits, data, oned_as='column')
self.path_other = os.path.join(self.dir_tmp, 'bar.mat')
data = {'foo': np.array([[2], [4]])}
io.savemat(self.path_other, data, oned_as='column')
@classmethod
def teardown_class(self):
shutil.rmtree(self.dir_tmp)
def test_invalid_path_dir(self):
assert_raises(IOError, n2l.split_matfile_to_val_list, os.curdir)
def test_invalid_path(self):
assert_raises(IOError, n2l.split_matfile_to_val_list, '/foo/bar.mat')
def test_invalid_ext(self):
fpath = os.path.join(self.dir_tmp, 'foo.txt')
with open(fpath, 'w') as f:
f.write('hello')
assert_true(os.path.isfile(fpath))
assert_raises(IOError, n2l.split_matfile_to_val_list, fpath)
def test_val_list(self):
val_list = n2l.split_matfile_to_val_list(self.path_splits)
assert_is_instance(val_list, list)
assert_list_equal(val_list, [1, 3, 9])
def test_val_list_other(self):
assert_raises(KeyError, n2l.split_matfile_to_val_list, self.path_other)
class TestBigArrToArrs:
def test_big_arr_to_arrs_single(self):
x = np.array([[[ 1, 2, 3],
[ 4, 5, 6]
],
[[ 7, 8, 9],
[10, 11, 12]
],
[[13, 14, 15],
[16, 17, 18],
],
[[19, 20, 21],
[22, 23, 24]
]
])
y = np.expand_dims(x, axis=0)
z = n2l.big_arr_to_arrs(y)
assert_is_instance(z, list)
assert_equal(len(z), 1)
for i in range(3):
for j in range(4):
for k in range(2):
assert_equal(z[0][j][i][k], x[j][k][i])
class TestNYUDV2ToLMDB:
@classmethod
def setup_class(self):
self.dir_tmp = tempfile.mkdtemp()
@classmethod
def teardown_class(self):
shutil.rmtree(self.dir_tmp)
def test_validate_path_mat_dir_exists(self):
assert_raises(IOError, n2l.nyudv2_to_lmdb, os.curdir, "", self.dir_tmp)
def test_validate_path_mat_ext(self):
p = os.path.join(self.dir_tmp, "foo.txt")
with open(p, 'w') as f:
f.write('foo')
assert_true(os.path.isfile(p))
assert_raises(IOError, n2l.nyudv2_to_lmdb, p, "", self.dir_tmp)
@patch('nideep.datasets.nyudv2.nyudv2_to_lmdb.to_lmdb.caffe')
@patch('nideep.datasets.nyudv2.nyudv2_to_lmdb.to_lmdb.caffe.proto.caffe_pb2.Datum')
def test_nyudv2_to_lmdb_info(self, mock_dat, mock_caffe):
# mock caffe calls made by our module
mock_dat.return_value.SerializeToString.return_value = 'x'
mock_caffe.io.array_to_datum.return_value = caffe.proto.caffe_pb2.Datum()
x = np.array([[[ 1, 2, 3],
[ 4, 5, 6]
],
[[ 7, 8, 9],
[10, 11, 12]
],
[[13, 14, 15],
[16, 17, 18],
],
[[19, 20, 21],
[22, 23, 24]
]
])
imgs = np.expand_dims(x, axis=1).astype(float)
imgs = np.tile(imgs, (3, 1, 1))
dat = {n2l.NYUDV2DataType.IMAGES : imgs,
n2l.NYUDV2DataType.LABELS : x.astype(int) + 1,
n2l.NYUDV2DataType.DEPTHS : x.astype(float) + 2
}
p = os.path.join(self.dir_tmp, 'foo.mat')
io.savemat(p, dat)
prefix = 'xyz_'
lmdb_info = n2l.nyudv2_to_lmdb(p, prefix, self.dir_tmp)
assert_is_instance(lmdb_info, list)
for info_ in lmdb_info:
n = info_[0]
plmdb = info_[-1]
assert_true(os.path.isdir(plmdb))
if 'val' in os.path.basename(plmdb):
assert_equal(n, 0)
@patch('nideep.datasets.nyudv2.nyudv2_to_lmdb.to_lmdb.caffe')
@patch('nideep.datasets.nyudv2.nyudv2_to_lmdb.to_lmdb.caffe.proto.caffe_pb2.Datum')
def test_nyudv2_to_lmdb_info_mat73(self, mock_dat, mock_caffe):
# mock caffe calls made by our module
mock_dat.return_value.SerializeToString.return_value = 'x'
mock_caffe.io.array_to_datum.return_value = caffe.proto.caffe_pb2.Datum()
x = np.array([[[ 1, 2, 3],
[ 4, 5, 6]
],
[[ 7, 8, 9],
[10, 11, 12]
],
[[13, 14, 15],
[16, 17, 18],
],
[[19, 20, 21],
[22, 23, 24]
]
])
imgs = np.expand_dims(x, axis=1).astype(float)
imgs = np.tile(imgs, (3, 1, 1))
p = os.path.join(self.dir_tmp, 'foo.mat')
with h5py.File(p, "w") as f:
f.create_dataset(n2l.NYUDV2DataType.IMAGES, data=imgs)
f.create_dataset(n2l.NYUDV2DataType.LABELS, data=x.astype(int) + 1)
f.create_dataset(n2l.NYUDV2DataType.DEPTHS, data=x.astype(float) + 2)
prefix = 'xyz_'
lmdb_info = n2l.nyudv2_to_lmdb(p, prefix, self.dir_tmp)
assert_is_instance(lmdb_info, list)
for info_ in lmdb_info:
n = info_[0]
plmdb = info_[-1]
assert_true(os.path.isdir(plmdb))
if 'val' in os.path.basename(plmdb):
assert_equal(n, 0)
@patch('nideep.datasets.nyudv2.nyudv2_to_lmdb.to_lmdb.caffe')
@patch('nideep.datasets.nyudv2.nyudv2_to_lmdb.to_lmdb.caffe.proto.caffe_pb2.Datum')
def test_nyudv2_to_lmdb_info_hdf5(self, mock_dat, mock_caffe):
# mock caffe calls made by our module
mock_dat.return_value.SerializeToString.return_value = 'x'
mock_caffe.io.array_to_datum.return_value = caffe.proto.caffe_pb2.Datum()
x = np.array([[[ 1, 2, 3],
[ 4, 5, 6]
],
[[ 7, 8, 9],
[10, 11, 12]
],
[[13, 14, 15],
[16, 17, 18],
],
[[19, 20, 21],
[22, 23, 24]
]
])
imgs = np.expand_dims(x, axis=1).astype(float)
imgs = np.tile(imgs, (3, 1, 1))
p = os.path.join(self.dir_tmp, 'foo.h5')
with h5py.File(p, "w") as f:
f.create_dataset(n2l.NYUDV2DataType.IMAGES, data=imgs)
f.create_dataset(n2l.NYUDV2DataType.LABELS, data=x.astype(int) + 1)
f.create_dataset(n2l.NYUDV2DataType.DEPTHS, data=x.astype(float) + 2)
prefix = 'xyz_'
lmdb_info = n2l.nyudv2_to_lmdb(p, prefix, self.dir_tmp)
assert_is_instance(lmdb_info, list)
for info_ in lmdb_info:
n = info_[0]
plmdb = info_[-1]
assert_true(os.path.isdir(plmdb))
if 'val' in os.path.basename(plmdb):
assert_equal(n, 0)
@patch('nideep.datasets.nyudv2.nyudv2_to_lmdb.to_lmdb.caffe')
@patch('nideep.datasets.nyudv2.nyudv2_to_lmdb.to_lmdb.caffe.proto.caffe_pb2.Datum')
def test_nyudv2_to_lmdb_info_hdf5_2(self, mock_dat, mock_caffe):
# mock caffe calls made by our module
mock_dat.return_value.SerializeToString.return_value = 'x'
mock_caffe.io.array_to_datum.return_value = caffe.proto.caffe_pb2.Datum()
x = np.array([[[ 1, 2, 3],
[ 4, 5, 6]
],
[[ 7, 8, 9],
[10, 11, 12]
],
[[13, 14, 15],
[16, 17, 18],
],
[[19, 20, 21],
[22, 23, 24]
]
])
imgs = np.expand_dims(x, axis=1).astype(float)
imgs = np.tile(imgs, (3, 1, 1))
p = os.path.join(self.dir_tmp, 'foo.hdf5')
with h5py.File(p, "w") as f:
f.create_dataset(n2l.NYUDV2DataType.IMAGES, data=imgs)
f.create_dataset(n2l.NYUDV2DataType.LABELS, data=x.astype(int) + 1)
f.create_dataset(n2l.NYUDV2DataType.DEPTHS, data=x.astype(float) + 2)
prefix = 'xyz_'
lmdb_info = n2l.nyudv2_to_lmdb(p, prefix, self.dir_tmp)
assert_is_instance(lmdb_info, list)
for info_ in lmdb_info:
n = info_[0]
plmdb = info_[-1]
assert_true(os.path.isdir(plmdb))
if 'val' in os.path.basename(plmdb):
assert_equal(n, 0)
class TestShiftLabelLMDB:
@classmethod
def setup_class(self):
self.dir_tmp = tempfile.mkdtemp()
x = np.array([[[ 0, 2, 3],
[ 4, 5, 6]
],
[[ 7, 8, 9],
[10, 11, 12]
],
[[13, 14, 15],
[16, 17, 18],
],
[[19, 20, 21],
[22, 23, 0]
]
])
tol.arrays_to_lmdb([y for y in x], os.path.join(self.dir_tmp, 'x_lmdb'))
@classmethod
def teardown_class(self):
shutil.rmtree(self.dir_tmp)
def test_shift_label_lmdb(self):
path_src = os.path.join(self.dir_tmp, 'x_lmdb')
x = rl.read_values(path_src)
assert_greater(len(x), 0, "This test needs non empty data.")
path_dst = os.path.join(self.dir_tmp, 'test_shift_label_lmdb')
keys = range(0, len(x), 2)
assert_greater(len(keys), 0, "This test needs a non-empty subset.")
assert_greater(len(x), len(keys), "Need subset, not all elements.")
n2l.shift_label_lmdb(path_src, path_dst)
assert_true(os.path.isdir(path_dst), "failed to save LMDB")
y = rl.read_values(path_dst)
assert_equal(len(x), len(y), "Wrong number of elements copied.")
# for x_val, x_label, y_val, y_label in [a+b for a, b in zip(x, y)]:
for (x_val, x_label), (y_val, y_label) in zip(x, y):
assert_true(np.all(x_val[x_val != 0] - 1 == y_val[x_val != 0]), "Wrong content copied for non-void label.")
assert_true(np.all(255 == y_val[x_val == 0]), "Wrong content copied for void label.")
assert_true(np.all(x_label == y_label), "Wrong content copied for label field.")