tensorflow/models

View on GitHub
research/adversarial_text/inputs.py

Summary

Maintainability
C
1 day
Test Coverage
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input utils for virtual adversarial text classification."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

# Dependency imports

import tensorflow as tf

from data import data_utils


class VatxtInput(object):
  """Wrapper around NextQueuedSequenceBatch."""

  def __init__(self,
               batch,
               state_name=None,
               tokens=None,
               num_states=0,
               eos_id=None):
    """Construct VatxtInput.

    Args:
      batch: NextQueuedSequenceBatch.
      state_name: str, name of state to fetch and save.
      tokens: int Tensor, tokens. Defaults to batch's F_TOKEN_ID sequence.
      num_states: int The number of states to store.
      eos_id: int Id of end of Sequence.
    """
    self._batch = batch
    self._state_name = state_name
    self._tokens = (tokens if tokens is not None else
                    batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID])
    self._num_states = num_states

    w = batch.sequences[data_utils.SequenceWrapper.F_WEIGHT]
    self._weights = w

    l = batch.sequences[data_utils.SequenceWrapper.F_LABEL]
    self._labels = l

    # eos weights
    self._eos_weights = None
    if eos_id:
      ew = tf.cast(tf.equal(self._tokens, eos_id), tf.float32)
      self._eos_weights = ew

  @property
  def tokens(self):
    return self._tokens

  @property
  def weights(self):
    return self._weights

  @property
  def eos_weights(self):
    return self._eos_weights

  @property
  def labels(self):
    return self._labels

  @property
  def length(self):
    return self._batch.length

  @property
  def state_name(self):
    return self._state_name

  @property
  def state(self):
    # LSTM tuple states
    state_names = _get_tuple_state_names(self._num_states, self._state_name)
    return tuple([
        tf.contrib.rnn.LSTMStateTuple(
            self._batch.state(c_name), self._batch.state(h_name))
        for c_name, h_name in state_names
    ])

  def save_state(self, value):
    # LSTM tuple states
    state_names = _get_tuple_state_names(self._num_states, self._state_name)
    save_ops = []
    for (c_state, h_state), (c_name, h_name) in zip(value, state_names):
      save_ops.append(self._batch.save_state(c_name, c_state))
      save_ops.append(self._batch.save_state(h_name, h_state))
    return tf.group(*save_ops)


def _get_tuple_state_names(num_states, base_name):
  """Returns state names for use with LSTM tuple state."""
  state_names = [('{}_{}_c'.format(i, base_name), '{}_{}_h'.format(
      i, base_name)) for i in range(num_states)]
  return state_names


def _split_bidir_tokens(batch):
  tokens = batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID]
  # Tokens have shape [batch, time, 2]
  # forward and reverse have shape [batch, time].
  forward, reverse = [
      tf.squeeze(t, axis=[2]) for t in tf.split(tokens, 2, axis=2)
  ]
  return forward, reverse


def _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq):
  """Returns input filenames for configuration.

  Args:
    phase: str, 'train', 'test', or 'valid'.
    bidir: bool, bidirectional model.
    pretrain: bool, pretraining or classification.
    use_seq2seq: bool, seq2seq data, only valid if pretrain=True.

  Returns:
    Tuple of filenames.

  Raises:
    ValueError: if an invalid combination of arguments is provided that does not
      map to any data files (e.g. pretrain=False, use_seq2seq=True).
  """
  data_spec = (phase, bidir, pretrain, use_seq2seq)
  data_specs = {
      ('train', True, True, False): (data_utils.TRAIN_LM,
                                     data_utils.TRAIN_REV_LM),
      ('train', True, False, False): (data_utils.TRAIN_BD_CLASS,),
      ('train', False, True, False): (data_utils.TRAIN_LM,),
      ('train', False, True, True): (data_utils.TRAIN_SA,),
      ('train', False, False, False): (data_utils.TRAIN_CLASS,),
      ('test', True, True, False): (data_utils.TEST_LM,
                                    data_utils.TRAIN_REV_LM),
      ('test', True, False, False): (data_utils.TEST_BD_CLASS,),
      ('test', False, True, False): (data_utils.TEST_LM,),
      ('test', False, True, True): (data_utils.TEST_SA,),
      ('test', False, False, False): (data_utils.TEST_CLASS,),
      ('valid', True, False, False): (data_utils.VALID_BD_CLASS,),
      ('valid', False, False, False): (data_utils.VALID_CLASS,),
  }
  if data_spec not in data_specs:
    raise ValueError(
        'Data specification (phase, bidir, pretrain, use_seq2seq) %s not '
        'supported' % str(data_spec))

  return data_specs[data_spec]


def _read_single_sequence_example(file_list, tokens_shape=None):
  """Reads and parses SequenceExamples from TFRecord-encoded file_list."""
  tf.logging.info('Constructing TFRecordReader from files: %s', file_list)
  file_queue = tf.train.string_input_producer(file_list)
  reader = tf.TFRecordReader()
  seq_key, serialized_record = reader.read(file_queue)
  ctx, sequence = tf.parse_single_sequence_example(
      serialized_record,
      sequence_features={
          data_utils.SequenceWrapper.F_TOKEN_ID:
              tf.FixedLenSequenceFeature(tokens_shape or [], dtype=tf.int64),
          data_utils.SequenceWrapper.F_LABEL:
              tf.FixedLenSequenceFeature([], dtype=tf.int64),
          data_utils.SequenceWrapper.F_WEIGHT:
              tf.FixedLenSequenceFeature([], dtype=tf.float32),
      })
  return seq_key, ctx, sequence


