tensorflow/models

View on GitHub
official/legacy/xlnet/preprocess_pretrain_data.py

Summary

Maintainability
F
1 wk
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-
"""Script to pre-process pre-training data into tfrecords."""

import json
import os
import random

# Import libraries
from absl import app
from absl import flags
from absl import logging

import numpy as np

import tensorflow.compat.v1 as tf
import sentencepiece as spm
from official.legacy.xlnet import preprocess_utils

FLAGS = flags.FLAGS


special_symbols = {
    "<unk>": 0,
    "<s>": 1,
    "</s>": 2,
    "<cls>": 3,
    "<sep>": 4,
    "<pad>": 5,
    "<mask>": 6,
    "<eod>": 7,
    "<eop>": 8,
}

VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]


def _int64_feature(values):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def _float_feature(values):
  return tf.train.Feature(float_list=tf.train.FloatList(value=values))


def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix,
                    mask_alpha=5, mask_beta=1, reuse_len=None, uncased=False,
                    fixed_num_predict=None):
  """docs."""
  if reuse_len is None:
    reuse_len_str = ""
  else:
    reuse_len_str = "reuse-{}.".format(reuse_len)
  if not uncased:
    uncased_str = ""
  else:
    uncased_str = "uncased."
  if bi_data:
    bi_data_str = "bi"
  else:
    bi_data_str = "uni"
  if fixed_num_predict is not None:
    fnp_str = "fnp-{}.".format(fixed_num_predict)
  else:
    fnp_str = ""

  file_name = "{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}".format(
      prefix, bsz_per_host, seq_len, reuse_len_str, uncased_str, bi_data_str,
      mask_alpha, mask_beta, fnp_str, suffix)

  return file_name


def _create_data(idx, input_paths):
  """Creates data."""
  # Load sentence-piece model
  sp = spm.SentencePieceProcessor()
  sp.Load(FLAGS.sp_path)

  input_shards = []
  total_line_cnt = 0
  for input_path in input_paths:
    input_data, sent_ids = [], []
    sent_id, line_cnt = True, 0
    logging.info("Processing %s", input_path)
    for line in tf.gfile.Open(input_path):
      if line_cnt % 100000 == 0:
        logging.info("Loading line %d", line_cnt)
      line_cnt += 1

      if not line.strip():
        if FLAGS.use_eod:
          sent_id = not sent_id
          cur_sent = [EOD_ID]
        else:
          continue
      else:
        if FLAGS.from_raw_text:
          cur_sent = preprocess_utils.preprocess_text(
              line.strip(), lower=FLAGS.uncased)
          cur_sent = preprocess_utils.encode_ids(sp, cur_sent)
        else:
          cur_sent = list(map(int, line.strip().split()))

      input_data.extend(cur_sent)
      sent_ids.extend([sent_id] * len(cur_sent))
      sent_id = not sent_id

    logging.info("Finish with line %d", line_cnt)
    if line_cnt == 0:
      continue

    input_data = np.array(input_data, dtype=np.int64)
    sent_ids = np.array(sent_ids, dtype=bool)

    total_line_cnt += line_cnt
    input_shards.append((input_data, sent_ids))

  logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)

  tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")

  filenames, num_batch = [], 0

  # Randomly shuffle input shards (with a fixed but distinct random seed)
  np.random.seed(100 * FLAGS.task + FLAGS.pass_id)

  perm_indices = np.random.permutation(len(input_shards))
  logging.info("Using perm indices %s for pass %d",
               perm_indices.tolist(), FLAGS.pass_id)

  input_data_list, sent_ids_list = [], []
  prev_sent_id = None
  for perm_idx in perm_indices:
    input_data, sent_ids = input_shards[perm_idx]
    # make sure the `send_ids[0] == not prev_sent_id`
    if prev_sent_id is not None and sent_ids[0] == prev_sent_id:
      sent_ids = np.logical_not(sent_ids)

    # append to temporary list
    input_data_list.append(input_data)
    sent_ids_list.append(sent_ids)

    # update `prev_sent_id`
    prev_sent_id = sent_ids[-1]

  input_data = np.concatenate(input_data_list)
  sent_ids = np.concatenate(sent_ids_list)

  file_name, cur_num_batch = create_tfrecords(
      save_dir=tfrecord_dir,
      basename="{}-{}-{}".format(FLAGS.split, idx, FLAGS.pass_id),
      data=[input_data, sent_ids],
      bsz_per_host=FLAGS.bsz_per_host,
      seq_len=FLAGS.seq_len,
      bi_data=FLAGS.bi_data,
      sp=sp,
  )

  filenames.append(file_name)
  num_batch += cur_num_batch

  record_info = {
      "filenames": filenames,
      "num_batch": num_batch
  }

  return record_info


