tensorflow/models

View on GitHub
research/rebar/rebar_train.py

Summary

Maintainability
C
7 hrs
Test Coverage
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import random
import sys
import os

import numpy as np
import tensorflow as tf

import rebar
import datasets
import logger as L

try:
  xrange          # Python 2
except NameError:
  xrange = range  # Python 3

gfile = tf.gfile

tf.app.flags.DEFINE_string("working_dir", "/tmp/rebar",
                           """Directory where to save data, write logs, etc.""")
tf.app.flags.DEFINE_string('hparams', '',
                           '''Comma separated list of name=value pairs.''')
tf.app.flags.DEFINE_integer('eval_freq', 20,
                           '''How often to run the evaluation step.''')
FLAGS = tf.flags.FLAGS

def manual_scalar_summary(name, value):
  value = tf.Summary.Value(tag=name, simple_value=value)
  summary_str = tf.Summary(value=[value])
  return summary_str

def eval(sbn, eval_xs, n_samples=100, batch_size=5):
  n = eval_xs.shape[0]
  i = 0
  res = []
  while i < n:
    batch_xs = eval_xs[i:min(i+batch_size, n)]
    res.append(sbn.partial_eval(batch_xs, n_samples))
    i += batch_size
  res = np.mean(res, axis=0)
  return res

def train(sbn, train_xs, valid_xs, test_xs, training_steps, debug=False):
  hparams = sorted(sbn.hparams.values().items())
  hparams = (map(str, x) for x in hparams)
  hparams = ('_'.join(x) for x in hparams)
  hparams_str = '.'.join(hparams)

  logger = L.Logger()

  # Create the experiment name from the hparams
  experiment_name = ([str(sbn.hparams.n_hidden) for i in xrange(sbn.hparams.n_layer)] +
                     [str(sbn.hparams.n_input)])
  if sbn.hparams.nonlinear:
    experiment_name = '~'.join(experiment_name)
  else:
    experiment_name = '-'.join(experiment_name)
  experiment_name = 'SBN_%s' % experiment_name
  rowkey = {'experiment': experiment_name,
            'model': hparams_str}

  # Create summary writer
  summ_dir = os.path.join(FLAGS.working_dir, hparams_str)
  summary_writer = tf.summary.FileWriter(
      summ_dir, flush_secs=15, max_queue=100)

  sv = tf.train.Supervisor(logdir=os.path.join(
      FLAGS.working_dir, hparams_str),
                     save_summaries_secs=0,
                     save_model_secs=1200,
                     summary_op=None,
                     recovery_wait_secs=30,
                     global_step=sbn.global_step)
  with sv.managed_session() as sess:
    # Dump hparams to file
    with gfile.Open(os.path.join(FLAGS.working_dir,
                                 hparams_str,
                                 'hparams.json'),
                    'w') as out:
      json.dump(sbn.hparams.values(), out)

    sbn.initialize(sess)
    batch_size = sbn.hparams.batch_size
    scores = []
    n = train_xs.shape[0]
    index = range(n)

    while not sv.should_stop():
      lHats = []
      grad_variances = []
      temperatures = []
      random.shuffle(index)
      i = 0
      while i < n:
        batch_index = index[i:min(i+batch_size, n)]
        batch_xs = train_xs[batch_index, :]

        if sbn.hparams.dynamic_b:
          # Dynamically binarize the batch data
          batch_xs = (np.random.rand(*batch_xs.shape) < batch_xs).astype(float)

        lHat, grad_variance, step, temperature = sbn.partial_fit(batch_xs,
                                                    sbn.hparams.n_samples)
        if debug:
          print(i, lHat)
          if i > 100:
            return
        lHats.append(lHat)
        grad_variances.append(grad_variance)
        temperatures.append(temperature)
        i += batch_size

      grad_variances = np.log(np.mean(grad_variances, axis=0)).tolist()
      summary_strings = []
      if isinstance(grad_variances, list):
        grad_variances = dict(zip([k for (k, v) in sbn.losses], map(float, grad_variances)))
        rowkey['step'] = step
        logger.log(rowkey, {'step': step,
                             'train': np.mean(lHats, axis=0)[0],
                             'grad_variances': grad_variances,
                             'temperature': np.mean(temperatures), })
        grad_variances = '\n'.join(map(str, sorted(grad_variances.iteritems())))
      else:
        rowkey['step'] = step
        logger.log(rowkey, {'step': step,
                             'train': np.mean(lHats, axis=0)[0],
                             'grad_variance': grad_variances,
                             'temperature': np.mean(temperatures), })
        summary_strings.append(manual_scalar_summary("log grad variance", grad_variances))

      print('Step %d: %s\n%s' % (step, str(np.mean(lHats, axis=0)), str(grad_variances)))

      # Every few epochs compute test and validation scores
      epoch = int(step / (train_xs.shape[0] / sbn.hparams.batch_size))
      if epoch % FLAGS.eval_freq == 0:
        valid_res = eval(sbn, valid_xs)
        test_res= eval(sbn, test_xs)

        print('\nValid %d: %s' % (step, str(valid_res)))
        print('Test %d: %s\n' % (step, str(test_res)))
        logger.log(rowkey, {'step': step,
                             'valid': valid_res[0],
                             'test': test_res[0]})
        logger.flush()  # Flush infrequently

      # Create summaries
      summary_strings.extend([
        manual_scalar_summary("Train ELBO", np.mean(lHats, axis=0)[0]),
        manual_scalar_summary("Temperature", np.mean(temperatures)),
      ])
      for summ_str in summary_strings:
        summary_writer.add_summary(summ_str, global_step=step)
      summary_writer.flush()

      sys.stdout.flush()
      scores.append(np.mean(lHats, axis=0))

      if step > training_steps:
        break

    return scores


def main():
  # Parse hyperparams
  hparams = rebar.default_hparams
  hparams.parse(FLAGS.hparams)
  print(hparams.values())

  train_xs, valid_xs, test_xs = datasets.load_data(hparams)
  mean_xs = np.mean(train_xs, axis=0)  # Compute mean centering on training

  training_steps = 2000000
  model = getattr(rebar, hparams.model)
  sbn = model(hparams, mean_xs=mean_xs)

  scores = train(sbn, train_xs, valid_xs, test_xs,
                 training_steps=training_steps, debug=False)

if __name__ == '__main__':
  main()