tensorflow/models

View on GitHub
research/adversarial_text/data/data_utils.py

Summary

Maintainability
A
3 hrs
Test Coverage
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for generating/preprocessing data for adversarial text models."""

import operator
import os
import random
import re

# Dependency imports

import tensorflow as tf

EOS_TOKEN = '</s>'

# Data filenames
# Sequence Autoencoder
ALL_SA = 'all_sa.tfrecords'
TRAIN_SA = 'train_sa.tfrecords'
TEST_SA = 'test_sa.tfrecords'
# Language Model
ALL_LM = 'all_lm.tfrecords'
TRAIN_LM = 'train_lm.tfrecords'
TEST_LM = 'test_lm.tfrecords'
# Classification
TRAIN_CLASS = 'train_classification.tfrecords'
TEST_CLASS = 'test_classification.tfrecords'
VALID_CLASS = 'validate_classification.tfrecords'
# LM with bidirectional LSTM
TRAIN_REV_LM = 'train_reverse_lm.tfrecords'
TEST_REV_LM = 'test_reverse_lm.tfrecords'
# Classification with bidirectional LSTM
TRAIN_BD_CLASS = 'train_bidir_classification.tfrecords'
TEST_BD_CLASS = 'test_bidir_classification.tfrecords'
VALID_BD_CLASS = 'validate_bidir_classification.tfrecords'


class ShufflingTFRecordWriter(object):
  """Thin wrapper around TFRecordWriter that shuffles records."""

  def __init__(self, path):
    self._path = path
    self._records = []
    self._closed = False

  def write(self, record):
    assert not self._closed
    self._records.append(record)

  def close(self):
    assert not self._closed
    random.shuffle(self._records)
    with tf.python_io.TFRecordWriter(self._path) as f:
      for record in self._records:
        f.write(record)
    self._closed = True

  def __enter__(self):
    return self

  def __exit__(self, unused_type, unused_value, unused_traceback):
    self.close()


class Timestep(object):
  """Represents a single timestep in a SequenceWrapper."""

  def __init__(self, token, label, weight, multivalent_tokens=False):
    """Constructs Timestep from empty Features."""
    self._token = token
    self._label = label
    self._weight = weight
    self._multivalent_tokens = multivalent_tokens
    self._fill_with_defaults()

  @property
  def token(self):
    if self._multivalent_tokens:
      raise TypeError('Timestep may contain multiple values; use `tokens`')
    return self._token.int64_list.value[0]

  @property
  def tokens(self):
    return self._token.int64_list.value

  @property
  def label(self):
    return self._label.int64_list.value[0]

  @property
  def weight(self):
    return self._weight.float_list.value[0]

  def set_token(self, token):
    if self._multivalent_tokens:
      raise TypeError('Timestep may contain multiple values; use `add_token`')
    self._token.int64_list.value[0] = token
    return self

  def add_token(self, token):
    self._token.int64_list.value.append(token)
    return self

  def set_label(self, label):
    self._label.int64_list.value[0] = label
    return self

  def set_weight(self, weight):
    self._weight.float_list.value[0] = weight
    return self

  def copy_from(self, timestep):
    self.set_token(timestep.token).set_label(timestep.label).set_weight(
        timestep.weight)
    return self

  def _fill_with_defaults(self):
    if not self._multivalent_tokens:
      self._token.int64_list.value.append(0)
    self._label.int64_list.value.append(0)
    self._weight.float_list.value.append(0.0)


class SequenceWrapper(object):
  """Wrapper around tf.SequenceExample."""

  F_TOKEN_ID = 'token_id'
  F_LABEL = 'label'
  F_WEIGHT = 'weight'

  def __init__(self, multivalent_tokens=False):
    self._seq = tf.train.SequenceExample()
    self._flist = self._seq.feature_lists.feature_list
    self._timesteps = []
    self._multivalent_tokens = multivalent_tokens

  @property
  def seq(self):
    return self._seq

  @property
  def multivalent_tokens(self):
    return self._multivalent_tokens

  @property
  def _tokens(self):
    return self._flist[SequenceWrapper.F_TOKEN_ID].feature

  @property
  def _labels(self):
    return self._flist[SequenceWrapper.F_LABEL].feature

  @property
  def _weights(self):
    return self._flist[SequenceWrapper.F_WEIGHT].feature

  def add_timestep(self):
    timestep = Timestep(
        self._tokens.add(),
        self._labels.add(),
        self._weights.add(),
        multivalent_tokens=self._multivalent_tokens)
    self._timesteps.append(timestep)
    return timestep

  def __iter__(self):
    for timestep in self._timesteps:
      yield timestep

  def __len__(self):
    return len(self._timesteps)

  def __getitem__(self, idx):
    return self._timesteps[idx]


def build_reverse_sequence(seq):
  """Builds a sequence that is the reverse of the input sequence."""
  reverse_seq = SequenceWrapper()

  # Copy all but last timestep
  for timestep in reversed(seq[:-1]):
    reverse_seq.add_timestep().copy_from(timestep)

  # Copy final timestep
  reverse_seq.add_timestep().copy_from(seq[-1])

  return reverse_seq


def build_bidirectional_seq(seq, rev_seq):
  bidir_seq = SequenceWrapper(multivalent_tokens=True)
  for forward_ts, reverse_ts in zip(seq, rev_seq):
    bidir_seq.add_timestep().add_token(forward_ts.token).add_token(
        reverse_ts.token)

  return bidir_seq


def build_lm_sequence(seq):
  """Builds language model sequence from input sequence.

  Args:
    seq: SequenceWrapper.

  Returns:
    SequenceWrapper with `seq` tokens copied over to output sequence tokens and
    labels (offset by 1, i.e. predict next token) with weights set to 1.0,
    except for <eos> token.
  """
  lm_seq = SequenceWrapper()
  for i, timestep in enumerate(seq):
    if i == len(seq) - 1:
      lm_seq.add_timestep().set_token(timestep.token).set_label(
          seq[i].token).set_weight(0.0)
    else:
      lm_seq.add_timestep().set_token(timestep.token).set_label(
          seq[i + 1].token).set_weight(1.0)
  return lm_seq


def build_seq_ae_sequence(seq):
  """Builds seq_ae sequence from input sequence.

  Args:
    seq: SequenceWrapper.

  Returns:
    SequenceWrapper with `seq` inputs copied and concatenated, and with labels
    copied in on the right-hand (i.e. decoder) side with weights set to 1.0.
    The new sequence will have length `len(seq) * 2 - 1`, as the last timestep
    of the encoder section and the first step of the decoder section will
    overlap.
  """
  seq_ae_seq = SequenceWrapper()

  for i in range(len(seq) * 2 - 1):
    ts = seq_ae_seq.add_timestep()

    if i < len(seq) - 1:
      # Encoder
      ts.set_token(seq[i].token)
    elif i == len(seq) - 1:
      # Transition step
      ts.set_token(seq[i].token)
      ts.set_label(seq[0].token)
      ts.set_weight(1.0)
    else:
      # Decoder
      ts.set_token(seq[i % len(seq)].token)
      ts.set_label(seq[(i + 1) % len(seq)].token)
      ts.set_weight(1.0)

  return seq_ae_seq


def build_labeled_sequence(seq, class_label, label_gain=False):
  """Builds labeled sequence from input sequence.

  Args:
    seq: SequenceWrapper.
    class_label: integer, starting from 0.
    label_gain: bool. If True, class_label will be put on every timestep and
      weight will increase linearly from 0 to 1.

  Returns:
    SequenceWrapper with `seq` copied in and `class_label` added as label to
    final timestep.
  """
  label_seq = SequenceWrapper(multivalent_tokens=seq.multivalent_tokens)

  # Copy sequence without labels
  seq_len = len(seq)
  final_timestep = None
  for i, timestep in enumerate(seq):
    label_timestep = label_seq.add_timestep()
    if seq.multivalent_tokens:
      for token in timestep.tokens:
        label_timestep.add_token(token)
    else:
      label_timestep.set_token(timestep.token)
    if label_gain:
      label_timestep.set_label(int(class_label))
      weight = 1.0 if seq_len < 2 else float(i) / (seq_len - 1)
      label_timestep.set_weight(weight)
    if i == (seq_len - 1):
      final_timestep = label_timestep

  # Edit final timestep to have class label and weight = 1.
  final_timestep.set_label(int(class_label)).set_weight(1.0)

  return label_seq


def split_by_punct(segment):
  """Splits str segment by punctuation, filters our empties and spaces."""
  return [s for s in re.split(r'\W+', segment) if s and not s.isspace()]


def sort_vocab_by_frequency(vocab_freq_map):
  """Sorts vocab_freq_map by count.

  Args:
    vocab_freq_map: dict<str term, int count>, vocabulary terms with counts.

  Returns:
    list<tuple<str term, int count>> sorted by count, descending.
  """
  return sorted(
      vocab_freq_map.items(), key=operator.itemgetter(1), reverse=True)


def write_vocab_and_frequency(ordered_vocab_freqs, output_dir):
  """Writes ordered_vocab_freqs into vocab.txt and vocab_freq.txt."""
  tf.gfile.MakeDirs(output_dir)
  with open(os.path.join(output_dir, 'vocab.txt'), 'w', encoding='utf-8') as vocab_f:
    with open(os.path.join(output_dir, 'vocab_freq.txt'), 'w', encoding='utf-8') as freq_f:
      for word, freq in ordered_vocab_freqs:
        vocab_f.write('{}\n'.format(word))
        freq_f.write('{}\n'.format(freq))