def create_data(_):
  """Creates pretrain data."""
  # Validate FLAGS
  assert FLAGS.bsz_per_host % FLAGS.num_core_per_host == 0
  if not FLAGS.use_tpu:
    FLAGS.num_core_per_host = 1  # forced to be one

  # Make workdirs
  if not tf.gfile.Exists(FLAGS.save_dir):
    tf.gfile.MakeDirs(FLAGS.save_dir)

  tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")
  if not tf.gfile.Exists(tfrecord_dir):
    tf.gfile.MakeDirs(tfrecord_dir)

  # Create and dump corpus_info from task 0
  if FLAGS.task == 0 and FLAGS.pass_id == 0:
    corpus_info = {
        "vocab_size": VOCAB_SIZE,
        "bsz_per_host": FLAGS.bsz_per_host,
        "num_core_per_host": FLAGS.num_core_per_host,
        "seq_len": FLAGS.seq_len,
        "reuse_len": FLAGS.reuse_len,
        "uncased": FLAGS.uncased,
        "bi_data": FLAGS.bi_data,
        "mask_alpha": FLAGS.mask_alpha,
        "mask_beta": FLAGS.mask_beta,
        "num_predict": FLAGS.num_predict,
        "use_eod": FLAGS.use_eod,
        "sp_path": FLAGS.sp_path,
        "input_glob": FLAGS.input_glob,
    }
    corpus_info_path = os.path.join(FLAGS.save_dir, "corpus_info.json")
    with tf.gfile.Open(corpus_info_path, "w") as fp:
      json.dump(corpus_info, fp)

  # Interleavely split the work into FLAGS.num_task splits
  file_paths = sorted(tf.gfile.Glob(FLAGS.input_glob))
  logging.info("Use glob: %s", FLAGS.input_glob)
  logging.info("Find %d files: %s", len(file_paths), file_paths)

  task_file_paths = file_paths[FLAGS.task::FLAGS.num_task]
  if not task_file_paths:
    logging.info("Exit: task %d has no file to process.", FLAGS.task)
    return

  logging.info("Task %d process %d files: %s",
               FLAGS.task, len(task_file_paths), task_file_paths)
  record_info = _create_data(FLAGS.task, task_file_paths)

  record_prefix = "record_info-{}-{}-{}".format(
      FLAGS.split, FLAGS.task, FLAGS.pass_id)
  record_name = format_filename(
      prefix=record_prefix,
      bsz_per_host=FLAGS.bsz_per_host,
      seq_len=FLAGS.seq_len,
      mask_alpha=FLAGS.mask_alpha,
      mask_beta=FLAGS.mask_beta,
      reuse_len=FLAGS.reuse_len,
      bi_data=FLAGS.bi_data,
      suffix="json",
      uncased=FLAGS.uncased,
      fixed_num_predict=FLAGS.num_predict)
  record_info_path = os.path.join(tfrecord_dir, record_name)

  with tf.gfile.Open(record_info_path, "w") as fp:
    json.dump(record_info, fp)


def batchify(data, bsz_per_host, sent_ids=None):
  """Creates batches."""
  num_step = len(data) // bsz_per_host
  data = data[:bsz_per_host * num_step]
  data = data.reshape(bsz_per_host, num_step)
  if sent_ids is not None:
    sent_ids = sent_ids[:bsz_per_host * num_step]
    sent_ids = sent_ids.reshape(bsz_per_host, num_step)

  if sent_ids is not None:
    return data, sent_ids
  return data


