rusty1s/embedded_gcnn

View on GitHub
lib/datasets/dataset.py

Summary

Maintainability
A
25 mins
Test Coverage
import numpy as np


class Datasets(object):
    def __init__(self, train, val, test):
        self.train = train
        self.val = val
        self.test = test

    @property
    def classes(self):
        raise NotImplementedError

    @property
    def width(self):
        raise NotImplementedError

    @property
    def height(self):
        raise NotImplementedError

    @property
    def num_channels(self):
        raise NotImplementedError

    @property
    def num_classes(self):
        return len(self.classes)

    def classnames(self, label):
        idx = np.where(label == 1)[0]
        return [self.classes[i] for i in idx]


class Dataset(object):
    def __init__(self, images, labels):
        self.epochs_completed = 0

        self._images = images
        self._labels = labels
        self._index_in_epoch = 0

    @property
    def num_examples(self):
        return self._labels.shape[0]

    def _random_shuffle_examples(self):
        perm = np.arange(self.num_examples)
        np.random.shuffle(perm)
        self._images = self._images[perm]
        self._labels = self._labels[perm]

    def next_batch(self, batch_size, shuffle=True):
        start = self._index_in_epoch

        # Shuffle for the first epoch.
        if self.epochs_completed == 0 and start == 0 and shuffle:
            self._random_shuffle_examples()

        if start + batch_size > self.num_examples:
            # Finished epoch.
            self.epochs_completed += 1

            # Get the rest examples in this epoch.
            rest_num_examples = self.num_examples - start
            images_rest = self._images[start:self.num_examples]
            labels_rest = self._labels[start:self.num_examples]

            # Shuffle the examples.
            if shuffle:
                self._random_shuffle_examples()

            # Start next epoch.
            start = 0
            self._index_in_epoch = batch_size - rest_num_examples
            end = self._index_in_epoch
            images_new = self._images[start:end]
            labels_new = self._labels[start:end]

            labels = np.concatenate((labels_rest, labels_new), axis=0)
            images = np.concatenate((images_rest, images_new), axis=0)
        else:
            # Just slice the examples.
            self._index_in_epoch += batch_size
            end = self._index_in_epoch
            images = self._images[start:end]
            labels = self._labels[start:end]

        return images, labels