rusty1s/embedded_gcnn

View on GitHub
lib/model/train.py

Summary

Maintainability
A
3 hrs
Test Coverage
from __future__ import print_function
from __future__ import division

import os
import time
from six.moves import xrange
from sklearn.preprocessing import StandardScaler

from ..datasets import PreprocessQueue
from .placeholder import feed_dict_with_batch
from ..pipeline import PreprocessedDataset, FileQueue, augment_batch


def train(model,
          data,
          preprocess_algorithm,
          batch_size,
          dropout,
          augment,
          max_steps,
          preprocess_first=None,
          display_step=10,
          save_step=250):

    train_queue, val_queue, test_queue = _generate_queues(
        data, preprocess_first, preprocess_algorithm, augment, batch_size)

    model.build()
    global_step = model.initialize()

    try:
        for step in xrange(global_step, max_steps):
            t_pre = time.process_time()
            batch = train_queue.dequeue()

            # Augment the preprocessed data.
            if augment and preprocess_first is not None:
                batch = augment_batch(batch) if augment else batch

            batch = _standard_scale_batch(batch)
            feed_dict = feed_dict_with_batch(model.placeholders, batch,
                                             dropout)
            t_pre = time.process_time() - t_pre

            t_train = model.train(feed_dict, step)

            if step % display_step == 0:
                # Evaluate on training and validation set with zero dropout.
                feed_dict.update({model.placeholders['dropout']: 0})
                batch = val_queue.dequeue()
                batch = _standard_scale_batch(batch)
                val_feed_dict = feed_dict_with_batch(model.placeholders, batch)

                train_info = model.evaluate(feed_dict, step, 'train')
                val_info = model.evaluate(val_feed_dict, step, 'val')

                log = 'step={}, '.format(step)
                log += 'time={:.2f}s + {:.2f}s, '.format(t_pre, t_train)
                log += 'train_loss={:.5f}, '.format(train_info[0])
                log += 'train_acc={:.5f}, '.format(train_info[1])
                log += 'val_loss={:.5f}, '.format(val_info[0])
                log += 'val_acc={:.5f}'.format(val_info[1])

                print(log)

            if step % save_step == 0:
                model.save()

    except KeyboardInterrupt:
        print()

    print('Optimization finished!')
    print('Evaluate on test set. This can take a few minutes.')

    try:
        num_steps = data.test.num_examples // batch_size
        test_info = [0, 0]

        for i in xrange(num_steps):
            batch = test_queue.dequeue()
            batch = _standard_scale_batch(batch)
            feed_dict = feed_dict_with_batch(model.placeholders, batch)

            batch_info = model.evaluate(feed_dict)
            test_info = [a + b for a, b in zip(test_info, batch_info)]

        log = 'Test results: '
        log += 'loss={:.5f}, '.format(test_info[0] / num_steps)
        log += 'acc={:.5f}'.format(test_info[1] / num_steps)

        print(log)

    except KeyboardInterrupt:
        print()
        print('Test evaluation aborted.')

    finally:
        train_queue.close()
        val_queue.close()
        test_queue.close()


def _preprocess_data(data, data_dir, preprocess_algorithm):
    data.train = PreprocessedDataset(
        os.path.join(data_dir, 'train'), data.train, preprocess_algorithm)
    data.val = PreprocessedDataset(
        os.path.join(data_dir, 'val'), data.val, preprocess_algorithm)
    data.test = PreprocessedDataset(
        os.path.join(data_dir, 'test'), data.test, preprocess_algorithm)
    return data


def _generate_queues(data, preprocess_first, preprocess_algorithm, augment,
                     batch_size):
    capacity = 10 * batch_size

    if preprocess_first is not None:
        data = _preprocess_data(data, preprocess_first, preprocess_algorithm)

        train_queue = FileQueue(data.train, batch_size, capacity, shuffle=True)
        val_queue = FileQueue(data.val, batch_size, capacity, shuffle=True)
        test_queue = FileQueue(data.test, batch_size, capacity, shuffle=False)
    else:
        train_queue = PreprocessQueue(
            data.train,
            preprocess_algorithm,
            augment,
            batch_size,
            capacity,
            shuffle=True)

        val_queue = PreprocessQueue(
            data.val,
            preprocess_algorithm,
            augment,
            batch_size,
            capacity,
            shuffle=True)

        test_queue = PreprocessQueue(
            data.test,
            preprocess_algorithm,
            augment,
            batch_size,
            capacity,
            shuffle=False)

    return train_queue, val_queue, test_queue


def _standard_scale_batch(batch):
    for example in batch:
        StandardScaler(copy=False).fit_transform(example[0])
    return batch