def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
  """Split two segments from `data` starting from the index `begin_idx`."""

  data_len = data.shape[0]
  if begin_idx + tot_len >= data_len:
    logging.info("[_split_a_and_b] returns None: "
                 "begin_idx %d + tot_len %d >= data_len %d",
                 begin_idx, tot_len, data_len)
    return None

  end_idx = begin_idx + 1
  cut_points = []
  while end_idx < data_len:
    if sent_ids[end_idx] != sent_ids[end_idx - 1]:
      if end_idx - begin_idx >= tot_len: break
      cut_points.append(end_idx)
    end_idx += 1

  a_begin = begin_idx
  if len(cut_points) == 0 or random.random() < 0.5:  # pylint:disable=g-explicit-length-test
    label = 0
    if len(cut_points) == 0:  # pylint:disable=g-explicit-length-test
      a_end = end_idx
    else:
      a_end = random.choice(cut_points)

    b_len = max(1, tot_len - (a_end - a_begin))
    # (zihangd): `data_len - 1` to account for extend_target
    b_begin = random.randint(0, data_len - 1 - b_len)
    b_end = b_begin + b_len
    while b_begin > 0 and sent_ids[b_begin - 1] == sent_ids[b_begin]:
      b_begin -= 1
    # (zihangd): `data_len - 1` to account for extend_target
    while b_end < data_len - 1 and sent_ids[b_end - 1] == sent_ids[b_end]:
      b_end += 1

    new_begin = a_end
  else:
    label = 1
    a_end = random.choice(cut_points)
    b_begin = a_end
    b_end = end_idx

    new_begin = b_end

  while a_end - a_begin + b_end - b_begin > tot_len:
    if a_end - a_begin > b_end - b_begin:
      # delete the right side only for the LM objective
      a_end -= 1
    else:
      b_end -= 1

  ret = [data[a_begin: a_end], data[b_begin: b_end], label, new_begin]

  if extend_target:
    if a_end >= data_len or b_end >= data_len:
      logging.info("[_split_a_and_b] returns None: "
                   "a_end %d or b_end %d >= data_len %d",
                   a_end, b_end, data_len)
      return None
    a_target = data[a_begin + 1: a_end + 1]
    b_target = data[b_begin: b_end + 1]
    ret.extend([a_target, b_target])

  return ret


def _is_start_piece(piece):
  special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~'))
  if (piece.startswith("▁") or piece.startswith("<")
      or piece in special_pieces):
    return True
  else:
    return False


def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
  """Samples `goal_num_predict` tokens for partial prediction."""
  seg_len = len(seg)
  mask = np.array([False] * seg_len, dtype=bool)

  num_predict = 0

  ngrams = np.arange(1, max_gram + 1, dtype=np.int64)
  pvals = 1. / np.arange(1, max_gram + 1)
  pvals /= pvals.sum(keepdims=True)

  if reverse:
    seg = np.flip(seg, 0)

  cur_len = 0
  while cur_len < seg_len:
    if goal_num_predict is not None and num_predict >= goal_num_predict: break

    n = np.random.choice(ngrams, p=pvals)
    if goal_num_predict is not None:
      n = min(n, goal_num_predict - num_predict)
    ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta
    l_ctx = np.random.choice(ctx_size)
    r_ctx = ctx_size - l_ctx

    # Find the start position of a complete token
    beg = cur_len + l_ctx
    while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())):
      beg += 1
    if beg >= seg_len:
      break

    # Find the end position of the n-gram (start pos of the n+1-th gram)
    end = beg + 1
    cnt_ngram = 1
    while end < seg_len:
      cnt_ngram += 1
      if cnt_ngram > n:
        break
      end += 1
    if end >= seg_len:
      break

    # Update
    mask[beg:end] = True
    num_predict += end - beg

    cur_len = end + r_ctx

  while goal_num_predict is not None and num_predict < goal_num_predict:
    i = np.random.randint(seg_len)
    if not mask[i]:
      mask[i] = True
      num_predict += 1

  if reverse:
    mask = np.flip(mask, 0)

  return mask


