tensorflow/models

View on GitHub
research/adversarial_text/data/document_generators.py

Summary

Maintainability
D
1 day
Test Coverage
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input readers and document/token generators for datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from collections import namedtuple
import csv
import os
import random

# Dependency imports

import tensorflow as tf

from data import data_utils

flags = tf.app.flags
FLAGS = flags.FLAGS

flags.DEFINE_string('dataset', '', 'Which dataset to generate data for')

# Preprocessing config
flags.DEFINE_boolean('output_unigrams', True, 'Whether to output unigrams.')
flags.DEFINE_boolean('output_bigrams', False, 'Whether to output bigrams.')
flags.DEFINE_boolean('output_char', False, 'Whether to output characters.')
flags.DEFINE_boolean('lowercase', True, 'Whether to lowercase document terms.')

# IMDB
flags.DEFINE_string('imdb_input_dir', '', 'The input directory containing the '
                    'IMDB sentiment dataset.')
flags.DEFINE_integer('imdb_validation_pos_start_id', 10621, 'File id of the '
                     'first file in the pos sentiment validation set.')
flags.DEFINE_integer('imdb_validation_neg_start_id', 10625, 'File id of the '
                     'first file in the neg sentiment validation set.')

# DBpedia
flags.DEFINE_string('dbpedia_input_dir', '',
                    'Path to DBpedia directory containing train.csv and '
                    'test.csv.')

# Reuters Corpus (rcv1)
flags.DEFINE_string('rcv1_input_dir', '',
                    'Path to rcv1 directory containing train.csv, unlab.csv, '
                    'and test.csv.')

# Rotten Tomatoes
flags.DEFINE_string('rt_input_dir', '',
                    'The Rotten Tomatoes dataset input directory.')

# The amazon reviews input file to use in either the RT or IMDB datasets.
flags.DEFINE_string('amazon_unlabeled_input_file', '',
                    'The unlabeled Amazon Reviews dataset input file. If set, '
                    'the input file is used to augment RT and IMDB vocab.')

Document = namedtuple('Document',
                      'content is_validation is_test label add_tokens')


def documents(dataset='train',
              include_unlabeled=False,
              include_validation=False):
  """Generates Documents based on FLAGS.dataset.

  Args:
    dataset: str, identifies folder within IMDB data directory, test or train.
    include_unlabeled: bool, whether to include the unsup directory. Only valid
      when dataset=train.
    include_validation: bool, whether to include validation data.

  Yields:
    Document

  Raises:
    ValueError: if include_unlabeled is true but dataset is not 'train'
  """

  if include_unlabeled and dataset != 'train':
    raise ValueError('If include_unlabeled=True, must use train dataset')

  # Set the random seed so that we have the same validation set when running
  # gen_data and gen_vocab.
  random.seed(302)

  ds = FLAGS.dataset
  if ds == 'imdb':
    docs_gen = imdb_documents
  elif ds == 'dbpedia':
    docs_gen = dbpedia_documents
  elif ds == 'rcv1':
    docs_gen = rcv1_documents
  elif ds == 'rt':
    docs_gen = rt_documents
  else:
    raise ValueError('Unrecognized dataset %s' % FLAGS.dataset)

  for doc in docs_gen(dataset, include_unlabeled, include_validation):
    yield doc


def tokens(doc):
  """Given a Document, produces character or word tokens.

  Tokens can be either characters, or word-level tokens (unigrams and/or
  bigrams).

  Args:
    doc: Document to produce tokens from.

  Yields:
    token

  Raises:
    ValueError: if all FLAGS.{output_unigrams, output_bigrams, output_char}
      are False.
  """
  if not (FLAGS.output_unigrams or FLAGS.output_bigrams or FLAGS.output_char):
    raise ValueError(
        'At least one of {FLAGS.output_unigrams, FLAGS.output_bigrams, '
        'FLAGS.output_char} must be true')

  content = doc.content.strip()
  if FLAGS.lowercase:
    content = content.lower()

  if FLAGS.output_char:
    for char in content:
      yield char

  else:
    tokens_ = data_utils.split_by_punct(content)
    for i, token in enumerate(tokens_):
      if FLAGS.output_unigrams:
        yield token

      if FLAGS.output_bigrams:
        previous_token = (tokens_[i - 1] if i > 0 else data_utils.EOS_TOKEN)
        bigram = '_'.join([previous_token, token])
        yield bigram
        if (i + 1) == len(tokens_):
          bigram = '_'.join([token, data_utils.EOS_TOKEN])
          yield bigram


def imdb_documents(dataset='train',
                   include_unlabeled=False,
                   include_validation=False):
  """Generates Documents for IMDB dataset.

  Data from http://ai.stanford.edu/~amaas/data/sentiment/

  Args:
    dataset: str, identifies folder within IMDB data directory, test or train.
    include_unlabeled: bool, whether to include the unsup directory. Only valid
      when dataset=train.
    include_validation: bool, whether to include validation data.

  Yields:
    Document

  Raises:
    ValueError: if FLAGS.imdb_input_dir is empty.
  """
  if not FLAGS.imdb_input_dir:
    raise ValueError('Must provide FLAGS.imdb_input_dir')

  tf.logging.info('Generating IMDB documents...')

  def check_is_validation(filename, class_label):
    if class_label is None:
      return False
    file_idx = int(filename.split('_')[0])
    is_pos_valid = (class_label and
                    file_idx >= FLAGS.imdb_validation_pos_start_id)
    is_neg_valid = (not class_label and
                    file_idx >= FLAGS.imdb_validation_neg_start_id)
    return is_pos_valid or is_neg_valid

  dirs = [(dataset + '/pos', True), (dataset + '/neg', False)]
  if include_unlabeled:
    dirs.append(('train/unsup', None))

  for d, class_label in dirs:
    for filename in os.listdir(os.path.join(FLAGS.imdb_input_dir, d)):
      is_validation = check_is_validation(filename, class_label)
      if is_validation and not include_validation:
        continue

      with open(os.path.join(FLAGS.imdb_input_dir, d, filename), encoding='utf-8') as imdb_f:
        content = imdb_f.read()
      yield Document(
          content=content,
          is_validation=is_validation,
          is_test=False,
          label=class_label,
          add_tokens=True)

  if FLAGS.amazon_unlabeled_input_file and include_unlabeled:
    with open(FLAGS.amazon_unlabeled_input_file, encoding='utf-8') as rt_f:
      for content in rt_f:
        yield Document(
            content=content,
            is_validation=False,
            is_test=False,
            label=None,
            add_tokens=False)


