tensorflow/models

View on GitHub
research/adversarial_text/data/data_utils_test.py

Summary

Maintainability
D
1 day
Test Coverage
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports

import tensorflow as tf

from data import data_utils

data = data_utils


class SequenceWrapperTest(tf.test.TestCase):

  def testDefaultTimesteps(self):
    seq = data.SequenceWrapper()
    t1 = seq.add_timestep()
    _ = seq.add_timestep()
    self.assertEqual(len(seq), 2)

    self.assertEqual(t1.weight, 0.0)
    self.assertEqual(t1.label, 0)
    self.assertEqual(t1.token, 0)

  def testSettersAndGetters(self):
    ts = data.SequenceWrapper().add_timestep()
    ts.set_token(3)
    ts.set_label(4)
    ts.set_weight(2.0)
    self.assertEqual(ts.token, 3)
    self.assertEqual(ts.label, 4)
    self.assertEqual(ts.weight, 2.0)

  def testTimestepIteration(self):
    seq = data.SequenceWrapper()
    seq.add_timestep().set_token(0)
    seq.add_timestep().set_token(1)
    seq.add_timestep().set_token(2)
    for i, ts in enumerate(seq):
      self.assertEqual(ts.token, i)

  def testFillsSequenceExampleCorrectly(self):
    seq = data.SequenceWrapper()
    seq.add_timestep().set_token(1).set_label(2).set_weight(3.0)
    seq.add_timestep().set_token(10).set_label(20).set_weight(30.0)

    seq_ex = seq.seq
    fl = seq_ex.feature_lists.feature_list
    fl_token = fl[data.SequenceWrapper.F_TOKEN_ID].feature
    fl_label = fl[data.SequenceWrapper.F_LABEL].feature
    fl_weight = fl[data.SequenceWrapper.F_WEIGHT].feature
    _ = [self.assertEqual(len(f), 2) for f in [fl_token, fl_label, fl_weight]]
    self.assertAllEqual([f.int64_list.value[0] for f in fl_token], [1, 10])
    self.assertAllEqual([f.int64_list.value[0] for f in fl_label], [2, 20])
    self.assertAllEqual([f.float_list.value[0] for f in fl_weight], [3.0, 30.0])


class DataUtilsTest(tf.test.TestCase):

  def testSplitByPunct(self):
    output = data.split_by_punct(
        'hello! world, i\'ve been\nwaiting\tfor\ryou for.a long time')
    expected = [
        'hello', 'world', 'i', 've', 'been', 'waiting', 'for', 'you', 'for',
        'a', 'long', 'time'
    ]
    self.assertListEqual(output, expected)

  def _buildDummySequence(self):
    seq = data.SequenceWrapper()
    for i in range(10):
      seq.add_timestep().set_token(i)
    return seq

  def testBuildLMSeq(self):
    seq = self._buildDummySequence()
    lm_seq = data.build_lm_sequence(seq)
    for i, ts in enumerate(lm_seq):
      # For end of sequence, the token and label should be same, and weight
      # should be 0.0.
      if i == len(lm_seq) - 1:
        self.assertEqual(ts.token, i)
        self.assertEqual(ts.label, i)
        self.assertEqual(ts.weight, 0.0)
      else:
        self.assertEqual(ts.token, i)
        self.assertEqual(ts.label, i + 1)
        self.assertEqual(ts.weight, 1.0)

  def testBuildSAESeq(self):
    seq = self._buildDummySequence()
    sa_seq = data.build_seq_ae_sequence(seq)

    self.assertEqual(len(sa_seq), len(seq) * 2 - 1)

    # Tokens should be sequence twice, minus the EOS token at the end
    for i, ts in enumerate(sa_seq):
      self.assertEqual(ts.token, seq[i % 10].token)

    # Weights should be len-1 0.0's and len 1.0's.
    for i in range(len(seq) - 1):
      self.assertEqual(sa_seq[i].weight, 0.0)
    for i in range(len(seq) - 1, len(sa_seq)):
      self.assertEqual(sa_seq[i].weight, 1.0)

    # Labels should be len-1 0's, and then the sequence
    for i in range(len(seq) - 1):
      self.assertEqual(sa_seq[i].label, 0)
    for i in range(len(seq) - 1, len(sa_seq)):
      self.assertEqual(sa_seq[i].label, seq[i - (len(seq) - 1)].token)

  def testBuildLabelSeq(self):
    seq = self._buildDummySequence()
    eos_id = len(seq) - 1
    label_seq = data.build_labeled_sequence(seq, True)
    for i, ts in enumerate(label_seq[:-1]):
      self.assertEqual(ts.token, i)
      self.assertEqual(ts.label, 0)
      self.assertEqual(ts.weight, 0.0)

    final_timestep = label_seq[-1]
    self.assertEqual(final_timestep.token, eos_id)
    self.assertEqual(final_timestep.label, 1)
    self.assertEqual(final_timestep.weight, 1.0)

  def testBuildBidirLabelSeq(self):
    seq = self._buildDummySequence()
    reverse_seq = data.build_reverse_sequence(seq)
    bidir_seq = data.build_bidirectional_seq(seq, reverse_seq)
    label_seq = data.build_labeled_sequence(bidir_seq, True)

    for (i, ts), j in zip(
        enumerate(label_seq[:-1]), reversed(range(len(seq) - 1))):
      self.assertAllEqual(ts.tokens, [i, j])
      self.assertEqual(ts.label, 0)
      self.assertEqual(ts.weight, 0.0)

    final_timestep = label_seq[-1]
    eos_id = len(seq) - 1
    self.assertAllEqual(final_timestep.tokens, [eos_id, eos_id])
    self.assertEqual(final_timestep.label, 1)
    self.assertEqual(final_timestep.weight, 1.0)

  def testReverseSeq(self):
    seq = self._buildDummySequence()
    reverse_seq = data.build_reverse_sequence(seq)
    for i, ts in enumerate(reversed(reverse_seq[:-1])):
      self.assertEqual(ts.token, i)
      self.assertEqual(ts.label, 0)
      self.assertEqual(ts.weight, 0.0)

    final_timestep = reverse_seq[-1]
    eos_id = len(seq) - 1
    self.assertEqual(final_timestep.token, eos_id)
    self.assertEqual(final_timestep.label, 0)
    self.assertEqual(final_timestep.weight, 0.0)

  def testBidirSeq(self):
    seq = self._buildDummySequence()
    reverse_seq = data.build_reverse_sequence(seq)
    bidir_seq = data.build_bidirectional_seq(seq, reverse_seq)
    for (i, ts), j in zip(
        enumerate(bidir_seq[:-1]), reversed(range(len(seq) - 1))):
      self.assertAllEqual(ts.tokens, [i, j])
      self.assertEqual(ts.label, 0)
      self.assertEqual(ts.weight, 0.0)

    final_timestep = bidir_seq[-1]
    eos_id = len(seq) - 1
    self.assertAllEqual(final_timestep.tokens, [eos_id, eos_id])
    self.assertEqual(final_timestep.label, 0)
    self.assertEqual(final_timestep.weight, 0.0)

  def testLabelGain(self):
    seq = self._buildDummySequence()
    label_seq = data.build_labeled_sequence(seq, True, label_gain=True)
    for i, ts in enumerate(label_seq):
      self.assertEqual(ts.token, i)
      self.assertEqual(ts.label, 1)
      self.assertNear(ts.weight, float(i) / (len(seq) - 1), 1e-3)


if __name__ == '__main__':
  tf.test.main()