official/projects/triviaqa/prediction.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for inference."""
import tensorflow as tf, tf_keras
def split_and_pad(strategy, batch_size, x):
"""Split and pad for interence."""
per_replica_size = batch_size // strategy.num_replicas_in_sync
def slice_fn(x, i):
begin = min(x.shape[0], i * per_replica_size)
end = min(x.shape[0], (i + 1) * per_replica_size)
indices = tf.range(begin, end, dtype=tf.int32)
return tf.gather(x, tf.pad(indices, [[0, per_replica_size - end + begin]]))
# pylint: disable=g-long-lambda
return tf.nest.map_structure(
lambda x: strategy.experimental_distribute_values_from_function(
lambda ctx: slice_fn(x, ctx.replica_id_in_sync_group)), x)
# pylint: enable=g-long-lambda
def decode_logits(top_k, max_size, logits, default):
"""Get the span from logits."""
logits = tf.transpose(logits, [0, 2, 1])
values, indices = tf.math.top_k(logits, top_k)
width = (
tf.expand_dims(indices[:, 1, :], -2) -
tf.expand_dims(indices[:, 0, :], -1))
mask = tf.logical_and(width >= 0, width <= max_size)
scores = (
tf.expand_dims(values[:, 0, :], -1) + tf.expand_dims(values[:, 1, :], -2))
scores = tf.where(mask, scores, -1e8)
flat_indices = tf.argmax(tf.reshape(scores, (-1, top_k * top_k)), -1)
begin = tf.gather(
indices[:, 0, :], tf.math.floordiv(flat_indices, top_k), batch_dims=1)
end = tf.gather(
indices[:, 1, :], tf.math.mod(flat_indices, top_k), batch_dims=1)
reduced_mask = tf.math.reduce_any(mask, [-1, -2])
return (tf.where(reduced_mask, begin,
default), tf.where(reduced_mask, end, default),
tf.math.reduce_max(scores, [-1, -2]))
@tf.function
def decode_answer(context, begin, end, token_offsets, end_limit):
i = tf.gather(token_offsets, begin, batch_dims=1)
j = tf.gather(token_offsets, tf.minimum(end + 1, end_limit), batch_dims=1)
j = tf.where(end == end_limit, tf.cast(tf.strings.length(context), tf.int64),
j)
return tf.strings.substr(context, i, j - i)
def distributed_logits_fn(model, x):
return model.distribute_strategy.run(
lambda x: model(x, training=False), args=(x,))