tensorflow/models

View on GitHub
research/adversarial_text/graphs_test.py

Summary

Maintainability
A
3 hrs
Test Coverage
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for graphs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from collections import defaultdict
import operator
import os
import random
import shutil
import string
import tempfile

# Dependency imports

import tensorflow as tf

import graphs
from data import data_utils

flags = tf.app.flags
FLAGS = flags.FLAGS
data = data_utils

flags.DEFINE_integer('task', 0, 'Task id; needed for SyncReplicas test')


def _build_random_vocabulary(vocab_size=100):
  """Builds and returns a dict<term, id>."""
  vocab = set()
  while len(vocab) < (vocab_size - 1):
    rand_word = ''.join(
        random.choice(string.ascii_lowercase)
        for _ in range(random.randint(1, 10)))
    vocab.add(rand_word)

  vocab_ids = dict([(word, i) for i, word in enumerate(vocab)])
  vocab_ids[data.EOS_TOKEN] = vocab_size - 1
  return vocab_ids


def _build_random_sequence(vocab_ids):
  seq_len = random.randint(10, 200)
  ids = vocab_ids.values()
  seq = data.SequenceWrapper()
  for token_id in [random.choice(ids) for _ in range(seq_len)]:
    seq.add_timestep().set_token(token_id)
  return seq


def _build_vocab_frequencies(seqs, vocab_ids):
  vocab_freqs = defaultdict(int)
  ids_to_words = dict([(i, word) for word, i in vocab_ids.iteritems()])
  for seq in seqs:
    for timestep in seq:
      vocab_freqs[ids_to_words[timestep.token]] += 1

  vocab_freqs[data.EOS_TOKEN] = 0
  return vocab_freqs


class GraphsTest(tf.test.TestCase):
  """Test graph construction methods."""

  @classmethod
  def setUpClass(cls):
    # Make model small
    FLAGS.batch_size = 2
    FLAGS.num_timesteps = 3
    FLAGS.embedding_dims = 4
    FLAGS.rnn_num_layers = 2
    FLAGS.rnn_cell_size = 4
    FLAGS.cl_num_layers = 2
    FLAGS.cl_hidden_size = 4
    FLAGS.vocab_size = 10

    # Set input/output flags
    FLAGS.data_dir = tempfile.mkdtemp()

    # Build and write sequence files.
    vocab_ids = _build_random_vocabulary(FLAGS.vocab_size)
    seqs = [_build_random_sequence(vocab_ids) for _ in range(5)]
    seqs_label = [
        data.build_labeled_sequence(seq, random.choice([True, False]))
        for seq in seqs
    ]
    seqs_lm = [data.build_lm_sequence(seq) for seq in seqs]
    seqs_ae = [data.build_seq_ae_sequence(seq) for seq in seqs]
    seqs_rev = [data.build_reverse_sequence(seq) for seq in seqs]
    seqs_bidir = [
        data.build_bidirectional_seq(seq, rev)
        for seq, rev in zip(seqs, seqs_rev)
    ]
    seqs_bidir_label = [
        data.build_labeled_sequence(bd_seq, random.choice([True, False]))
        for bd_seq in seqs_bidir
    ]

    filenames = [
        data.TRAIN_CLASS, data.TRAIN_LM, data.TRAIN_SA, data.TEST_CLASS,
        data.TRAIN_REV_LM, data.TRAIN_BD_CLASS, data.TEST_BD_CLASS
    ]
    seq_lists = [
        seqs_label, seqs_lm, seqs_ae, seqs_label, seqs_rev, seqs_bidir,
        seqs_bidir_label
    ]
    for fname, seq_list in zip(filenames, seq_lists):
      with tf.python_io.TFRecordWriter(
          os.path.join(FLAGS.data_dir, fname)) as writer:
        for seq in seq_list:
          writer.write(seq.seq.SerializeToString())

    # Write vocab.txt and vocab_freq.txt
    vocab_freqs = _build_vocab_frequencies(seqs, vocab_ids)
    ordered_vocab_freqs = sorted(
        vocab_freqs.items(), key=operator.itemgetter(1), reverse=True)
    with open(os.path.join(FLAGS.data_dir, 'vocab.txt'), 'w') as vocab_f:
      with open(os.path.join(FLAGS.data_dir, 'vocab_freq.txt'), 'w') as freq_f:
        for word, freq in ordered_vocab_freqs:
          vocab_f.write('{}\n'.format(word))
          freq_f.write('{}\n'.format(freq))

  @classmethod
  def tearDownClass(cls):
    shutil.rmtree(FLAGS.data_dir)

  def setUp(self):
    # Reset FLAGS
    FLAGS.rnn_num_layers = 1
    FLAGS.sync_replicas = False
    FLAGS.adv_training_method = None
    FLAGS.num_candidate_samples = -1
    FLAGS.num_classes = 2
    FLAGS.use_seq2seq_autoencoder = False

    # Reset Graph
    tf.reset_default_graph()

  def testClassifierGraph(self):
    FLAGS.rnn_num_layers = 2
    model = graphs.VatxtModel()
    train_op, _, _ = model.classifier_training()
    # Pretrained vars: embedding + LSTM layers
    self.assertEqual(
        len(model.pretrained_variables), 1 + 2 * FLAGS.rnn_num_layers)
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      tf.train.start_queue_runners(sess)
      sess.run(train_op)

  def testLanguageModelGraph(self):
    train_op, _, _ = graphs.VatxtModel().language_model_training()
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      tf.train.start_queue_runners(sess)
      sess.run(train_op)

  def testMulticlass(self):
    FLAGS.num_classes = 10
    graphs.VatxtModel().classifier_graph()

  def testATMethods(self):
    at_methods = [None, 'rp', 'at', 'vat', 'atvat']
    for method in at_methods:
      FLAGS.adv_training_method = method
      with tf.Graph().as_default():
        graphs.VatxtModel().classifier_graph()

        # Ensure variables have been reused
        # Embedding + LSTM layers + hidden layers + logits layer
        expected_num_vars = 1 + 2 * FLAGS.rnn_num_layers + 2 * (
            FLAGS.cl_num_layers) + 2
        self.assertEqual(len(tf.trainable_variables()), expected_num_vars)

  def testSyncReplicas(self):
    FLAGS.sync_replicas = True
    graphs.VatxtModel().language_model_training()

  def testCandidateSampling(self):
    FLAGS.num_candidate_samples = 10
    graphs.VatxtModel().language_model_training()

  def testSeqAE(self):
    FLAGS.use_seq2seq_autoencoder = True
    graphs.VatxtModel().language_model_training()

  def testBidirLM(self):
    graphs.VatxtBidirModel().language_model_graph()

  def testBidirClassifier(self):
    at_methods = [None, 'rp', 'at', 'vat', 'atvat']
    for method in at_methods:
      FLAGS.adv_training_method = method
      with tf.Graph().as_default():
        graphs.VatxtBidirModel().classifier_graph()

        # Ensure variables have been reused
        # Embedding + 2 LSTM layers + hidden layers + logits layer
        expected_num_vars = 1 + 2 * 2 * FLAGS.rnn_num_layers + 2 * (
            FLAGS.cl_num_layers) + 2
        self.assertEqual(len(tf.trainable_variables()), expected_num_vars)

  def testEvalGraph(self):
    _, _ = graphs.VatxtModel().eval_graph()

  def testBidirEvalGraph(self):
    _, _ = graphs.VatxtBidirModel().eval_graph()


if __name__ == '__main__':
  tf.test.main()