official/projects/triviaqa/inputs.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input processing for TriviaQA."""
import os
from typing import Optional, Text, Union
import tensorflow as tf, tf_keras
import tensorflow_datasets as tfds
from official.modeling import tf_utils
from official.projects.triviaqa import dataset # pylint: disable=unused-import
def _flatten_dims(tensor: tf.Tensor,
first_dim: Optional[int] = 0,
last_dim: Optional[int] = -1,
name: Optional[Text] = None) -> tf.Tensor:
"""Flattens the given span of dimensions in `tensor`.
Args:
tensor: [..., first_dim_size, ...middle_dims..., last_dim_size, ...] shaped
Tensor.
first_dim: The first dimension to flatten (inclusive). Must be a valid index
for the rank of `tensor`. Default is 0.
last_dim: The last dimension to flatten (inclusive). Must be a valid index
for the rank of `tensor`. Default is -1.
name: A name for the operation (optional).
Returns:
Tensor of shape [..., flattened_dim_size, ...] where
flattened_dim_size = first_dim_size * ...middle_dims... * last_dim_size.
"""
with tf.name_scope(name or 'flatten_dims'):
tensor = tf.convert_to_tensor(tensor)
rank = tensor.shape.rank
if rank is None:
raise ValueError('Static rank of `tensor` must be known.')
if first_dim < 0: # pytype: disable=unsupported-operands
first_dim += rank
if first_dim < 0 or first_dim >= rank: # pytype: disable=unsupported-operands
raise ValueError('`first_dim` out of bounds for `tensor` rank.')
if last_dim < 0: # pytype: disable=unsupported-operands
last_dim += rank
if last_dim < 0 or last_dim >= rank: # pytype: disable=unsupported-operands
raise ValueError('`last_dim` out of bounds for `tensor` rank.')
if first_dim > last_dim: # pytype: disable=unsupported-operands
raise ValueError('`first_dim` must not be larger than `last_dim`.')
# Try to calculate static flattened dim size if all input sizes to flatten
# are statically known. Otherwise, just use -1.
flat_dims_shape = tensor.shape[first_dim:(last_dim + 1)].as_list()
flattened_dim_size = 1
for size in flat_dims_shape:
if size is None:
flattened_dim_size = -1
break
flattened_dim_size *= size
old_shape = tf.shape(tensor)
output_shape = tf.concat([
old_shape[:first_dim], [flattened_dim_size], old_shape[(last_dim + 1):]
], 0)
return tf.reshape(tensor, output_shape)
def _pad_to_multiple(tensor: tf.Tensor,
factor: Union[int, tf.Tensor],
axis: int,
mode: Optional[Text] = 'CONSTANT',
constant_values=0,
name: Optional[Text] = None) -> tf.Tensor:
"""Pads `tensor` on a given `axis` to be a multiple of `factor`.
Padding will be concatenated to the end of the axis only, not the beginning.
If the length along `axis` is already a multiple of `factor`, this is
effectively a no-op.
Args:
tensor: A Tensor with rank >= 1 to pad.
factor: Positive integer factor to pad for. If a Tensor, must be a scalar
int.
axis: A valid axis in `tensor` to pad.
mode: The padding mode to use according to `tf.pad`. Defaults to 'CONSTANT'.
constant_values: For 'CONSTANT' mode, the scalar pad value to use within
`tf.pad`. Defaults to 0. Must be same type as `tensor`.
name: A name for the operation (optional).
Returns:
The padded Tensor result.
"""
with tf.name_scope(name or 'pad_to_multiple'):
tensor = tf.convert_to_tensor(tensor)
if isinstance(factor, int) and factor < 1:
raise ValueError('`factor` must be positive.')
rank = tensor.shape.rank
if rank is None:
raise ValueError('Static rank of `tensor` must be known.')
if axis < 0:
axis += rank
if axis < 0 or axis >= rank:
raise ValueError('`axis` out of bounds for `tensor` rank.')
axis_len = tf_utils.get_shape_list(tensor)[axis]
pad_len = -axis_len % factor
paddings = pad_len * tf.one_hot([-1, axis], rank, axis=0, dtype=tf.int32)
return tf.pad(
tensor=tensor,
paddings=paddings,
mode=mode,
constant_values=constant_values)
def _skew_elements_right(tensor: tf.Tensor,
axis: int,
pad_value=0,
name: Optional[Text] = None) -> tf.Tensor:
"""Skews successive elements right along the given `axis`.
This changes an input like
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
]
into the following:
[
[1, 2, 3, 0, 0],
[0, 4, 5, 6, 0],
[0, 0, 7, 8, 9]
]
Args:
tensor: Tensor of shape [..., num_rows, axis_len, ...].
axis: A valid axis in `tensor` to skew along. It must not be the first axis
in `tensor`.
pad_value: The scalar pad value to use. Defaults to 0. Must be the same type
as `tensor`.
name: A name for the operation (optional).
Returns:
Tensor of shape [..., num_rows, axis_len + num_rows - 1, ...].
"""
with tf.name_scope(name or 'skew_elements_right'):
tensor = tf.convert_to_tensor(tensor)
rank = tensor.shape.rank
num_rows = tf_utils.get_shape_list(tensor)[axis - 1]
axis_len = tf_utils.get_shape_list(tensor)[axis]
if rank is None:
raise ValueError('Static rank of `tensor` must be known.')
if axis < 0:
axis += rank
if axis <= 0 or axis >= rank:
raise ValueError('`axis` out of bounds for `tensor` rank.')
output_len = axis_len + num_rows - 1
paddings = num_rows * tf.one_hot([-1, axis], rank, axis=0, dtype=tf.int32)
# [..., num_rows, axis_len + num_rows, ...]
padded_tensor = tf.pad(tensor, paddings, constant_values=pad_value)
# [..., num_rows * (axis_len + num_rows), ...]
flat_tensor = _flatten_dims(
padded_tensor, first_dim=axis - 1, last_dim=axis)
padded_tensor2 = _pad_to_multiple(
flat_tensor,
factor=output_len,
axis=axis - 1,
constant_values=pad_value)
# [..., num_rows + 1, output_len, ...]
new_shape = tf.concat([
tf.shape(tensor)[:(axis - 1)], [num_rows + 1, output_len],
tf.shape(tensor)[(axis + 1):]
], 0)
reshaped_tensor = tf.reshape(padded_tensor2, new_shape)
# [..., num_rows, output_len, ...]
output_shape = new_shape - tf.one_hot(axis - 1, depth=rank, dtype=tf.int32)
return tf.slice(
reshaped_tensor, begin=tf.zeros_like(output_shape), size=output_shape)
class RelativePositionGenerator(object):
"""Generates `relative_att_ids` for purely distance-based relative positions.
This implements the clipped relative position representations originally
described in https://arxiv.org/abs/1803.02155 .
Attributes:
max_distance: Integer passed from `__init__`.
ignore_direction: Bool passed from `__init__`.
relative_vocab_size: Integer representing the maximum number of unique ids
output from this generator.
left_pad_value: Integer id for all positions at or beyond max_distance to
the left.
right_pad_value: Integer id for all positions at or beyond max_distance to
the right.
"""
def __init__(self, max_distance: int, ignore_direction: bool = False):
"""Init.
Args:
max_distance: The maximum distance to represent. Must not be negative. All
larger distances will be clipped to this value.
ignore_direction: If True, both left and right position representations
will have the same ids based on absolute distance (resulting in
symmetric ids around the center token).
"""
if max_distance < 0:
raise ValueError('`max_distance` must not be negative.')
self.max_distance = max_distance
self.ignore_direction = ignore_direction
self.right_pad_value = max_distance
self.left_pad_value = max_distance if ignore_direction else 2 * max_distance
# 0 is the first id, so vocab size is 1 + the largest id (left pad value).
self.relative_vocab_size = self.left_pad_value + 1
def make_relative_att_ids(self,
seq_len: Union[int, tf.Tensor],
batch_size: Optional[Union[int, tf.Tensor]] = 1,
name: Optional[Text] = None) -> tf.Tensor:
"""Makes relative position ids for full self-attention.
For example, if `max_distance` is 3, `ignore_direction` is False, `seq_len`
is 6, and `batch_size` is 1, the result is the following:
[[
[0, 1, 2, 3, 3, 3],
[4, 0, 1, 2, 3, 3],
[5, 4, 0, 1, 2, 3],
[6, 5, 4, 0, 1, 2],
[6, 6, 5, 4, 0, 1],
[6, 6, 6, 5, 4, 0],
]]
Args:
seq_len: The sequence length to create ids for. Must be positive. If a
Tensor, must be a scalar int.
batch_size: The batch size of the result (default 1). Must be positive. If
a Tensor, must be a scalar int. All examples in the batch will have the
same id pattern.
name: A name for the operation (optional).
Returns:
<int32>[batch_size, seq_len, seq_len] Tensor of relative position ids.
"""
with tf.name_scope(name or 'make_relative_att_ids'):
if isinstance(seq_len, int) and seq_len < 1:
raise ValueError('`seq_len` must be positive.')
if isinstance(batch_size, int) and batch_size < 1:
raise ValueError('`batch_size` must be positive.')
# We need the id_pattern to cover all tokens to the left of the last token
# and all tokens to the right of the first token at the same time.
window_size = 2 * seq_len - 1
# [window_size]
id_pattern = self._make_relative_id_pattern(window_size)
# [seq_len, window_size]
id_tensor = tf.tile(id_pattern[tf.newaxis, :], [seq_len, 1])
# [seq_len, window_size + seq_len - 1]
id_tensor = _skew_elements_right(id_tensor, -1)
# [seq_len, seq_len]
id_tensor = tf.slice(id_tensor, [0, seq_len - 1], [seq_len, seq_len])
return tf.tile(id_tensor[tf.newaxis, :, :], [batch_size, 1, 1])
def make_local_relative_att_ids(self,
seq_len: Union[int, tf.Tensor],
local_radius: int,
batch_size: Optional[Union[int,
tf.Tensor]] = 1,
name: Optional[Text] = None) -> tf.Tensor:
"""Makes relative position ids for local self-attention.
The result can be used as `relative_att_ids` in
`layers.RelativeLocalSelfAttention`.
For example, if `max_distance` is 3, `ignore_direction` is False, `seq_len`
is 4, `local_radius` is 5, and `batch_size` is 1, the result is the
following:
[[
[6, 6, 6, 5, 4, 0, 1, 2, 3, 3, 3],
[6, 6, 6, 5, 4, 0, 1, 2, 3, 3, 3],
[6, 6, 6, 5, 4, 0, 1, 2, 3, 3, 3],
[6, 6, 6, 5, 4, 0, 1, 2, 3, 3, 3],
]]
Args:
seq_len: The sequence length to create ids for. Must be positive. If a
Tensor, must be a scalar int.
local_radius: The local radius as expected by
`layers.RelativeLocalSelfAttention`. Must be positive.
batch_size: The batch size of the result (default 1). Must be positive. If
a Tensor, must be a scalar int. All examples in the batch will have the
same id pattern.
name: A name for the operation (optional).
Returns:
<int32>[batch_size, seq_len, 2*local_radius + 1] Tensor of relative
position ids.
"""
with tf.name_scope(name or 'make_local_relative_att_ids'):
if isinstance(seq_len, int) and seq_len < 1:
raise ValueError('`seq_len` must be positive.')
if local_radius < 1:
raise ValueError('`local_radius` must be positive.')
if isinstance(batch_size, int) and batch_size < 1:
raise ValueError('`batch_size` must be positive.')
window_size = 2 * local_radius + 1
# [window_size]
id_pattern = self._make_relative_id_pattern(window_size)
return tf.tile(id_pattern[tf.newaxis, tf.newaxis, :],
[batch_size, seq_len, 1])
def _make_relative_id_pattern(
self, window_size: Union[int, tf.Tensor]) -> tf.Tensor:
"""Helper for making the relative id pattern for a particular window size.
For example, if `max_distance` is 3, `ignore_direction` is False, and
`window_size` is 11, the result is the following:
[6, 6, 6, 5, 4, 0, 1, 2, 3, 3, 3].
Args:
window_size: Window size to return relative ids for. Must be positive and
odd since ids will be relative to the center of the window. If a Tensor,
must be a scalar int.
Returns:
<int32>[window_size] Tensor of relative position ids.
"""
if isinstance(window_size, int):
if window_size < 1:
raise ValueError('`window_size` must be positive.')
if window_size % 2 != 1:
raise ValueError('`window_size` must be odd.')
x = tf.range(self.max_distance + 1, dtype=tf.int32)
x = tf.pad(x, [[self.max_distance, 0]], mode='REFLECT')
if not self.ignore_direction:
direction_adder = tf.concat([
tf.fill([self.max_distance], self.max_distance),
tf.zeros([self.max_distance + 1], dtype=tf.int32)
], 0)
x += direction_adder
len_x = x.shape.as_list()[0]
if len_x > window_size:
trim_amount = (len_x - window_size) // 2
return x[trim_amount:-trim_amount]
pad_amount = (window_size - len_x) // 2
result = tf.pad(x, [[pad_amount, 0]], constant_values=self.left_pad_value)
result = tf.pad(
result, [[0, pad_amount]], constant_values=self.right_pad_value)
return result
def read_batches(data_dir,
split,
batch_size,
include_answers=True,
shuffle=False,
drop_final_batch=False,
compression_type=''):
"""Read TriviaQA batches."""
features = {
'id': tf.io.FixedLenFeature([], tf.string),
'qid': tf.io.FixedLenFeature([], tf.string),
'context': tf.io.FixedLenFeature([], tf.string),
'question': tf.io.FixedLenFeature([], tf.string),
'global_token_ids': tf.io.RaggedFeature(tf.int64),
'token_ids': tf.io.RaggedFeature(tf.int64),
'segment_ids': tf.io.RaggedFeature(tf.int64),
'token_offsets': tf.io.RaggedFeature(tf.int64),
}
if include_answers:
features['answers'] = tf.io.RaggedFeature(
tf.int64, partitions=(tf.io.RaggedFeature.UniformRowLength(2),)) # pytype: disable=attribute-error
dataset_builder = tfds.builder(
'bigbird_trivia_qa/rc_wiki.preprocessed', data_dir=data_dir)
split_info = dataset_builder.info.splits[split]
return tf.data.experimental.make_batched_features_dataset(
[
os.path.join(dataset_builder.data_dir, filename)
for filename in split_info.filenames
],
batch_size=batch_size,
features=features,
reader=lambda path: tf.data.TFRecordDataset(path, compression_type),
label_key='answers' if include_answers else None,
num_epochs=1,
shuffle=shuffle,
shuffle_buffer_size=split_info.num_examples,
prefetch_buffer_size=tf.data.experimental.AUTOTUNE,
sloppy_ordering=True,
drop_final_batch=drop_final_batch,
reader_num_threads=8,
parser_num_threads=16)
def scatter_labels(labels, batch_size, sequence_length):
"""Create one hot labels."""
row_ids = labels.value_rowids()
indices = tf.concat(
(tf.stack((row_ids, tf.cast(labels.flat_values[:, 0],
tf.int32), tf.zeros_like(row_ids)), -1),
tf.stack((row_ids, tf.cast(labels.flat_values[:, 1],
tf.int32), tf.ones_like(row_ids)), -1)), 0)
one_hot_labels = tf.scatter_nd(indices,
tf.ones(tf.shape(indices)[0], tf.float32),
(batch_size, sequence_length, 2))
return tf.minimum(one_hot_labels, 1.)
def features_map_fn(features, local_radius, relative_pos_max_distance,
use_hard_g2l_mask, padding_id, eos_id, null_id, cls_id,
sep_id, sequence_length, global_sequence_length):
"""Make features."""
batch_size = tf.get_static_value(features['token_ids'].shape[0])
# sequence_lengths = features['token_ids'].row_lengths()
question_lengths = tf.argmax(
tf.equal(features['token_ids'].to_tensor(
shape=(batch_size, global_sequence_length)), sep_id), -1) + 1
mapped_features = dict(
token_ids=tf.cast(
features['token_ids'].to_tensor(shape=(batch_size, sequence_length)),
tf.int32),
global_token_ids=tf.cast(
features['global_token_ids'].to_tensor(
shape=(batch_size, global_sequence_length)), tf.int32),
segment_ids=tf.cast(
features['segment_ids'].to_tensor(
shape=(batch_size, sequence_length)), tf.int32),
)
relative_pos_generator = RelativePositionGenerator(
max_distance=relative_pos_max_distance)
# Only do long-to-long attention for non-null tokens.
# Let the null token attend to itself.
l2l_att_mask = tf.ones((batch_size, sequence_length, 2 * local_radius + 1),
tf.int32)
l2l_att_mask *= 1 - tf.cast(
tf.logical_or(
tf.equal(mapped_features['token_ids'], padding_id),
tf.equal(mapped_features['token_ids'], null_id)),
tf.int32)[:, :, tf.newaxis]
l2l_relative_att_ids = relative_pos_generator.make_local_relative_att_ids(
seq_len=sequence_length, local_radius=local_radius, batch_size=batch_size)
#
l2g_att_mask = tf.ones((batch_size, sequence_length, global_sequence_length),
tf.int32)
l2g_att_mask *= tf.cast(
tf.not_equal(mapped_features['token_ids'], padding_id),
tf.int32)[:, :, tf.newaxis]
l2g_att_mask *= tf.cast(
tf.not_equal(mapped_features['global_token_ids'], padding_id),
tf.int32)[:, tf.newaxis, :]
l2g_relative_att_ids = tf.fill(
(batch_size, sequence_length, global_sequence_length),
relative_pos_generator.relative_vocab_size + 1)
#
g2g_att_mask = tf.ones(
(batch_size, global_sequence_length, global_sequence_length), tf.int32)
g2g_att_mask *= tf.cast(
tf.not_equal(mapped_features['global_token_ids'], padding_id),
tf.int32)[:, :, tf.newaxis]
g2g_relative_att_ids = relative_pos_generator.make_relative_att_ids(
seq_len=global_sequence_length, batch_size=batch_size)
global_sentence_mask = tf.equal(mapped_features['global_token_ids'], eos_id)
global_question_mask = tf.logical_not(
tf.logical_or(
tf.logical_or(
tf.equal(mapped_features['global_token_ids'], cls_id),
tf.equal(mapped_features['global_token_ids'], eos_id)),
tf.equal(mapped_features['global_token_ids'], padding_id)))
g2g_question_mask = tf.logical_and(global_question_mask[:, tf.newaxis, :],
global_question_mask[:, :, tf.newaxis])
g2g_sentence_mask = tf.logical_and(global_sentence_mask[:, tf.newaxis, :],
global_sentence_mask[:, :, tf.newaxis])
g2g_local_mask = tf.cast(
tf.logical_or(g2g_question_mask, g2g_sentence_mask), tf.int32)
g2g_relative_att_ids *= g2g_local_mask
g2g_relative_att_ids += (1 - g2g_local_mask) * (
relative_pos_generator.relative_vocab_size + 2)
#
g2l_att_mask = tf.transpose(l2g_att_mask, [0, 2, 1])
if use_hard_g2l_mask:
global_range = tf.range(
global_sequence_length, dtype=mapped_features['global_token_ids'].dtype)
g2l_att_mask *= tf.cast(
tf.logical_or(
tf.equal(
mapped_features['global_token_ids'], cls_id)[:, :, tf.newaxis],
tf.equal(global_range[tf.newaxis, :, tf.newaxis],
mapped_features['segment_ids'][:, tf.newaxis, :])),
tf.int32)
g2l_relative_att_ids = tf.transpose(l2g_relative_att_ids, [0, 2, 1])
mapped_features.update(
dict(
l2l_att_mask=l2l_att_mask,
l2l_relative_att_ids=l2l_relative_att_ids,
l2g_att_mask=l2g_att_mask,
l2g_relative_att_ids=l2g_relative_att_ids,
g2g_att_mask=g2g_att_mask,
g2g_relative_att_ids=g2g_relative_att_ids,
g2l_att_mask=g2l_att_mask,
g2l_relative_att_ids=g2l_relative_att_ids,
question_lengths=question_lengths,
))
return mapped_features
def labels_map_fn(token_ids, labels, sequence_length):
batch_size = tf.get_static_value(labels.shape[0])
row_lengths = labels.row_lengths()
empty_token_index = token_ids.row_lengths() - 1
one_hot_labels = scatter_labels(labels, batch_size, sequence_length)
one_hot_labels += (tf.cast(row_lengths == 0, tf.float32)[:, tf.newaxis] *
tf.one_hot(empty_token_index, sequence_length))[:, :,
tf.newaxis]
return one_hot_labels