tensorflow/models

View on GitHub
official/nlp/modeling/layers/text_layers_test.py

Summary

Maintainability
F
2 wks
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests bert.text_layers."""

import os
import tempfile

import numpy as np
import tensorflow as tf, tf_keras
from tensorflow import estimator as tf_estimator

from sentencepiece import SentencePieceTrainer
from official.nlp.modeling.layers import text_layers


# This test covers the in-process behavior of a BertTokenizer layer.
# For saving, restoring, and the restored behavior (incl. shape inference),
# see nlp/tools/export_tfhub_lib_test.py.
class BertTokenizerTest(tf.test.TestCase):

  def _make_vocab_file(self, vocab, filename="vocab.txt"):
    path = os.path.join(
        tempfile.mkdtemp(dir=self.get_temp_dir()),  # New subdir each time.
        filename)
    with tf.io.gfile.GFile(path, "w") as f:
      f.write("\n".join(vocab + [""]))
    return path

  def test_uncased(self):
    vocab_file = self._make_vocab_file(
        ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"])
    bert_tokenize = text_layers.BertTokenizer(
        vocab_file=vocab_file, lower_case=True)
    inputs = tf.constant(["abc def", "ABC DEF d"])
    token_ids = bert_tokenize(inputs)
    self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]],
                                                       [[6], [4, 5], [4]]]))
    bert_tokenize.tokenize_with_offsets = True
    token_ids_2, start_offsets, limit_offsets = bert_tokenize(inputs)
    self.assertAllEqual(token_ids, token_ids_2)
    self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]],
                                                           [[0], [4, 5], [8]]]))
    self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]],
                                                           [[3], [5, 7], [9]]]))
    self.assertEqual(bert_tokenize.vocab_size.numpy(), 8)

  # Repeat the above and test that case matters with lower_case=False.
  def test_cased(self):
    vocab_file = self._make_vocab_file(
        ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "ABC"])
    bert_tokenize = text_layers.BertTokenizer(
        vocab_file=vocab_file, lower_case=False, tokenize_with_offsets=True)
    inputs = tf.constant(["abc def", "ABC DEF"])
    token_ids, start_offsets, limit_offsets = bert_tokenize(inputs)
    self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]],
                                                       [[7], [1]]]))
    self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]],
                                                           [[0], [4]]]))
    self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]],
                                                           [[3], [7]]]))

  def test_special_tokens_complete(self):
    vocab_file = self._make_vocab_file(
        ["foo", "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "xy"])
    bert_tokenize = text_layers.BertTokenizer(
        vocab_file=vocab_file, lower_case=True)
    self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
                         dict(padding_id=1,
                              start_of_sequence_id=3,
                              end_of_segment_id=4,
                              mask_id=5,
                              vocab_size=7))

  def test_special_tokens_partial(self):
    vocab_file = self._make_vocab_file(
        ["[PAD]", "[CLS]", "[SEP]"])
    bert_tokenize = text_layers.BertTokenizer(
        vocab_file=vocab_file, lower_case=True)
    self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
                         dict(padding_id=0,
                              start_of_sequence_id=1,
                              end_of_segment_id=2,
                              vocab_size=3))  # No mask_id,

  def test_special_tokens_in_estimator(self):
    """Tests getting special tokens without an Eager init context."""
    vocab_file = self._make_vocab_file(
        ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"])

    def input_fn():
      with tf.init_scope():
        self.assertFalse(tf.executing_eagerly())
      # Build a preprocessing Model.
      sentences = tf_keras.layers.Input(shape=[], dtype=tf.string)
      bert_tokenizer = text_layers.BertTokenizer(
          vocab_file=vocab_file, lower_case=True)
      special_tokens_dict = bert_tokenizer.get_special_tokens_dict()
      for k, v in special_tokens_dict.items():
        self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
      tokens = bert_tokenizer(sentences)
      packed_inputs = text_layers.BertPackInputs(
          4, special_tokens_dict=special_tokens_dict)(tokens)
      preprocessing = tf_keras.Model(sentences, packed_inputs)
      # Map the dataset.
      ds = tf.data.Dataset.from_tensors(
          (tf.constant(["abc", "DEF"]), tf.constant([0, 1])))
      ds = ds.map(lambda features, labels: (preprocessing(features), labels))
      return ds

    def model_fn(features, labels, mode):
      del labels  # Unused.
      return tf_estimator.EstimatorSpec(mode=mode,
                                        predictions=features["input_word_ids"])

    estimator = tf_estimator.Estimator(model_fn=model_fn)
    outputs = list(estimator.predict(input_fn))
    self.assertAllEqual(outputs, np.array([[2, 6, 3, 0],
                                           [2, 4, 5, 3]]))


# This test covers the in-process behavior of a SentencepieceTokenizer layer.
class SentencepieceTokenizerTest(tf.test.TestCase):

  def setUp(self):
    super().setUp()
    # Make a sentencepiece model.
    tmp_dir = self.get_temp_dir()
    tempfile.mkdtemp(dir=tmp_dir)
    vocab = ["a", "b", "c", "d", "e", "abc", "def", "ABC", "DEF"]
    model_prefix = os.path.join(tmp_dir, "spm_model")
    input_text_file_path = os.path.join(tmp_dir, "train_input.txt")
    with tf.io.gfile.GFile(input_text_file_path, "w") as f:
      f.write(" ".join(vocab + ["\n"]))
    # Add 7 more tokens: <pad>, <unk>, [CLS], [SEP], [MASK], <s>, </s>.
    full_vocab_size = len(vocab) + 7
    flags = dict(
        model_prefix=model_prefix,
        model_type="word",
        input=input_text_file_path,
        pad_id=0, unk_id=1, control_symbols="[CLS],[SEP],[MASK]",
        vocab_size=full_vocab_size,
        bos_id=full_vocab_size-2, eos_id=full_vocab_size-1)
    SentencePieceTrainer.Train(
        " ".join(["--{}={}".format(k, v) for k, v in flags.items()]))
    self._spm_path = model_prefix + ".model"

  def test_uncased(self):
    sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
        model_file_path=self._spm_path, lower_case=True, nbest_size=0)

    inputs = tf.constant(["abc def", "ABC DEF d"])
    token_ids = sentencepiece_tokenizer(inputs)
    self.assertAllEqual(
        token_ids,
        tf.ragged.constant([[8, 12], [8, 12, 11]]))
    sentencepiece_tokenizer.tokenize_with_offsets = True
    token_ids_2, start_offsets, limit_offsets = sentencepiece_tokenizer(inputs)
    self.assertAllEqual(token_ids, token_ids_2)
    self.assertAllEqual(
        start_offsets, tf.ragged.constant([[0, 3], [0, 3, 7]]))
    self.assertAllEqual(
        limit_offsets, tf.ragged.constant([[3, 7], [3, 7, 9]]))
    self.assertEqual(sentencepiece_tokenizer.vocab_size.numpy(), 16)

  # Repeat the above and test that case matters with lower_case=False.
  def test_cased(self):
    sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
        model_file_path=self._spm_path,
        lower_case=False,
        nbest_size=0,
        tokenize_with_offsets=False)

    inputs = tf.constant(["abc def", "ABC DEF d"])
    token_ids = sentencepiece_tokenizer(inputs)
    self.assertAllEqual(
        token_ids,
        tf.ragged.constant([[8, 12], [5, 6, 11]]))
    sentencepiece_tokenizer.tokenize_with_offsets = True
    token_ids_2, start_offsets, limit_offsets = sentencepiece_tokenizer(inputs)
    self.assertAllEqual(token_ids, token_ids_2)
    self.assertAllEqual(
        start_offsets,
        tf.ragged.constant([[0, 3], [0, 3, 7]]))
    self.assertAllEqual(
        limit_offsets,
        tf.ragged.constant([[3, 7], [3, 7, 9]]))

  def test_special_tokens(self):
    sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
        model_file_path=self._spm_path, lower_case=True, nbest_size=0)
    self.assertDictEqual(sentencepiece_tokenizer.get_special_tokens_dict(),
                         dict(padding_id=0,
                              start_of_sequence_id=2,
                              end_of_segment_id=3,
                              mask_id=4,
                              vocab_size=16))

  def test_special_tokens_in_estimator(self):
    """Tests getting special tokens without an Eager init context."""

    def input_fn():
      with tf.init_scope():
        self.assertFalse(tf.executing_eagerly())
      # Build a preprocessing Model.
      sentences = tf_keras.layers.Input(shape=[], dtype=tf.string)
      sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
          model_file_path=self._spm_path, lower_case=True, nbest_size=0)
      special_tokens_dict = sentencepiece_tokenizer.get_special_tokens_dict()
      for k, v in special_tokens_dict.items():
        self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
      tokens = sentencepiece_tokenizer(sentences)
      packed_inputs = text_layers.BertPackInputs(
          4, special_tokens_dict=special_tokens_dict)(tokens)
      preprocessing = tf_keras.Model(sentences, packed_inputs)
      # Map the dataset.
      ds = tf.data.Dataset.from_tensors(
          (tf.constant(["abc", "DEF"]), tf.constant([0, 1])))
      ds = ds.map(lambda features, labels: (preprocessing(features), labels))
      return ds

    def model_fn(features, labels, mode):
      del labels  # Unused.
      return tf_estimator.EstimatorSpec(mode=mode,
                                        predictions=features["input_word_ids"])

    estimator = tf_estimator.Estimator(model_fn=model_fn)
    outputs = list(estimator.predict(input_fn))
    self.assertAllEqual(outputs, np.array([[2, 8, 3, 0],
                                           [2, 12, 3, 0]]))

  def test_strip_diacritics(self):
    sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
        model_file_path=self._spm_path,
        lower_case=True,
        nbest_size=0,
        strip_diacritics=True)
    inputs = tf.constant(["a b c d e", "ă ḅ č ḓ é"])
    token_ids = sentencepiece_tokenizer(inputs)
    self.assertAllEqual(
        token_ids,
        tf.ragged.constant([[7, 9, 10, 11, 13], [7, 9, 10, 11, 13]]))

  def test_fail_on_tokenize_with_offsets_and_strip_diacritics(self):
    # Raise an error in init().
    with self.assertRaises(ValueError):
      text_layers.SentencepieceTokenizer(
          model_file_path=self._spm_path,
          tokenize_with_offsets=True,
          lower_case=True,
          nbest_size=0,
          strip_diacritics=True)

    sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
        model_file_path=self._spm_path,
        lower_case=True,
        nbest_size=0,
        strip_diacritics=True)
    sentencepiece_tokenizer.tokenize_with_offsets = True

    # Raise an error in call():
    inputs = tf.constant(["abc def", "ABC DEF d", "Äffin"])
    with self.assertRaises(ValueError):
      sentencepiece_tokenizer(inputs)

  def test_serialize_deserialize(self):
    self.skipTest("b/170480226")
    sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
        model_file_path=self._spm_path,
        lower_case=False,
        nbest_size=0,
        tokenize_with_offsets=False,
        name="sentencepiece_tokenizer_layer")
    config = sentencepiece_tokenizer.get_config()
    new_tokenizer = text_layers.SentencepieceTokenizer.from_config(config)
    self.assertEqual(config, new_tokenizer.get_config())
    inputs = tf.constant(["abc def", "ABC DEF d"])
    token_ids = sentencepiece_tokenizer(inputs)
    token_ids_2 = new_tokenizer(inputs)
    self.assertAllEqual(token_ids, token_ids_2)

  # TODO(b/170480226): Remove once tf_hub_export_lib_test.py covers saving.
  def test_saving(self):
    sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
        model_file_path=self._spm_path, lower_case=True, nbest_size=0)
    inputs = tf_keras.layers.Input([], dtype=tf.string)
    outputs = sentencepiece_tokenizer(inputs)
    model = tf_keras.Model(inputs, outputs)
    export_path = tempfile.mkdtemp(dir=self.get_temp_dir())
    model.save(export_path, signatures={})


class BertPackInputsTest(tf.test.TestCase):

  def test_round_robin_correct_outputs(self):
    bpi = text_layers.BertPackInputs(
        10,
        start_of_sequence_id=1001,
        end_of_segment_id=1002,
        padding_id=999,
        truncator="round_robin")
    # Single input, rank 2.
    bert_inputs = bpi(
        tf.ragged.constant([[11, 12, 13],
                            [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]))
    self.assertAllEqual(
        bert_inputs["input_word_ids"],
        tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999],
                     [1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
    self.assertAllEqual(
        bert_inputs["input_mask"],
        tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
                     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
    self.assertAllEqual(
        bert_inputs["input_type_ids"],
        tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))

    # Two inputs, rank 3. Truncation does not respect word boundaries.
    bert_inputs = bpi([
        tf.ragged.constant([[[111], [112, 113]],
                            [[121, 122, 123], [124, 125, 126], [127, 128]]]),
        tf.ragged.constant([[[211, 212], [213]],
                            [[221, 222], [223, 224, 225], [226, 227, 228]]])
    ])
    self.assertAllEqual(
        bert_inputs["input_word_ids"],
        tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 213, 1002, 999],
                     [1001, 121, 122, 123, 124, 1002, 221, 222, 223, 1002]]))
    self.assertAllEqual(
        bert_inputs["input_mask"],
        tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
                     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
    self.assertAllEqual(
        bert_inputs["input_type_ids"],
        tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
                     [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]]))

    # Three inputs. rank 3.
    bert_inputs = bpi([
        tf.ragged.constant([[[111], [112, 113]],
                            [[121, 122, 123], [124, 125, 126], [127, 128]]]),
        tf.ragged.constant([[[211, 212], [213]],
                            [[221, 222], [223, 224, 225], [226, 227, 228]]]),
        tf.ragged.constant([[[311, 312], [313]],
                            [[321, 322], [323, 324, 325], [326, 327, 328]]])
    ])
    self.assertAllEqual(
        bert_inputs["input_word_ids"],
        tf.constant([[1001, 111, 112, 1002, 211, 212, 1002, 311, 312, 1002],
                     [1001, 121, 122, 1002, 221, 222, 1002, 321, 322, 1002]]))

  def test_waterfall_correct_outputs(self):
    bpi = text_layers.BertPackInputs(
        10,
        start_of_sequence_id=1001,
        end_of_segment_id=1002,
        padding_id=999,
        truncator="waterfall")
    # Single input, rank 2.
    bert_inputs = bpi(
        tf.ragged.constant([[11, 12, 13],
                            [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]))
    self.assertAllEqual(
        bert_inputs["input_word_ids"],
        tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999],
                     [1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
    self.assertAllEqual(
        bert_inputs["input_mask"],
        tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
                     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
    self.assertAllEqual(
        bert_inputs["input_type_ids"],
        tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))

    # Two inputs, rank 3. Truncation does not respect word boundaries.
    bert_inputs = bpi([
        tf.ragged.constant([[[111], [112, 113]],
                            [[121, 122, 123], [124, 125, 126], [127, 128]]]),
        tf.ragged.constant([[[211, 212], [213]],
                            [[221, 222], [223, 224, 225], [226, 227, 228]]])
    ])
    self.assertAllEqual(
        bert_inputs["input_word_ids"],
        tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 213, 1002, 999],
                     [1001, 121, 122, 123, 124, 125, 126, 127, 1002, 1002]]))
    self.assertAllEqual(
        bert_inputs["input_mask"],
        tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
                     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
    self.assertAllEqual(
        bert_inputs["input_type_ids"],
        tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
                     [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]))

    # Three inputs, rank 3. Truncation does not respect word boundaries.
    bert_inputs = bpi([
        tf.ragged.constant([[[111], [112, 113]],
                            [[121, 122, 123], [124, 125, 126], [127, 128]]]),
        tf.ragged.constant([[[211], [212]],
                            [[221, 222], [223, 224, 225], [226, 227, 228]]]),
        tf.ragged.constant([[[311, 312], [313]],
                            [[321, 322], [323, 324, 325], [326, 327]]])
    ])
    self.assertAllEqual(
        bert_inputs["input_word_ids"],
        tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 1002, 311, 1002],
                     [1001, 121, 122, 123, 124, 125, 126, 1002, 1002, 1002]]))
    self.assertAllEqual(
        bert_inputs["input_mask"],
        tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
    self.assertAllEqual(
        bert_inputs["input_type_ids"],
        tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 2, 2],
                     [0, 0, 0, 0, 0, 0, 0, 0, 1, 2]]))

  def test_special_tokens_dict(self):
    special_tokens_dict = dict(start_of_sequence_id=1001,
                               end_of_segment_id=1002,
                               padding_id=999,
                               extraneous_key=666)
    bpi = text_layers.BertPackInputs(10,
                                     special_tokens_dict=special_tokens_dict)
    bert_inputs = bpi(
        tf.ragged.constant([[11, 12, 13],
                            [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]))
    self.assertAllEqual(
        bert_inputs["input_word_ids"],
        tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999],
                     [1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))


# This test covers the in-process behavior of FastWordpieceBertTokenizer layer.
class FastWordPieceBertTokenizerTest(tf.test.TestCase):

  def _make_vocab_file(self, vocab, filename="vocab.txt"):
    path = os.path.join(
        tempfile.mkdtemp(dir=self.get_temp_dir()),  # New subdir each time.
        filename)
    with tf.io.gfile.GFile(path, "w") as f:
      f.write("\n".join(vocab + [""]))
    return path

  def test_uncased(self):
    vocab_file = self._make_vocab_file(
        ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"])
    bert_tokenize = text_layers.FastWordpieceBertTokenizer(
        vocab_file=vocab_file, lower_case=True)
    inputs = tf.constant(["abc def", "ABC DEF d"])
    token_ids = bert_tokenize(inputs)
    self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]],
                                                       [[6], [4, 5], [4]]]))
    bert_tokenize.tokenize_with_offsets = True
    token_ids_2, start_offsets, limit_offsets = bert_tokenize(inputs)
    self.assertAllEqual(token_ids, token_ids_2)
    self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]],
                                                           [[0], [4, 5], [8]]]))
    self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]],
                                                           [[3], [5, 7], [9]]]))
    self.assertEqual(bert_tokenize.vocab_size, 8)

  # Repeat the above and test that case matters with lower_case=False.
  def test_cased(self):
    vocab_file = self._make_vocab_file(
        ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "ABC"])
    bert_tokenize = text_layers.FastWordpieceBertTokenizer(
        vocab_file=vocab_file, lower_case=False, tokenize_with_offsets=True)
    inputs = tf.constant(["abc def", "ABC DEF"])
    token_ids, start_offsets, limit_offsets = bert_tokenize(inputs)
    self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]],
                                                       [[7], [1]]]))
    self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]],
                                                           [[0], [4]]]))
    self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]],
                                                           [[3], [7]]]))

  def test_special_tokens_complete(self):
    vocab_file = self._make_vocab_file(
        ["foo", "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "xy"])
    bert_tokenize = text_layers.FastWordpieceBertTokenizer(
        vocab_file=vocab_file, lower_case=True)
    self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
                         dict(padding_id=1,
                              start_of_sequence_id=3,
                              end_of_segment_id=4,
                              mask_id=5,
                              vocab_size=7))

  def test_special_tokens_partial(self):
    # [UNK] token is required by fast wordpiece tokenizer.
    vocab_file = self._make_vocab_file(
        ["[PAD]", "[CLS]", "[SEP]", "[UNK]"])
    bert_tokenize = text_layers.FastWordpieceBertTokenizer(
        vocab_file=vocab_file, lower_case=True)
    self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
                         dict(padding_id=0,
                              start_of_sequence_id=1,
                              end_of_segment_id=2,
                              vocab_size=4))  # No mask_id,

  def test_special_tokens_in_estimator(self):
    """Tests getting special tokens without an Eager init context."""
    vocab_file = self._make_vocab_file(
        ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"])

    def input_fn():
      with tf.init_scope():
        self.assertFalse(tf.executing_eagerly())
      # Build a preprocessing Model.
      sentences = tf_keras.layers.Input(shape=[], dtype=tf.string)
      bert_tokenizer = text_layers.FastWordpieceBertTokenizer(
          vocab_file=vocab_file, lower_case=True)
      special_tokens_dict = bert_tokenizer.get_special_tokens_dict()
      for k, v in special_tokens_dict.items():
        self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
      tokens = bert_tokenizer(sentences)
      packed_inputs = text_layers.BertPackInputs(
          4, special_tokens_dict=special_tokens_dict)(tokens)
      preprocessing = tf_keras.Model(sentences, packed_inputs)
      # Map the dataset.
      ds = tf.data.Dataset.from_tensors(
          (tf.constant(["abc", "DEF"]), tf.constant([0, 1])))
      ds = ds.map(lambda features, labels: (preprocessing(features), labels))
      return ds

    def model_fn(features, labels, mode):
      del labels  # Unused.
      return tf_estimator.EstimatorSpec(mode=mode,
                                        predictions=features["input_word_ids"])

    estimator = tf_estimator.Estimator(model_fn=model_fn)
    outputs = list(estimator.predict(input_fn))
    self.assertAllEqual(outputs, np.array([[2, 6, 3, 0],
                                           [2, 4, 5, 3]]))


if __name__ == "__main__":
  tf.test.main()