def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5,
                       goal_num_predict=None):
  """Sample `goal_num_predict` tokens for partial prediction."""

  seg_len = len(seg)
  mask = np.array([False] * seg_len, dtype=bool)

  num_predict = 0

  ngrams = np.arange(1, max_gram + 1, dtype=np.int64)
  pvals = 1. / np.arange(1, max_gram + 1)
  pvals /= pvals.sum(keepdims=True)

  if reverse:
    seg = np.flip(seg, 0)

  cur_len = 0
  while cur_len < seg_len:
    if goal_num_predict is not None and num_predict >= goal_num_predict: break

    n = np.random.choice(ngrams, p=pvals)
    if goal_num_predict is not None:
      n = min(n, goal_num_predict - num_predict)
    ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta
    l_ctx = np.random.choice(ctx_size)
    r_ctx = ctx_size - l_ctx

    # Find the start position of a complete token
    beg = cur_len + l_ctx
    while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())):
      beg += 1
    if beg >= seg_len:
      break

    # Find the end position of the n-gram (start pos of the n+1-th gram)
    end = beg
    cnt_ngram = 0
    while end < seg_len:
      if _is_start_piece(sp.IdToPiece(seg[end].item())):
        cnt_ngram += 1
        if cnt_ngram > n:
          break

      # select current piece
      mask[end] = True

      # update the end pointer and increment num_predict
      end += 1
      num_predict += 1

      if goal_num_predict is not None and num_predict >= goal_num_predict:
        break

    cur_len = end + r_ctx

  while goal_num_predict is not None and num_predict < goal_num_predict:
    i = np.random.randint(seg_len)
    if not mask[i]:
      mask[i] = True
      num_predict += 1

  if reverse:
    mask = np.flip(mask, 0)

  return mask


