tensorflow/models

View on GitHub
official/nlp/data/pretrain_dataloader.py

Summary

Maintainability
C
1 day
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Loads dataset for the BERT pretraining task."""
import dataclasses
from typing import Mapping, Optional

from absl import logging

import numpy as np
import tensorflow as tf, tf_keras
from official.common import dataset_fn
from official.core import config_definitions as cfg
from official.core import input_reader
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory


@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
  """Data config for BERT pretraining task (tasks/masked_lm)."""
  input_path: str = ''
  global_batch_size: int = 512
  is_training: bool = True
  seq_length: int = 512
  max_predictions_per_seq: int = 76
  use_next_sentence_label: bool = True
  use_position_id: bool = False
  # Historically, BERT implementations take `input_ids` and `segment_ids` as
  # feature names. Inside the TF Model Garden implementation, the Keras model
  # inputs are set as `input_word_ids` and `input_type_ids`. When
  # v2_feature_names is True, the data loader assumes the tf.Examples use
  # `input_word_ids` and `input_type_ids` as keys.
  use_v2_feature_names: bool = False
  file_type: str = 'tfrecord'


@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
class BertPretrainDataLoader(data_loader.DataLoader):
  """A class to load dataset for bert pretraining task."""

  def __init__(self, params):
    """Inits `BertPretrainDataLoader` class.

    Args:
      params: A `BertPretrainDataConfig` object.
    """
    self._params = params
    self._seq_length = params.seq_length
    self._max_predictions_per_seq = params.max_predictions_per_seq
    self._use_next_sentence_label = params.use_next_sentence_label
    self._use_position_id = params.use_position_id

  def _name_to_features(self):
    name_to_features = {
        'input_mask':
            tf.io.FixedLenFeature([self._seq_length], tf.int64),
        'masked_lm_positions':
            tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
        'masked_lm_ids':
            tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
        'masked_lm_weights':
            tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32),
    }
    if self._params.use_v2_feature_names:
      name_to_features.update({
          'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
          'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
      })
    else:
      name_to_features.update({
          'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
          'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
      })
    if self._use_next_sentence_label:
      name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
                                                                       tf.int64)
    if self._use_position_id:
      name_to_features['position_ids'] = tf.io.FixedLenFeature(
          [self._seq_length], tf.int64)
    return name_to_features

  def _decode(self, record: tf.Tensor):
    """Decodes a serialized tf.Example."""
    name_to_features = self._name_to_features()
    example = tf.io.parse_single_example(record, name_to_features)

    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
    # So cast all int64 to int32.
    for name in list(example.keys()):
      t = example[name]
      if t.dtype == tf.int64:
        t = tf.cast(t, tf.int32)
      example[name] = t

    return example

  def _parse(self, record: Mapping[str, tf.Tensor]):
    """Parses raw tensors into a dict of tensors to be consumed by the model."""
    x = {
        'input_mask': record['input_mask'],
        'masked_lm_positions': record['masked_lm_positions'],
        'masked_lm_ids': record['masked_lm_ids'],
        'masked_lm_weights': record['masked_lm_weights'],
    }
    if self._params.use_v2_feature_names:
      x['input_word_ids'] = record['input_word_ids']
      x['input_type_ids'] = record['input_type_ids']
    else:
      x['input_word_ids'] = record['input_ids']
      x['input_type_ids'] = record['segment_ids']
    if self._use_next_sentence_label:
      x['next_sentence_labels'] = record['next_sentence_labels']
    if self._use_position_id:
      x['position_ids'] = record['position_ids']

    return x

  def load(self, input_context: Optional[tf.distribute.InputContext] = None):
    """Returns a tf.dataset.Dataset."""
    reader = input_reader.InputReader(
        params=self._params,
        dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
        decoder_fn=self._decode,
        parser_fn=self._parse)
    return reader.read(input_context)


@dataclasses.dataclass
class XLNetPretrainDataConfig(cfg.DataConfig):
  """Data config for XLNet pretraining task.

  Attributes:
    input_path: See base class.
    global_batch_size: See base class.
    is_training: See base class.
    seq_length: The length of each sequence.
    max_predictions_per_seq: The number of predictions per sequence.
    reuse_length: The number of tokens in a previous segment to reuse. This
      should be the same value used during pretrain data creation.
    sample_strategy: The strategy used to sample factorization permutations.
      Possible values: 'single_token', 'whole_word', 'token_span', 'word_span'.
    min_num_tokens: The minimum number of tokens to sample in a span. This is
      used when `sample_strategy` is 'token_span'.
    max_num_tokens: The maximum number of tokens to sample in a span. This is
      used when `sample_strategy` is 'token_span'.
    min_num_words: The minimum number of words to sample in a span. This is used
      when `sample_strategy` is 'word_span'.
    max_num_words: The maximum number of words to sample in a span. This is used
      when `sample_strategy` is 'word_span'.
    permutation_size: The length of the longest permutation. This can be set to
      `reuse_length`. This should NOT be greater than `reuse_length`, otherwise
      this may introduce data leaks.
    leak_ratio: The percentage of masked tokens that are leaked.
    segment_sep_id: The ID of the SEP token used when preprocessing the dataset.
    segment_cls_id: The ID of the CLS token used when preprocessing the dataset.
  """
  input_path: str = ''
  global_batch_size: int = 512
  is_training: bool = True
  seq_length: int = 512
  max_predictions_per_seq: int = 76
  reuse_length: int = 256
  sample_strategy: str = 'word_span'
  min_num_tokens: int = 1
  max_num_tokens: int = 5
  min_num_words: int = 1
  max_num_words: int = 5
  permutation_size: int = 256
  leak_ratio: float = 0.1
  segment_sep_id: int = 4
  segment_cls_id: int = 3


@data_loader_factory.register_data_loader_cls(XLNetPretrainDataConfig)
class XLNetPretrainDataLoader(data_loader.DataLoader):
  """A class to load dataset for xlnet pretraining task."""

  def __init__(self, params: XLNetPretrainDataConfig):
    """Inits `XLNetPretrainDataLoader` class.

    Args:
      params: A `XLNetPretrainDataConfig` object.
    """
    self._params = params
    self._seq_length = params.seq_length
    self._max_predictions_per_seq = params.max_predictions_per_seq
    self._reuse_length = params.reuse_length
    self._num_replicas_in_sync = None
    self._permutation_size = params.permutation_size
    self._sep_id = params.segment_sep_id
    self._cls_id = params.segment_cls_id
    self._sample_strategy = params.sample_strategy
    self._leak_ratio = params.leak_ratio

  def _decode(self, record: tf.Tensor):
    """Decodes a serialized tf.Example."""
    name_to_features = {
        'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
        'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
        'boundary_indices': tf.io.VarLenFeature(tf.int64),
    }
    example = tf.io.parse_single_example(record, name_to_features)

    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
    # So cast all int64 to int32.
    for name in list(example.keys()):
      t = example[name]
      if t.dtype == tf.int64:
        t = tf.cast(t, tf.int32)
      example[name] = t

    return example

  def _parse(self, record: Mapping[str, tf.Tensor]):
    """Parses raw tensors into a dict of tensors to be consumed by the model."""
    x = {}

    inputs = record['input_word_ids']
    x['input_type_ids'] = record['input_type_ids']

    if self._sample_strategy in ['whole_word', 'word_span']:
      boundary = tf.sparse.to_dense(record['boundary_indices'])
    else:
      boundary = None

    input_mask = self._online_sample_mask(inputs=inputs, boundary=boundary)

    if self._reuse_length > 0:
      if self._permutation_size > self._reuse_length:
        logging.warning(
            '`permutation_size` is greater than `reuse_length` (%d > %d).'
            'This may introduce data leakage.', self._permutation_size,
            self._reuse_length)

      # Enable the memory mechanism.
      # Permute the reuse and non-reuse segments separately.
      non_reuse_len = self._seq_length - self._reuse_length
      if not (self._reuse_length % self._permutation_size == 0 and
              non_reuse_len % self._permutation_size == 0):
        raise ValueError('`reuse_length` and `seq_length` should both be '
                         'a multiple of `permutation_size`.')

      # Creates permutation mask and target mask for the first reuse_len tokens.
      # The tokens in this part are reused from the last sequence.
      perm_mask_0, target_mask_0, tokens_0, masked_0 = self._get_factorization(
          inputs=inputs[:self._reuse_length],
          input_mask=input_mask[:self._reuse_length])

      # Creates permutation mask and target mask for the rest of tokens in
      # current example, which are concatenation of two new segments.
      perm_mask_1, target_mask_1, tokens_1, masked_1 = self._get_factorization(
          inputs[self._reuse_length:], input_mask[self._reuse_length:])

      perm_mask_0 = tf.concat([
          perm_mask_0,
          tf.zeros([self._reuse_length, non_reuse_len], dtype=tf.int32)
      ],
                              axis=1)
      perm_mask_1 = tf.concat([
          tf.ones([non_reuse_len, self._reuse_length], dtype=tf.int32),
          perm_mask_1
      ],
                              axis=1)
      perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
      target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
      tokens = tf.concat([tokens_0, tokens_1], axis=0)
      masked_tokens = tf.concat([masked_0, masked_1], axis=0)
    else:
      # Disable the memory mechanism.
      if self._seq_length % self._permutation_size != 0:
        raise ValueError('`seq_length` should be a multiple of '
                         '`permutation_size`.')
      # Permute the entire sequence together
      perm_mask, target_mask, tokens, masked_tokens = self._get_factorization(
          inputs=inputs, input_mask=input_mask)
    x['permutation_mask'] = tf.reshape(perm_mask,
                                       [self._seq_length, self._seq_length])
    x['input_word_ids'] = tokens
    x['masked_tokens'] = masked_tokens

    target = tokens
    if self._max_predictions_per_seq is not None:
      indices = tf.range(self._seq_length, dtype=tf.int32)
      bool_target_mask = tf.cast(target_mask, tf.bool)
      indices = tf.boolean_mask(indices, bool_target_mask)

      # account for extra padding due to CLS/SEP.
      actual_num_predict = tf.shape(indices)[0]
      pad_len = self._max_predictions_per_seq - actual_num_predict

      target_mapping = tf.one_hot(indices, self._seq_length, dtype=tf.int32)
      paddings = tf.zeros([pad_len, self._seq_length],
                          dtype=target_mapping.dtype)
      target_mapping = tf.concat([target_mapping, paddings], axis=0)
      x['target_mapping'] = tf.reshape(
          target_mapping, [self._max_predictions_per_seq, self._seq_length])

      target = tf.boolean_mask(target, bool_target_mask)
      paddings = tf.zeros([pad_len], dtype=target.dtype)
      target = tf.concat([target, paddings], axis=0)
      x['target'] = tf.reshape(target, [self._max_predictions_per_seq])

      target_mask = tf.concat([
          tf.ones([actual_num_predict], dtype=tf.int32),
          tf.zeros([pad_len], dtype=tf.int32)
      ],
                              axis=0)
      x['target_mask'] = tf.reshape(target_mask,
                                    [self._max_predictions_per_seq])
    else:
      x['target'] = tf.reshape(target, [self._seq_length])
      x['target_mask'] = tf.reshape(target_mask, [self._seq_length])
    return x

  def _index_pair_to_mask(self, begin_indices: tf.Tensor,
                          end_indices: tf.Tensor,
                          inputs: tf.Tensor) -> tf.Tensor:
    """Converts beginning and end indices into an actual mask."""
    non_func_mask = tf.logical_and(
        tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
    all_indices = tf.where(
        non_func_mask, tf.range(self._seq_length, dtype=tf.int32),
        tf.constant(-1, shape=[self._seq_length], dtype=tf.int32))
    candidate_matrix = tf.cast(
        tf.logical_and(all_indices[None, :] >= begin_indices[:, None],
                       all_indices[None, :] < end_indices[:, None]), tf.float32)
    cumsum_matrix = tf.reshape(
        tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, self._seq_length])
    masked_matrix = tf.cast(cumsum_matrix <= self._max_predictions_per_seq,
                            tf.float32)
    target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
    return tf.cast(target_mask, tf.bool)

  def _single_token_mask(self, inputs: tf.Tensor) -> tf.Tensor:
    """Samples individual tokens as prediction targets."""
    all_indices = tf.range(self._seq_length, dtype=tf.int32)
    non_func_mask = tf.logical_and(
        tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
    non_func_indices = tf.boolean_mask(all_indices, non_func_mask)

    masked_pos = tf.random.shuffle(non_func_indices)
    masked_pos = tf.sort(masked_pos[:self._max_predictions_per_seq])

    sparse_indices = tf.stack([tf.zeros_like(masked_pos), masked_pos], axis=-1)
    sparse_indices = tf.cast(sparse_indices, tf.int64)

    sparse_indices = tf.sparse.SparseTensor(
        sparse_indices,
        values=tf.ones_like(masked_pos),
        dense_shape=(1, self._seq_length))

    target_mask = tf.sparse.to_dense(sp_input=sparse_indices, default_value=0)

    return tf.squeeze(tf.cast(target_mask, tf.bool))

  def _whole_word_mask(self, inputs: tf.Tensor,
                       boundary: tf.Tensor) -> tf.Tensor:
    """Samples whole words as prediction targets."""
    pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
    cand_pair_indices = tf.random.shuffle(
        pair_indices)[:self._max_predictions_per_seq]
    begin_indices = cand_pair_indices[:, 0]
    end_indices = cand_pair_indices[:, 1]

    return self._index_pair_to_mask(
        begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)

  def _token_span_mask(self, inputs: tf.Tensor) -> tf.Tensor:
    """Samples token spans as prediction targets."""
    min_num_tokens = self._params.min_num_tokens
    max_num_tokens = self._params.max_num_tokens

    mask_alpha = self._seq_length / self._max_predictions_per_seq
    round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)

    # Sample span lengths from a zipf distribution
    span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
    probs = np.array([1.0 / (i + 1) for i in span_len_seq])

    probs /= np.sum(probs)
    logits = tf.constant(np.log(probs), dtype=tf.float32)
    span_lens = tf.random.categorical(
        logits=logits[None],
        num_samples=self._max_predictions_per_seq,
        dtype=tf.int32,
    )[0] + min_num_tokens

    # Sample the ratio [0.0, 1.0) of left context lengths
    span_lens_float = tf.cast(span_lens, tf.float32)
    left_ratio = tf.random.uniform(
        shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
    left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
    left_ctx_len = round_to_int(left_ctx_len)

    # Compute the offset from left start to the right end
    right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len

    # Get the actual begin and end indices
    begin_indices = (
        tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
    end_indices = begin_indices + span_lens

    # Remove out of range indices
    valid_idx_mask = end_indices < self._seq_length
    begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
    end_indices = tf.boolean_mask(end_indices, valid_idx_mask)

    # Shuffle valid indices
    num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
    order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
    begin_indices = tf.gather(begin_indices, order)
    end_indices = tf.gather(end_indices, order)

    return self._index_pair_to_mask(
        begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)

  def _word_span_mask(self, inputs: tf.Tensor, boundary: tf.Tensor):
    """Sample whole word spans as prediction targets."""
    min_num_words = self._params.min_num_words
    max_num_words = self._params.max_num_words

    # Note: 1.2 is the token-to-word ratio
    mask_alpha = self._seq_length / self._max_predictions_per_seq / 1.2
    round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)

    # Sample span lengths from a zipf distribution
    span_len_seq = np.arange(min_num_words, max_num_words + 1)
    probs = np.array([1.0 / (i + 1) for i in span_len_seq])
    probs /= np.sum(probs)
    logits = tf.constant(np.log(probs), dtype=tf.float32)

    # Sample `num_predict` words here: note that this is over sampling
    span_lens = tf.random.categorical(
        logits=logits[None],
        num_samples=self._max_predictions_per_seq,
        dtype=tf.int32,
    )[0] + min_num_words

    # Sample the ratio [0.0, 1.0) of left context lengths
    span_lens_float = tf.cast(span_lens, tf.float32)
    left_ratio = tf.random.uniform(
        shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
    left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)

    left_ctx_len = round_to_int(left_ctx_len)
    right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len

    begin_indices = (
        tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
    end_indices = begin_indices + span_lens

    # Remove out of range indices
    max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int32)
    valid_idx_mask = end_indices < max_boundary_index
    begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
    end_indices = tf.boolean_mask(end_indices, valid_idx_mask)

    begin_indices = tf.gather(boundary, begin_indices)
    end_indices = tf.gather(boundary, end_indices)

    # Shuffle valid indices
    num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
    order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
    begin_indices = tf.gather(begin_indices, order)
    end_indices = tf.gather(end_indices, order)

    return self._index_pair_to_mask(
        begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)

  def _online_sample_mask(self, inputs: tf.Tensor,
                          boundary: tf.Tensor) -> tf.Tensor:
    """Samples target positions for predictions.

    Descriptions of each strategy:
      - 'single_token': Samples individual tokens as prediction targets.
      - 'token_span': Samples spans of tokens as prediction targets.
      - 'whole_word': Samples individual words as prediction targets.
      - 'word_span': Samples spans of words as prediction targets.

    Args:
      inputs: The input tokens.
      boundary: The `int` Tensor of indices indicating whole word boundaries.
        This is used in 'whole_word' and 'word_span'

    Returns:
      The sampled `bool` input mask.

    Raises:
      `ValueError`: if `max_predictions_per_seq` is not set or if boundary is
        not provided for 'whole_word' and 'word_span' sample strategies.
    """
    if self._max_predictions_per_seq is None:
      raise ValueError('`max_predictions_per_seq` must be set.')

    if boundary is None and 'word' in self._sample_strategy:
      raise ValueError('`boundary` must be provided for {} strategy'.format(
          self._sample_strategy))

    if self._sample_strategy == 'single_token':
      return self._single_token_mask(inputs)
    elif self._sample_strategy == 'token_span':
      return self._token_span_mask(inputs)
    elif self._sample_strategy == 'whole_word':
      return self._whole_word_mask(inputs, boundary)
    elif self._sample_strategy == 'word_span':
      return self._word_span_mask(inputs, boundary)
    else:
      raise NotImplementedError('Invalid sample strategy.')

  def _get_factorization(self, inputs: tf.Tensor, input_mask: tf.Tensor):
    """Samples a permutation of the factorization order.

    Args:
      inputs: the input tokens.
      input_mask: the `bool` Tensor of the same shape as `inputs`. If `True`,
        then this means select for partial prediction.

    Returns:
      perm_mask: An `int32` Tensor of shape [seq_length, seq_length] consisting
        of 0s and 1s. If perm_mask[i][j] == 0, then this means that the i-th
        token (in original order) cannot attend to the jth attention token.
      target_mask: An `int32` Tensor of shape [seq_len] consisting of 0s and 1s.
        If target_mask[i] == 1, then the i-th token needs to be predicted and
        the mask will be used as input. This token will be included in the loss.
        If target_mask[i] == 0, then the token (or [SEP], [CLS]) will be used as
        input. This token will not be included in the loss.
      tokens: int32 Tensor of shape [seq_length].
      masked_tokens: int32 Tensor of shape [seq_length].
    """
    factorization_length = tf.shape(inputs)[0]
    # Generate permutation indices
    index = tf.range(factorization_length, dtype=tf.int32)
    index = tf.transpose(tf.reshape(index, [-1, self._permutation_size]))
    index = tf.random.shuffle(index)
    index = tf.reshape(tf.transpose(index), [-1])

    input_mask = tf.cast(input_mask, tf.bool)

    # non-functional tokens
    non_func_tokens = tf.logical_not(
        tf.logical_or(
            tf.equal(inputs, self._sep_id), tf.equal(inputs, self._cls_id)))
    masked_tokens = tf.logical_and(input_mask, non_func_tokens)
    non_masked_or_func_tokens = tf.logical_not(masked_tokens)

    smallest_index = -2 * tf.ones([factorization_length], dtype=tf.int32)

    # Similar to BERT, randomly leak some masked tokens
    if self._leak_ratio > 0:
      leak_tokens = tf.logical_and(
          masked_tokens,
          tf.random.uniform([factorization_length], maxval=1.0) <
          self._leak_ratio)
      can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
    else:
      can_attend_self = non_masked_or_func_tokens
    to_index = tf.where(can_attend_self, smallest_index, index)
    from_index = tf.where(can_attend_self, to_index + 1, to_index)

    # For masked tokens, can attend if i > j
    # For context tokens, always can attend each other
    can_attend = from_index[:, None] > to_index[None, :]

    perm_mask = tf.cast(can_attend, tf.int32)

    # Only masked tokens are included in the loss
    target_mask = tf.cast(masked_tokens, tf.int32)

    return perm_mask, target_mask, inputs, masked_tokens

  def load(self, input_context: Optional[tf.distribute.InputContext] = None):
    """Returns a tf.dataset.Dataset."""
    if input_context:
      self._num_replicas_in_sync = input_context.num_replicas_in_sync
    reader = input_reader.InputReader(
        params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
    return reader.read(input_context)