def dbpedia_documents(dataset='train',
                      include_unlabeled=False,
                      include_validation=False):
  """Generates Documents for DBpedia dataset.

  Dataset linked to at https://github.com/zhangxiangxiao/Crepe.

  Args:
    dataset: str, identifies the csv file within the DBpedia data directory,
      test or train.
    include_unlabeled: bool, unused.
    include_validation: bool, whether to include validation data, which is a
      randomly selected 10% of the data.

  Yields:
    Document

  Raises:
    ValueError: if FLAGS.dbpedia_input_dir is empty.
  """
  del include_unlabeled

  if not FLAGS.dbpedia_input_dir:
    raise ValueError('Must provide FLAGS.dbpedia_input_dir')

  tf.logging.info('Generating DBpedia documents...')

  with open(os.path.join(FLAGS.dbpedia_input_dir, dataset + '.csv')) as db_f:
    reader = csv.reader(db_f)
    for row in reader:
      # 10% of the data is randomly held out
      is_validation = random.randint(1, 10) == 1
      if is_validation and not include_validation:
        continue

      content = row[1] + ' ' + row[2]
      yield Document(
          content=content,
          is_validation=is_validation,
          is_test=False,
          label=int(row[0]) - 1,  # Labels should start from 0
          add_tokens=True)


def rcv1_documents(dataset='train',
                   include_unlabeled=True,
                   include_validation=False):
  # pylint:disable=line-too-long
  """Generates Documents for Reuters Corpus (rcv1) dataset.

  Dataset described at
  http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/lyrl2004_rcv1v2_README.htm

  Args:
    dataset: str, identifies the csv file within the rcv1 data directory.
    include_unlabeled: bool, whether to include the unlab file. Only valid
      when dataset=train.
    include_validation: bool, whether to include validation data, which is a
      randomly selected 10% of the data.

  Yields:
    Document

  Raises:
    ValueError: if FLAGS.rcv1_input_dir is empty.
  """
  # pylint:enable=line-too-long

  if not FLAGS.rcv1_input_dir:
    raise ValueError('Must provide FLAGS.rcv1_input_dir')

  tf.logging.info('Generating rcv1 documents...')

  datasets = [dataset]
  if include_unlabeled:
    if dataset == 'train':
      datasets.append('unlab')
  for dset in datasets:
    with open(os.path.join(FLAGS.rcv1_input_dir, dset + '.csv')) as db_f:
      reader = csv.reader(db_f)
      for row in reader:
        # 10% of the data is randomly held out
        is_validation = random.randint(1, 10) == 1
        if is_validation and not include_validation:
          continue

        content = row[1]
        yield Document(
            content=content,
            is_validation=is_validation,
            is_test=False,
            label=int(row[0]),
            add_tokens=True)


def rt_documents(dataset='train',
                 include_unlabeled=True,
                 include_validation=False):
  # pylint:disable=line-too-long
  """Generates Documents for the Rotten Tomatoes dataset.

  Dataset available at http://www.cs.cornell.edu/people/pabo/movie-review-data/
  In this dataset, amazon reviews are used for the unlabeled data.

  Args:
    dataset: str, identifies the data subdirectory.
    include_unlabeled: bool, whether to include the unlabeled data. Only valid
      when dataset=train.
    include_validation: bool, whether to include validation data, which is a
      randomly selected 10% of the data.

  Yields:
    Document

  Raises:
    ValueError: if FLAGS.rt_input_dir is empty.
  """
  # pylint:enable=line-too-long

  if not FLAGS.rt_input_dir:
    raise ValueError('Must provide FLAGS.rt_input_dir')

  tf.logging.info('Generating rt documents...')

  data_files = []
  input_filenames = os.listdir(FLAGS.rt_input_dir)
  for inp_fname in input_filenames:
    if inp_fname.endswith('.pos'):
      data_files.append((os.path.join(FLAGS.rt_input_dir, inp_fname), True))
    elif inp_fname.endswith('.neg'):
      data_files.append((os.path.join(FLAGS.rt_input_dir, inp_fname), False))
  if include_unlabeled and FLAGS.amazon_unlabeled_input_file:
    data_files.append((FLAGS.amazon_unlabeled_input_file, None))

  for filename, class_label in data_files:
    with open(filename) as rt_f:
      for content in rt_f:
        if class_label is None:
          # Process Amazon Review data for unlabeled dataset
          if content.startswith('review/text'):
            yield Document(
                content=content,
                is_validation=False,
                is_test=False,
                label=None,
                add_tokens=False)
        else:
          # 10% of the data is randomly held out for the validation set and
          # another 10% of it is randomly held out for the test set
          random_int = random.randint(1, 10)
          is_validation = random_int == 1
          is_test = random_int == 2
          if (is_test and dataset != 'test') or (is_validation and
                                                 not include_validation):
            continue

          yield Document(
              content=content,
              is_validation=is_validation,
              is_test=is_test,
              label=class_label,
              add_tokens=True)