def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
                     bi_data, sp):
  """Creates TFRecords."""
  data, sent_ids = data[0], data[1]

  num_core = FLAGS.num_core_per_host
  bsz_per_core = bsz_per_host // num_core

  if bi_data:
    assert bsz_per_host % (2 * FLAGS.num_core_per_host) == 0
    fwd_data, fwd_sent_ids = batchify(data, bsz_per_host // 2, sent_ids)

    fwd_data = fwd_data.reshape(num_core, 1, bsz_per_core // 2, -1)
    fwd_sent_ids = fwd_sent_ids.reshape(num_core, 1, bsz_per_core // 2, -1)

    bwd_data = fwd_data[:, :, :, ::-1]
    bwd_sent_ids = fwd_sent_ids[:, :, :, ::-1]

    data = np.concatenate(
        [fwd_data, bwd_data], 1).reshape(bsz_per_host, -1)
    sent_ids = np.concatenate(
        [fwd_sent_ids, bwd_sent_ids], 1).reshape(bsz_per_host, -1)
  else:
    data, sent_ids = batchify(data, bsz_per_host, sent_ids)

  logging.info("Raw data shape %s.", data.shape)

  file_name = format_filename(
      prefix=basename,
      bsz_per_host=bsz_per_host,
      seq_len=seq_len,
      bi_data=bi_data,
      suffix="tfrecords",
      mask_alpha=FLAGS.mask_alpha,
      mask_beta=FLAGS.mask_beta,
      reuse_len=FLAGS.reuse_len,
      uncased=FLAGS.uncased,
      fixed_num_predict=FLAGS.num_predict
  )
  save_path = os.path.join(save_dir, file_name)
  record_writer = tf.python_io.TFRecordWriter(save_path)
  logging.info("Start writing %s.", save_path)

  num_batch = 0
  reuse_len = FLAGS.reuse_len

  # [sep] x 2 + [cls]
  assert reuse_len < seq_len - 3

  data_len = data.shape[1]
  sep_array = np.array([SEP_ID], dtype=np.int64)
  cls_array = np.array([CLS_ID], dtype=np.int64)

  i = 0
  while i + seq_len <= data_len:
    if num_batch % 500 == 0:
      logging.info("Processing batch %d", num_batch)

    all_ok = True
    features = []
    for idx in range(bsz_per_host):
      inp = data[idx, i: i + reuse_len]
      tgt = data[idx, i + 1: i + reuse_len + 1]

      results = _split_a_and_b(
          data[idx],
          sent_ids[idx],
          begin_idx=i + reuse_len,
          tot_len=seq_len - reuse_len - 3,
          extend_target=True)
      if results is None:
        logging.info("Break out with seq idx %d", i)
        all_ok = False
        break

      # unpack the results
      (a_data, b_data, label, _, a_target, b_target) = tuple(results)

      # sample ngram spans to predict
      reverse = bi_data and (idx // (bsz_per_core // 2)) % 2 == 1
      if FLAGS.num_predict is None:
        num_predict_0 = num_predict_1 = None
      else:
        num_predict_1 = FLAGS.num_predict // 2
        num_predict_0 = FLAGS.num_predict - num_predict_1
      mask_0 = _sample_mask(sp, inp, reverse=reverse,
                            goal_num_predict=num_predict_0)
      mask_1 = _sample_mask(sp, np.concatenate([a_data, sep_array, b_data,
                                                sep_array, cls_array]),
                            reverse=reverse, goal_num_predict=num_predict_1)

      # concatenate data
      cat_data = np.concatenate([inp, a_data, sep_array, b_data,
                                 sep_array, cls_array])
      seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] +
                [1] * b_data.shape[0] + [1] + [2])
      assert cat_data.shape[0] == seq_len
      assert mask_0.shape[0] == seq_len // 2
      assert mask_1.shape[0] == seq_len // 2

      # the last two CLS's are not used, just for padding purposes
      tgt = np.concatenate([tgt, a_target, b_target, cls_array, cls_array])
      assert tgt.shape[0] == seq_len

      is_masked = np.concatenate([mask_0, mask_1], 0)
      if FLAGS.num_predict is not None:
        assert np.sum(is_masked) == FLAGS.num_predict

      feature = {
          "input": _int64_feature(cat_data),
          "is_masked": _int64_feature(is_masked),
          "target": _int64_feature(tgt),
          "seg_id": _int64_feature(seg_id),
          "label": _int64_feature([label]),
      }
      features.append(feature)

    if all_ok:
      assert len(features) == bsz_per_host
      for feature in features:
        example = tf.train.Example(features=tf.train.Features(feature=feature))
        record_writer.write(example.SerializeToString())
      num_batch += 1
    else:
      break

    i += reuse_len

  record_writer.close()
  logging.info("Done writing %s. Num of batches: %d", save_path, num_batch)

  return save_path, num_batch


################
# get_input_fn #
################
def _convert_example(example, use_bfloat16):
  """Cast int64 into int32 and float32 to bfloat16 if use_bfloat16."""
  for key in list(example.keys()):
    val = example[key]
    if tf_keras.backend.is_sparse(val):
      val = tf.sparse.to_dense(val)
    if val.dtype == tf.int64:
      val = tf.cast(val, tf.int32)
    if use_bfloat16 and val.dtype == tf.float32:
      val = tf.cast(val, tf.bfloat16)

    example[key] = val


def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
                           host_id, num_core_per_host, bsz_per_core):
  """Parses files to a dataset."""
  del num_batch
  # list of file pathes
  num_files = len(file_names)
  num_files_per_host = num_files // num_hosts
  my_start_file_id = host_id * num_files_per_host
  my_end_file_id = (host_id + 1) * num_files_per_host
  if host_id == num_hosts - 1:
    my_end_file_id = num_files
  file_paths = file_names[my_start_file_id: my_end_file_id]
  logging.info("Host %d handles %d files", host_id, len(file_paths))

  assert split == "train"
  dataset = tf.data.Dataset.from_tensor_slices(file_paths)

  # file-level shuffle
  if len(file_paths) > 1:
    dataset = dataset.shuffle(len(file_paths))

  # Note: we cannot perform sample-level shuffle here because this will violate
  # the consecutive requirement of data stream.
  dataset = tf.data.TFRecordDataset(dataset)

  # Note: since we are doing online preprocessing, the parsed result of
  # the same input at each time will be different. Thus, cache processed data
  # is not helpful. It will use a lot of memory and lead to contrainer OOM.
  # So, change to cache non-parsed raw data instead.
  dataset = dataset.cache().map(parser).repeat()
  dataset = dataset.batch(bsz_per_core, drop_remainder=True)
  dataset = dataset.prefetch(num_core_per_host * bsz_per_core)

  return dataset


def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
  """Samples a permutation of the factorization order, and create a mask.

  Args:
    inputs: int64 Tensor in shape [seq_len], input ids.
    targets: int64 Tensor in shape [seq_len], target ids.
    is_masked: bool Tensor in shape [seq_len]. True means being selected
      for partial prediction.
    perm_size: the length of longest permutation. Could be set to be reuse_len.
      Should not be larger than reuse_len or there will be data leaks.
    seq_len: int, sequence length.

  Returns:
    The permutation mask, new targets, target mask, and new inputs.

  """

  # Generate permutation indices
  index = tf.range(seq_len, dtype=tf.int64)
  index = tf.transpose(tf.reshape(index, [-1, perm_size]))
  index = tf.random_shuffle(index)
  index = tf.reshape(tf.transpose(index), [-1])

  # `perm_mask` and `target_mask`
  # non-functional tokens
  non_func_tokens = tf.logical_not(tf.logical_or(
      tf.equal(inputs, SEP_ID),
      tf.equal(inputs, CLS_ID)))

  non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
  masked_or_func_tokens = tf.logical_not(non_mask_tokens)

  # Set the permutation indices of non-masked (& non-funcional) tokens to the
  # smallest index (-1):
  # (1) they can be seen by all other positions
  # (2) they cannot see masked positions, so there won"t be information leak
  smallest_index = -tf.ones([seq_len], dtype=tf.int64)
  rev_index = tf.where(non_mask_tokens, smallest_index, index)

  # Create `target_mask`: non-funcional and maksed tokens
  # 1: use mask as input and have loss
  # 0: use token (or [SEP], [CLS]) as input and do not have loss
  target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
  target_mask = tf.cast(target_tokens, tf.float32)

  # Create `perm_mask`
  # `target_tokens` cannot see themselves
  self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

  # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
  # 0: can attend if i > j or j is non-masked
  perm_mask = tf.logical_and(
      self_rev_index[:, None] <= rev_index[None, :],
      masked_or_func_tokens)
  perm_mask = tf.cast(perm_mask, tf.float32)

  # new target: [next token] for LM and [curr token] (self) for PLM
  new_targets = tf.concat([inputs[0: 1], targets[: -1]],
                          axis=0)

  # construct inputs_k
  inputs_k = inputs

  # construct inputs_q
  inputs_q = target_mask

  return perm_mask, new_targets, target_mask, inputs_k, inputs_q


def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
                num_batch, seq_len, reuse_len, perm_size, mask_alpha,
                mask_beta, use_bfloat16=False, num_predict=None):
  """Gets the dataset."""

  del mask_alpha
  del mask_beta
  bsz_per_core = params["batch_size"]
  if num_hosts > 1:
    host_id = params["context"].current_host
  else:
    host_id = 0

    #### Function used to parse tfrecord
  def parser(record):
    """function used to parse tfrecord."""

    record_spec = {
        "input": tf.FixedLenFeature([seq_len], tf.int64),
        "target": tf.FixedLenFeature([seq_len], tf.int64),
        "seg_id": tf.FixedLenFeature([seq_len], tf.int64),
        "label": tf.FixedLenFeature([1], tf.int64),
        "is_masked": tf.FixedLenFeature([seq_len], tf.int64),
    }

    # retrieve serialized example
    example = tf.parse_single_example(
        serialized=record,
        features=record_spec)

    inputs = example.pop("input")
    target = example.pop("target")
    is_masked = tf.cast(example.pop("is_masked"), tf.bool)

    non_reuse_len = seq_len - reuse_len
    assert perm_size <= reuse_len and perm_size <= non_reuse_len

    perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
        inputs[:reuse_len],
        target[:reuse_len],
        is_masked[:reuse_len],
        perm_size,
        reuse_len)

    perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
        inputs[reuse_len:],
        target[reuse_len:],
        is_masked[reuse_len:],
        perm_size,
        non_reuse_len)

    perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])],
                            axis=1)
    perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
                            axis=1)
    perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
    target = tf.concat([target_0, target_1], axis=0)
    target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
    input_k = tf.concat([input_k_0, input_k_1], axis=0)
    input_q = tf.concat([input_q_0, input_q_1], axis=0)

    if num_predict is not None:
      indices = tf.range(seq_len, dtype=tf.int64)
      bool_target_mask = tf.cast(target_mask, tf.bool)
      indices = tf.boolean_mask(indices, bool_target_mask)

      ##### extra padding due to CLS/SEP introduced after prepro
      actual_num_predict = tf.shape(indices)[0]
      pad_len = num_predict - actual_num_predict

      ##### target_mapping
      target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
      paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
      target_mapping = tf.concat([target_mapping, paddings], axis=0)
      example["target_mapping"] = tf.reshape(target_mapping,
                                             [num_predict, seq_len])

      ##### target
      target = tf.boolean_mask(target, bool_target_mask)
      paddings = tf.zeros([pad_len], dtype=target.dtype)
      target = tf.concat([target, paddings], axis=0)
      example["target"] = tf.reshape(target, [num_predict])

      ##### target mask
      target_mask = tf.concat(
          [tf.ones([actual_num_predict], dtype=tf.float32),
           tf.zeros([pad_len], dtype=tf.float32)],
          axis=0)
      example["target_mask"] = tf.reshape(target_mask, [num_predict])
    else:
      example["target"] = tf.reshape(target, [seq_len])
      example["target_mask"] = tf.reshape(target_mask, [seq_len])

    # reshape back to fixed shape
    example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
    example["input_k"] = tf.reshape(input_k, [seq_len])
    example["input_q"] = tf.reshape(input_q, [seq_len])

    _convert_example(example, use_bfloat16)

    for k, v in example.items():
      logging.info("%s: %s", k, v)

    return example

  # Get dataset
  dataset = parse_files_to_dataset(
      parser=parser,
      file_names=file_names,
      split=split,
      num_batch=num_batch,
      num_hosts=num_hosts,
      host_id=host_id,
      num_core_per_host=num_core_per_host,
      bsz_per_core=bsz_per_core)

  return dataset


