rusty1s/embedded_gcnn

View on GitHub
lib/datasets/mnist_test.py

Summary

Maintainability
F
1 wk
Test Coverage
from unittest import TestCase

import numpy as np

from .mnist import MNIST

data = MNIST('data/mnist', val_size=10000)


class MNISTTest(TestCase):
    def test_init(self):
        self.assertEqual(data.train.num_examples, 50000)
        self.assertEqual(data.val.num_examples, 10000)
        self.assertEqual(data.test.num_examples, 10000)

    def test_shapes(self):
        images, labels = data.train.next_batch(32, shuffle=False)
        self.assertEqual(images.shape, (32, 28, 28, 1))
        self.assertEqual(labels.shape, (32, 10))
        data.train.next_batch(data.train.num_examples - 32, shuffle=False)

        images, labels = data.val.next_batch(32, shuffle=False)
        self.assertEqual(images.shape, (32, 28, 28, 1))
        self.assertEqual(labels.shape, (32, 10))
        data.val.next_batch(data.val.num_examples - 32, shuffle=False)

        images, labels = data.test.next_batch(32, shuffle=False)
        self.assertEqual(images.shape, (32, 28, 28, 1))
        self.assertEqual(labels.shape, (32, 10))
        data.test.next_batch(data.test.num_examples - 32, shuffle=False)

    def test_images(self):
        images, _ = data.train.next_batch(
            data.train.num_examples, shuffle=False)

        self.assertEqual(images.dtype, np.float32)
        self.assertLessEqual(images.max(), 1)
        self.assertGreaterEqual(images.min(), 0)

        images, _ = data.val.next_batch(data.val.num_examples, shuffle=False)

        self.assertEqual(images.dtype, np.float32)
        self.assertLessEqual(images.max(), 1)
        self.assertGreaterEqual(images.min(), 0)

        images, _ = data.test.next_batch(data.test.num_examples, shuffle=False)

        self.assertEqual(images.dtype, np.float32)
        self.assertLessEqual(images.max(), 1)
        self.assertGreaterEqual(images.min(), 0)

    def test_labels(self):
        _, labels = data.train.next_batch(
            data.train.num_examples, shuffle=False)

        self.assertEqual(labels.dtype, np.uint8)

        _, labels = data.val.next_batch(
            data.val.num_examples, shuffle=False)

        self.assertEqual(labels.dtype, np.uint8)

        _, labels = data.test.next_batch(
            data.test.num_examples, shuffle=False)

        self.assertEqual(labels.dtype, np.uint8)

    def test_class_functions(self):
        self.assertEqual(data.classes,
                         ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
        self.assertEqual(data.num_classes, 10)

        _, labels = data.test.next_batch(5, shuffle=False)

        self.assertEqual(data.classnames(labels[0]), ['7'])
        self.assertEqual(data.classnames(labels[1]), ['2'])
        self.assertEqual(data.classnames(labels[2]), ['1'])
        self.assertEqual(data.classnames(labels[3]), ['0'])
        self.assertEqual(data.classnames(labels[4]), ['4'])

        data.test.next_batch(data.test.num_examples - 5, shuffle=False)