official/projects/triviaqa/modeling.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modeling for TriviaQA."""
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.nlp.configs import encoders
class TriviaQaHead(tf_keras.layers.Layer):
"""Computes logits given token and global embeddings."""
def __init__(self,
intermediate_size,
intermediate_activation=tf_utils.get_activation('gelu'),
dropout_rate=0.0,
attention_dropout_rate=0.0,
**kwargs):
super(TriviaQaHead, self).__init__(**kwargs)
self._attention_dropout = tf_keras.layers.Dropout(attention_dropout_rate)
self._intermediate_dense = tf_keras.layers.Dense(intermediate_size)
self._intermediate_activation = tf_keras.layers.Activation(
intermediate_activation)
self._output_dropout = tf_keras.layers.Dropout(dropout_rate)
self._output_layer_norm = tf_keras.layers.LayerNormalization()
self._logits_dense = tf_keras.layers.Dense(2)
def build(self, input_shape):
output_shape = input_shape['token_embeddings'][-1]
self._output_dense = tf_keras.layers.Dense(output_shape)
super(TriviaQaHead, self).build(input_shape)
def call(self, inputs, training=None):
token_embeddings = inputs['token_embeddings']
token_ids = inputs['token_ids']
question_lengths = inputs['question_lengths']
x = self._attention_dropout(token_embeddings, training=training)
intermediate_outputs = self._intermediate_dense(x)
intermediate_outputs = self._intermediate_activation(intermediate_outputs)
outputs = self._output_dense(intermediate_outputs)
outputs = self._output_dropout(outputs, training=training)
outputs = self._output_layer_norm(outputs + token_embeddings)
logits = self._logits_dense(outputs)
logits -= tf.expand_dims(
tf.cast(tf.equal(token_ids, 0), tf.float32) + tf.sequence_mask(
question_lengths, logits.shape[-2], dtype=tf.float32), -1) * 1e6
return logits
class TriviaQaModel(tf_keras.Model):
"""Model for TriviaQA."""
def __init__(self, model_config: encoders.EncoderConfig, sequence_length: int,
**kwargs):
inputs = dict(
token_ids=tf_keras.Input((sequence_length,), dtype=tf.int32),
question_lengths=tf_keras.Input((), dtype=tf.int32))
encoder = encoders.build_encoder(model_config)
x = encoder(
dict(
input_word_ids=inputs['token_ids'],
input_mask=tf.cast(inputs['token_ids'] > 0, tf.int32),
input_type_ids=1 -
tf.sequence_mask(inputs['question_lengths'], sequence_length,
tf.int32)))['sequence_output']
logits = TriviaQaHead(
model_config.get().intermediate_size,
dropout_rate=model_config.get().dropout_rate,
attention_dropout_rate=model_config.get().attention_dropout_rate)(
dict(
token_embeddings=x,
token_ids=inputs['token_ids'],
question_lengths=inputs['question_lengths']))
super(TriviaQaModel, self).__init__(inputs, logits, **kwargs)
self._encoder = encoder
@property
def encoder(self):
return self._encoder
class SpanOrCrossEntropyLoss(tf_keras.losses.Loss):
"""Cross entropy loss for multiple correct answers.
See https://arxiv.org/abs/1710.10723.
"""
def call(self, y_true, y_pred):
y_pred_masked = y_pred - tf.cast(y_true < 0.5, tf.float32) * 1e6
or_cross_entropy = (
tf.math.reduce_logsumexp(y_pred, axis=-2) -
tf.math.reduce_logsumexp(y_pred_masked, axis=-2))
return tf.math.reduce_sum(or_cross_entropy, -1)
def smooth_labels(label_smoothing, labels, question_lengths, token_ids):
mask = 1. - (
tf.cast(tf.equal(token_ids, 0), tf.float32) +
tf.sequence_mask(question_lengths, labels.shape[-2], dtype=tf.float32))
num_classes = tf.expand_dims(tf.math.reduce_sum(mask, -1, keepdims=True), -1)
labels = (1. - label_smoothing) * labels + (label_smoothing / num_classes)
return labels * tf.expand_dims(mask, -1)