def get_input_fn(
    tfrecord_dir,
    split,
    bsz_per_host,
    seq_len,
    reuse_len,
    bi_data,
    num_hosts=1,
    num_core_per_host=1,
    perm_size=None,
    mask_alpha=None,
    mask_beta=None,
    uncased=False,
    num_passes=None,
    use_bfloat16=False,
    num_predict=None):
  """Gets the input function."""

  # Merge all record infos into a single one
  record_glob_base = format_filename(
      prefix="record_info-{}-*".format(split),
      bsz_per_host=bsz_per_host,
      seq_len=seq_len,
      bi_data=bi_data,
      suffix="json",
      mask_alpha=mask_alpha,
      mask_beta=mask_beta,
      reuse_len=reuse_len,
      uncased=uncased,
      fixed_num_predict=num_predict)

  record_info = {"num_batch": 0, "filenames": []}

  tfrecord_dirs = tfrecord_dir.split(",")
  logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)

  for idx, record_dir in enumerate(tfrecord_dirs):
    record_glob = os.path.join(record_dir, record_glob_base)
    logging.info("[%d] Record glob: %s", idx, record_glob)

    record_paths = sorted(tf.gfile.Glob(record_glob))
    logging.info("[%d] Num of record info path: %d", idx, len(record_paths))

    cur_record_info = {"num_batch": 0, "filenames": []}

    for record_info_path in record_paths:
      if num_passes is not None:
        record_info_name = os.path.basename(record_info_path)
        fields = record_info_name.split(".")[0].split("-")
        pass_id = int(fields[-1])
        if len(fields) == 5 and pass_id >= num_passes:
          logging.info("Skip pass %d: %s", pass_id, record_info_name)
          continue

      with tf.gfile.Open(record_info_path, "r") as fp:
        info = json.load(fp)
        if num_passes is not None:
          eff_num_passes = min(num_passes, len(info["filenames"]))
          ratio = eff_num_passes / len(info["filenames"])
          cur_record_info["num_batch"] += int(info["num_batch"] * ratio)
          cur_record_info["filenames"] += info["filenames"][:eff_num_passes]
        else:
          cur_record_info["num_batch"] += info["num_batch"]
          cur_record_info["filenames"] += info["filenames"]

    # overwrite directory for `cur_record_info`
    new_filenames = []
    for filename in cur_record_info["filenames"]:
      basename = os.path.basename(filename)
      new_filename = os.path.join(record_dir, basename)
      new_filenames.append(new_filename)
    cur_record_info["filenames"] = new_filenames

    logging.info("[Dir %d] Number of chosen batches: %s",
                 idx, cur_record_info["num_batch"])
    logging.info("[Dir %d] Number of chosen files: %s",
                 idx, len(cur_record_info["filenames"]))
    logging.info(cur_record_info["filenames"])

    # add `cur_record_info` to global `record_info`
    record_info["num_batch"] += cur_record_info["num_batch"]
    record_info["filenames"] += cur_record_info["filenames"]

  logging.info("Total number of batches: %d", record_info["num_batch"])
  logging.info("Total number of files: %d", len(record_info["filenames"]))
  logging.info(record_info["filenames"])

  def input_fn(params):
    """docs."""
    assert params["batch_size"] * num_core_per_host == bsz_per_host

    dataset = get_dataset(
        params=params,
        num_hosts=num_hosts,
        num_core_per_host=num_core_per_host,
        split=split,
        file_names=record_info["filenames"],
        num_batch=record_info["num_batch"],
        seq_len=seq_len,
        reuse_len=reuse_len,
        perm_size=perm_size,
        mask_alpha=mask_alpha,
        mask_beta=mask_beta,
        use_bfloat16=use_bfloat16,
        num_predict=num_predict)

    return dataset

  return input_fn, record_info


