official/projects/lra/lra_dual_encoder_task.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer network for dual encoder style models."""
# pylint: disable=g-classes-have-attributes
import dataclasses
from typing import List, Union, Optional
from absl import logging
import numpy as np
import orbit
from scipy import stats
from sklearn import metrics as sklearn_metrics
import tensorflow as tf, tf_keras
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory
from official.nlp.tasks import utils
from official.projects.lra import lra_dual_encoder
METRIC_TYPES = frozenset(
['accuracy', 'f1', 'matthews_corrcoef', 'pearson_spearman_corr']
)
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A classifier/regressor configuration."""
num_classes: int = 2
use_encoder_pooler: bool = False
encoder: encoders.EncoderConfig = encoders.EncoderConfig()
max_seq_length: int = 512
@dataclasses.dataclass
class DualEncoderConfig(cfg.TaskConfig):
"""The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can
# be specified.
init_checkpoint: str = ''
init_cls_pooler: bool = False
hub_module_url: str = ''
metric_type: str = 'accuracy'
# Defines the concrete model config at instantiation time.
model: ModelConfig = ModelConfig()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
@task_factory.register_task_cls(DualEncoderConfig)
class DualEncoderTask(base_task.Task):
"""Task object for DualEncoderTask."""
def __init__(self, params: cfg.TaskConfig, logging_dir=None, name=None):
super().__init__(params, logging_dir, name=name)
if params.metric_type not in METRIC_TYPES:
raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
self.metric_type = params.metric_type
if hasattr(params.train_data, 'label_field'):
self.label_field = params.train_data.label_field
else:
self.label_field = 'label_ids'
def build_model(self):
if self.task_config.hub_module_url and self.task_config.init_checkpoint:
raise ValueError(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
if self.task_config.hub_module_url:
encoder_network = utils.get_encoder_from_hub(
self.task_config.hub_module_url
)
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get()
return lra_dual_encoder.LRADualEncoder(
network=encoder_network,
max_seq_length=self.task_config.model.max_seq_length,
num_classes=self.task_config.model.num_classes,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range
),
use_encoder_pooler=self.task_config.model.use_encoder_pooler,
inner_dim=encoder_cfg.hidden_size * 2,
)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
label_ids = labels[self.label_field]
if self.task_config.model.num_classes == 1:
loss = tf_keras.losses.mean_squared_error(label_ids, model_outputs)
else:
loss = tf_keras.losses.sparse_categorical_crossentropy(
label_ids, tf.cast(model_outputs, tf.float32), from_logits=True
)
if aux_losses:
loss += tf.add_n(aux_losses)
return tf_utils.safe_mean(loss)
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path == 'dummy':
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
x = dict(
left_word_ids=dummy_ids,
left_mask=dummy_ids,
right_word_ids=dummy_ids,
right_mask=dummy_ids,
)
if self.task_config.model.num_classes == 1:
y = tf.zeros((1,), dtype=tf.float32)
else:
y = tf.zeros((1, 1), dtype=tf.int32)
x[self.label_field] = y
return x
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
return dataset
return data_loader_factory.get_data_loader(params).load(input_context)
def build_metrics(self, training=None):
del training
if self.task_config.model.num_classes == 1:
metrics = [tf_keras.metrics.MeanSquaredError()]
elif self.task_config.model.num_classes == 2:
metrics = [
tf_keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy'),
tf_keras.metrics.AUC(name='auc', curve='PR'),
]
else:
metrics = [
tf_keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy'),
]
return metrics
def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics:
if metric.name == 'auc':
# Convert the logit to probability and extract the probability of True..
metric.update_state(
labels[self.label_field],
tf.expand_dims(tf.nn.softmax(model_outputs)[:, 1], axis=1),
)
if metric.name == 'cls_accuracy':
metric.update_state(labels[self.label_field], model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels[self.label_field], model_outputs)
def validation_step(self, inputs, model: tf_keras.Model, metrics=None):
features, labels = inputs, inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses
)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
if model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics or []})
logs.update({m.name: m.result() for m in model.metrics})
if self.metric_type == 'matthews_corrcoef':
logs.update({
'sentence_prediction': (
tf.expand_dims( # Ensure one prediction along batch dimension.
tf.math.argmax(outputs, axis=1), axis=1
)
),
'labels': labels[self.label_field],
})
else:
logs.update({
'sentence_prediction': outputs,
'labels': labels[self.label_field],
})
return logs
def aggregate_logs(self, state=None, step_outputs=None):
if self.metric_type == 'accuracy':
return None
if state is None:
state = {'sentence_prediction': [], 'labels': []}
state['sentence_prediction'].append(
np.concatenate(
[v.numpy() for v in step_outputs['sentence_prediction']], axis=0
)
)
state['labels'].append(
np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0)
)
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
if self.metric_type == 'accuracy':
return None
preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
labels = np.concatenate(aggregated_logs['labels'], axis=0)
if self.metric_type == 'f1':
preds = np.argmax(preds, axis=1)
return {self.metric_type: sklearn_metrics.f1_score(labels, preds)}
elif self.metric_type == 'matthews_corrcoef':
preds = np.reshape(preds, -1)
labels = np.reshape(labels, -1)
return {
self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
}
elif self.metric_type == 'pearson_spearman_corr':
preds = np.reshape(preds, -1)
labels = np.reshape(labels, -1)
pearson_corr = stats.pearsonr(preds, labels)[0]
spearman_corr = stats.spearmanr(preds, labels)[0]
corr_metric = (pearson_corr + spearman_corr) / 2
return {self.metric_type: corr_metric}
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
logging.info(
'Trying to load pretrained checkpoint from %s', ckpt_dir_or_file
)
if ckpt_dir_or_file and tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
logging.info(
'No checkpoint file found from %s. Will not load.', ckpt_dir_or_file
)
return
pretrain2finetune_mapping = {
'encoder': model.checkpoint_items['encoder'],
}
if self.task_config.init_cls_pooler:
# This option is valid when use_encoder_pooler is false.
pretrain2finetune_mapping['next_sentence.pooler_dense'] = (
model.checkpoint_items['sentence_prediction.pooler_dense']
)
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info(
'Finished loading pretrained checkpoint from %s', ckpt_dir_or_file
)
def predict(
task: DualEncoderTask,
params: cfg.DataConfig,
model: tf_keras.Model,
params_aug: Optional[cfg.DataConfig] = None,
test_time_aug_wgt: float = 0.3,
) -> List[Union[int, float]]:
"""Predicts on the input data.
Args:
task: A `DualEncoderTask` object.
params: A `cfg.DataConfig` object.
model: A keras.Model.
params_aug: A `cfg.DataConfig` object for augmented data.
test_time_aug_wgt: Test time augmentation weight. The prediction score will
use (1. - test_time_aug_wgt) original prediction plus test_time_aug_wgt
augmented prediction.
Returns:
A list of predictions with length of `num_examples`. For regression task,
each element in the list is the predicted score; for classification task,
each element is the predicted class id.
"""
def predict_step(inputs):
"""Replicated prediction calculation."""
x = inputs
example_id = x.pop('example_id')
outputs = task.inference_step(x, model)
return dict(example_id=example_id, predictions=outputs)
def aggregate_fn(state, outputs):
"""Concatenates model's outputs."""
if state is None:
state = []
for per_replica_example_id, per_replica_batch_predictions in zip(
outputs['example_id'], outputs['predictions']
):
state.extend(zip(per_replica_example_id, per_replica_batch_predictions))
return state
dataset = orbit.utils.make_distributed_dataset(
tf.distribute.get_strategy(), task.build_inputs, params
)
outputs = utils.predict(predict_step, aggregate_fn, dataset)
# When running on TPU POD, the order of output cannot be maintained,
# so we need to sort by example_id.
outputs = sorted(outputs, key=lambda x: x[0])
is_regression = task.task_config.model.num_classes == 1
if params_aug is not None:
dataset_aug = orbit.utils.make_distributed_dataset(
tf.distribute.get_strategy(), task.build_inputs, params_aug
)
outputs_aug = utils.predict(predict_step, aggregate_fn, dataset_aug)
outputs_aug = sorted(outputs_aug, key=lambda x: x[0])
if is_regression:
return [
(1.0 - test_time_aug_wgt) * x[1] + test_time_aug_wgt * y[1]
for x, y in zip(outputs, outputs_aug)
]
else:
return [
tf.argmax(
(1.0 - test_time_aug_wgt) * x[1] + test_time_aug_wgt * y[1],
axis=-1,
)
for x, y in zip(outputs, outputs_aug)
]
if is_regression:
return [x[1] for x in outputs]
else:
return [tf.argmax(x[1], axis=-1) for x in outputs]