def _read_and_batch(data_dir,
                    fname,
                    state_name,
                    state_size,
                    num_layers,
                    unroll_steps,
                    batch_size,
                    bidir_input=False):
  """Inputs for text model.

  Args:
    data_dir: str, directory containing TFRecord files of SequenceExample.
    fname: str, input file name.
    state_name: string, key for saved state of LSTM.
    state_size: int, size of LSTM state.
    num_layers: int, the number of layers in the LSTM.
    unroll_steps: int, number of timesteps to unroll for TBTT.
    batch_size: int, batch size.
    bidir_input: bool, whether the input is bidirectional. If True, creates 2
      states, state_name and state_name + '_reverse'.

  Returns:
    Instance of NextQueuedSequenceBatch

  Raises:
    ValueError: if file for input specification is not found.
  """
  data_path = os.path.join(data_dir, fname)
  if not tf.gfile.Exists(data_path):
    raise ValueError('Failed to find file: %s' % data_path)

  tokens_shape = [2] if bidir_input else []
  seq_key, ctx, sequence = _read_single_sequence_example(
      [data_path], tokens_shape=tokens_shape)
  # Set up stateful queue reader.
  state_names = _get_tuple_state_names(num_layers, state_name)
  initial_states = {}
  for c_state, h_state in state_names:
    initial_states[c_state] = tf.zeros(state_size)
    initial_states[h_state] = tf.zeros(state_size)
  if bidir_input:
    rev_state_names = _get_tuple_state_names(num_layers,
                                             '{}_reverse'.format(state_name))
    for rev_c_state, rev_h_state in rev_state_names:
      initial_states[rev_c_state] = tf.zeros(state_size)
      initial_states[rev_h_state] = tf.zeros(state_size)
  batch = tf.contrib.training.batch_sequences_with_states(
      input_key=seq_key,
      input_sequences=sequence,
      input_context=ctx,
      input_length=tf.shape(sequence['token_id'])[0],
      initial_states=initial_states,
      num_unroll=unroll_steps,
      batch_size=batch_size,
      allow_small_batch=False,
      num_threads=4,
      capacity=batch_size * 10,
      make_keys_unique=True,
      make_keys_unique_seed=29392)
  return batch


def inputs(data_dir=None,
           phase='train',
           bidir=False,
           pretrain=False,
           use_seq2seq=False,
           state_name='lstm',
           state_size=None,
           num_layers=0,
           batch_size=32,
           unroll_steps=100,
           eos_id=None):
  """Inputs for text model.

  Args:
    data_dir: str, directory containing TFRecord files of SequenceExample.
    phase: str, dataset for evaluation {'train', 'valid', 'test'}.
    bidir: bool, bidirectional LSTM.
    pretrain: bool, whether to read pretraining data or classification data.
    use_seq2seq: bool, whether to read seq2seq data or the language model data.
    state_name: string, key for saved state of LSTM.
    state_size: int, size of LSTM state.
    num_layers: int, the number of LSTM layers.
    batch_size: int, batch size.
    unroll_steps: int, number of timesteps to unroll for TBTT.
    eos_id: int, id of end of sequence. used for the kl weights on vat
  Returns:
    Instance of VatxtInput (x2 if bidir=True and pretrain=True, i.e. forward and
      reverse).
  """
  with tf.name_scope('inputs'):
    filenames = _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq)

    if bidir and pretrain:
      # Bidirectional pretraining
      # Requires separate forward and reverse language model data.
      forward_fname, reverse_fname = filenames
      forward_batch = _read_and_batch(data_dir, forward_fname, state_name,
                                      state_size, num_layers, unroll_steps,
                                      batch_size)
      state_name_rev = state_name + '_reverse'
      reverse_batch = _read_and_batch(data_dir, reverse_fname, state_name_rev,
                                      state_size, num_layers, unroll_steps,
                                      batch_size)
      forward_input = VatxtInput(
          forward_batch,
          state_name=state_name,
          num_states=num_layers,
          eos_id=eos_id)
      reverse_input = VatxtInput(
          reverse_batch,
          state_name=state_name_rev,
          num_states=num_layers,
          eos_id=eos_id)
      return forward_input, reverse_input

    elif bidir:
      # Classifier bidirectional LSTM
      # Shared data source, but separate token/state streams
      fname, = filenames
      batch = _read_and_batch(
          data_dir,
          fname,
          state_name,
          state_size,
          num_layers,
          unroll_steps,
          batch_size,
          bidir_input=True)
      forward_tokens, reverse_tokens = _split_bidir_tokens(batch)
      forward_input = VatxtInput(
          batch,
          state_name=state_name,
          tokens=forward_tokens,
          num_states=num_layers)
      reverse_input = VatxtInput(
          batch,
          state_name=state_name + '_reverse',
          tokens=reverse_tokens,
          num_states=num_layers)
      return forward_input, reverse_input
    else:
      # Unidirectional LM or classifier
      fname, = filenames
      batch = _read_and_batch(
          data_dir,
          fname,
          state_name,
          state_size,
          num_layers,
          unroll_steps,
          batch_size,
          bidir_input=False)
      return VatxtInput(
          batch, state_name=state_name, num_states=num_layers, eos_id=eos_id)