def define_flags():
  """Defines relevant flags."""
  flags.DEFINE_bool("use_tpu", True, help="whether to use TPUs")
  flags.DEFINE_integer("bsz_per_host", 32, help="batch size per host.")
  flags.DEFINE_integer("num_core_per_host", 8, help="num TPU cores per host.")

  flags.DEFINE_integer("seq_len", 512,
                       help="Sequence length.")
  flags.DEFINE_integer("reuse_len", 256,
                       help="Number of token that can be reused as memory. "
                       "Could be half of `seq_len`.")
  flags.DEFINE_bool("uncased", False, help="Use uncased inputs or not.")
  flags.DEFINE_bool("bi_data", True,
                    help="whether to create bidirectional data")
  flags.DEFINE_integer("mask_alpha", default=6,
                       help="How many tokens to form a group.")
  flags.DEFINE_integer("mask_beta", default=1,
                       help="How many tokens to mask within each group.")
  flags.DEFINE_bool("use_eod", True,
                    help="whether to append EOD at the end of a doc.")
  flags.DEFINE_bool("from_raw_text", True,
                    help="Whether the input is raw text or encoded ids.")
  flags.DEFINE_integer("num_predict", default=85,
                       help="Num of tokens to predict.")

  flags.DEFINE_string("input_glob", "data/example/*.txt",
                      help="Input file glob.")
  flags.DEFINE_string("sp_path", "", help="Path to the sentence piece model.")
  flags.DEFINE_string("save_dir", "proc_data/example",
                      help="Directory for saving the processed data.")
  flags.DEFINE_enum("split", "train", ["train", "dev", "test"],
                    help="Save the data as which split.")

  flags.DEFINE_integer("pass_id", 0, help="ID of the current pass."
                       "Different passes sample different negative segment.")
  flags.DEFINE_integer("num_task", 1, help="Number of total tasks.")
  flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when "
                       "using multiple workers to identify each worker.")


if __name__ == "__main__":
  define_flags()
  logging.set_verbosity(logging.INFO)
  app.run(create_data)