official/projects/teams/teams_pretrainer.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer network for TEAMS models."""
# pylint: disable=g-classes-have-attributes
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import models
_LOGIT_PENALTY_MULTIPLIER = 10000
class ReplacedTokenDetectionHead(tf_keras.layers.Layer):
"""Replaced token detection discriminator head.
Arguments:
encoder_cfg: Encoder config, used to create hidden layers and head.
num_task_agnostic_layers: Number of task agnostic layers in the
discriminator.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def __init__(self,
encoder_cfg,
num_task_agnostic_layers,
output='logits',
name='rtd',
**kwargs):
super(ReplacedTokenDetectionHead, self).__init__(name=name, **kwargs)
self.num_task_agnostic_layers = num_task_agnostic_layers
self.hidden_size = encoder_cfg['embedding_cfg']['hidden_size']
self.num_hidden_instances = encoder_cfg['num_hidden_instances']
self.hidden_cfg = encoder_cfg['hidden_cfg']
self.activation = self.hidden_cfg['intermediate_activation']
self.initializer = self.hidden_cfg['kernel_initializer']
self.hidden_layers = []
for i in range(self.num_task_agnostic_layers, self.num_hidden_instances):
self.hidden_layers.append(
layers.Transformer(
num_attention_heads=self.hidden_cfg['num_attention_heads'],
intermediate_size=self.hidden_cfg['intermediate_size'],
intermediate_activation=self.activation,
dropout_rate=self.hidden_cfg['dropout_rate'],
attention_dropout_rate=self.hidden_cfg['attention_dropout_rate'],
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='transformer/layer_%d_rtd' % i))
self.dense = tf_keras.layers.Dense(
self.hidden_size,
activation=self.activation,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='transform/rtd_dense')
self.rtd_head = tf_keras.layers.Dense(
units=1,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='transform/rtd_head')
if output not in ('predictions', 'logits'):
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
self._output_type = output
def call(self, sequence_data, input_mask):
"""Compute inner-products of hidden vectors with sampled element embeddings.
Args:
sequence_data: A [batch_size, seq_length, num_hidden] tensor.
input_mask: A [batch_size, seq_length] binary mask to separate the input
from the padding.
Returns:
A [batch_size, seq_length] tensor.
"""
attention_mask = layers.SelfAttentionMask()([sequence_data, input_mask])
data = sequence_data
for hidden_layer in self.hidden_layers:
data = hidden_layer([sequence_data, attention_mask])
rtd_logits = self.rtd_head(self.dense(data))
return tf.squeeze(rtd_logits, axis=-1)
class MultiWordSelectionHead(tf_keras.layers.Layer):
"""Multi-word selection discriminator head.
Arguments:
embedding_table: The embedding table.
activation: The activation, if any, for the dense layer.
initializer: The intializer for the dense layer. Defaults to a Glorot
uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def __init__(self,
embedding_table,
activation=None,
initializer='glorot_uniform',
output='logits',
name='mws',
**kwargs):
super(MultiWordSelectionHead, self).__init__(name=name, **kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf_keras.initializers.get(initializer)
self._vocab_size, self.embed_size = self.embedding_table.shape
self.dense = tf_keras.layers.Dense(
self.embed_size,
activation=self.activation,
kernel_initializer=self.initializer,
name='transform/mws_dense')
self.layer_norm = tf_keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/mws_layernorm')
if output not in ('predictions', 'logits'):
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
self._output_type = output
def call(self, sequence_data, masked_positions, candidate_sets):
"""Compute inner-products of hidden vectors with sampled element embeddings.
Args:
sequence_data: A [batch_size, seq_length, num_hidden] tensor.
masked_positions: A [batch_size, num_prediction] tensor.
candidate_sets: A [batch_size, num_prediction, k] tensor.
Returns:
A [batch_size, num_prediction, k] tensor.
"""
# Gets shapes for later usage
candidate_set_shape = tf_utils.get_shape_list(candidate_sets)
num_prediction = candidate_set_shape[1]
# Gathers hidden vectors -> (batch_size, num_prediction, 1, embed_size)
masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
lm_data = self.dense(masked_lm_input)
lm_data = self.layer_norm(lm_data)
lm_data = tf.expand_dims(
tf.reshape(lm_data, [-1, num_prediction, self.embed_size]), 2)
# Gathers embeddings -> (batch_size, num_prediction, embed_size, k)
flat_candidate_sets = tf.reshape(candidate_sets, [-1])
candidate_embeddings = tf.gather(self.embedding_table, flat_candidate_sets)
candidate_embeddings = tf.reshape(
candidate_embeddings,
tf.concat([tf.shape(candidate_sets), [self.embed_size]], axis=0)
)
candidate_embeddings.set_shape(
candidate_sets.shape.as_list() + [self.embed_size])
candidate_embeddings = tf.transpose(candidate_embeddings, [0, 1, 3, 2])
# matrix multiplication + squeeze -> (batch_size, num_prediction, k)
logits = tf.matmul(lm_data, candidate_embeddings)
logits = tf.squeeze(logits, 2)
if self._output_type == 'logits':
return logits
return tf.nn.log_softmax(logits)
def _gather_indexes(self, sequence_tensor, positions):
"""Gathers the vectors at the specific positions.
Args:
sequence_tensor: Sequence output of shape
(`batch_size`, `seq_length`, `num_hidden`) where `num_hidden` is
number of hidden units.
positions: Positions ids of tokens in batched sequences.
Returns:
Sequence tensor of shape (batch_size * num_predictions,
num_hidden).
"""
sequence_shape = tf_utils.get_shape_list(
sequence_tensor, name='sequence_output_tensor')
batch_size, seq_length, width = sequence_shape
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
@tf_keras.utils.register_keras_serializable(package='Text')
class TeamsPretrainer(tf_keras.Model):
"""TEAMS network training model.
This is an implementation of the network structure described in "Training
ELECTRA Augmented with Multi-word Selection"
(https://arxiv.org/abs/2106.00139).
The TeamsPretrainer allows a user to pass in two transformer encoders, one
for generator, the other for discriminator (multi-word selection). The
pretrainer then instantiates the masked language model (at generator side) and
classification networks (including both multi-word selection head and replaced
token detection head) that are used to create the training objectives.
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside `__init__` and `call()` implements the computation.
Args:
generator_network: A transformer encoder for generator, this network should
output a sequence output.
discriminator_mws_network: A transformer encoder for multi-word selection
discriminator, this network should output a sequence output.
num_discriminator_task_agnostic_layers: Number of layers shared between
multi-word selection and random token detection discriminators.
vocab_size: Size of generator output vocabulary
candidate_size: Candidate size for multi-word selection task,
including the correct word.
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either `logits` or
`predictions`.
"""
def __init__(self,
generator_network,
discriminator_mws_network,
num_discriminator_task_agnostic_layers,
vocab_size,
candidate_size=5,
mlm_activation=None,
mlm_initializer='glorot_uniform',
output_type='logits',
**kwargs):
super().__init__()
self._config = {
'generator_network':
generator_network,
'discriminator_mws_network':
discriminator_mws_network,
'num_discriminator_task_agnostic_layers':
num_discriminator_task_agnostic_layers,
'vocab_size':
vocab_size,
'candidate_size':
candidate_size,
'mlm_activation':
mlm_activation,
'mlm_initializer':
mlm_initializer,
'output_type':
output_type,
}
for k, v in kwargs.items():
self._config[k] = v
self.generator_network = generator_network
self.discriminator_mws_network = discriminator_mws_network
self.vocab_size = vocab_size
self.candidate_size = candidate_size
self.mlm_activation = mlm_activation
self.mlm_initializer = mlm_initializer
self.output_type = output_type
self.masked_lm = layers.MaskedLM(
embedding_table=self.generator_network.embedding_network
.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
output=output_type,
name='generator_masked_lm')
discriminator_cfg = self.discriminator_mws_network.get_config()
self.num_task_agnostic_layers = num_discriminator_task_agnostic_layers
self.discriminator_rtd_head = ReplacedTokenDetectionHead(
encoder_cfg=discriminator_cfg,
num_task_agnostic_layers=self.num_task_agnostic_layers,
output=output_type,
name='discriminator_rtd')
hidden_cfg = discriminator_cfg['hidden_cfg']
self.discriminator_mws_head = MultiWordSelectionHead(
embedding_table=self.discriminator_mws_network.embedding_network
.get_embedding_table(),
activation=hidden_cfg['intermediate_activation'],
initializer=hidden_cfg['kernel_initializer'],
output=output_type,
name='discriminator_mws')
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
"""TEAMS forward pass.
Args:
inputs: A dict of all inputs, same as the standard BERT model.
Returns:
outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
tensor indicating logits on masked positions.
(2) disc_rtd_logits: A `[batch_size, sequence_length]` tensor indicating
logits for discriminator replaced token detection task.
(3) disc_rtd_label: A `[batch_size, sequence_length]` tensor indicating
target labels for discriminator replaced token detection task.
(4) disc_mws_logits: A `[batch_size, num_token_predictions,
candidate_size]` tensor indicating logits for discriminator multi-word
selection task.
(5) disc_mws_labels: A `[batch_size, num_token_predictions]` tensor
indicating target labels for discriminator multi-word selection task.
"""
input_word_ids = inputs['input_word_ids']
input_mask = inputs['input_mask']
input_type_ids = inputs['input_type_ids']
masked_lm_positions = inputs['masked_lm_positions']
# Runs generator.
sequence_output = self.generator_network(
[input_word_ids, input_mask, input_type_ids])['sequence_output']
lm_outputs = self.masked_lm(sequence_output, masked_lm_positions)
# Samples tokens from generator.
fake_data = self._get_fake_data(inputs, lm_outputs)
# Runs discriminator.
disc_input = fake_data['inputs']
disc_rtd_label = fake_data['is_fake_tokens']
disc_mws_candidates = fake_data['candidate_set']
mws_sequence_outputs = self.discriminator_mws_network([
disc_input['input_word_ids'], disc_input['input_mask'],
disc_input['input_type_ids']
])['encoder_outputs']
# Applies replaced token detection with input selected based on
# self.num_discriminator_task_agnostic_layers
disc_rtd_logits = self.discriminator_rtd_head(
mws_sequence_outputs[self.num_task_agnostic_layers - 1], input_mask)
# Applies multi-word selection.
disc_mws_logits = self.discriminator_mws_head(mws_sequence_outputs[-1],
masked_lm_positions,
disc_mws_candidates)
disc_mws_label = tf.zeros_like(masked_lm_positions, dtype=tf.int32)
outputs = {
'lm_outputs': lm_outputs,
'disc_rtd_logits': disc_rtd_logits,
'disc_rtd_label': disc_rtd_label,
'disc_mws_logits': disc_mws_logits,
'disc_mws_label': disc_mws_label,
}
return outputs
def _get_fake_data(self, inputs, mlm_logits):
"""Generate corrupted data for discriminator.
Note it is poosible for sampled token to be the same as the correct one.
Args:
inputs: A dict of all inputs, same as the input of `call()` function
mlm_logits: The generator's output logits
Returns:
A dict of generated fake data
"""
inputs = models.electra_pretrainer.unmask(inputs, duplicate=True)
# Samples replaced token.
sampled_tokens = tf.stop_gradient(
models.electra_pretrainer.sample_from_softmax(
mlm_logits, disallow=None))
sampled_tokids = tf.argmax(sampled_tokens, axis=-1, output_type=tf.int32)
# Prepares input and label for replaced token detection task.
updated_input_ids, masked = models.electra_pretrainer.scatter_update(
inputs['input_word_ids'], sampled_tokids, inputs['masked_lm_positions'])
rtd_labels = masked * (1 - tf.cast(
tf.equal(updated_input_ids, inputs['input_word_ids']), tf.int32))
updated_inputs = models.electra_pretrainer.get_updated_inputs(
inputs, duplicate=True, input_word_ids=updated_input_ids)
# Samples (candidate_size-1) negatives and concat with true tokens
disallow = tf.one_hot(
inputs['masked_lm_ids'], depth=self.vocab_size, dtype=tf.float32)
sampled_candidates = tf.stop_gradient(
sample_k_from_softmax(mlm_logits, k=self.candidate_size-1,
disallow=disallow))
true_token_id = tf.expand_dims(inputs['masked_lm_ids'], -1)
candidate_set = tf.concat([true_token_id, sampled_candidates], -1)
return {
'inputs': updated_inputs,
'is_fake_tokens': rtd_labels,
'sampled_tokens': sampled_tokens,
'candidate_set': candidate_set
}
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.discriminator_mws_network)
return items
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
"""Implement softmax sampling using gumbel softmax trick to select k items.
Args:
logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating
the generator output logits for each masked position.
k: Number of samples
disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size]
indicating the true word id in each masked position.
use_topk: Whether to use tf.nn.top_k or using iterative approach where the
latter is empirically faster.
Returns:
sampled_tokens: A [batch_size, num_token_predictions, k] tensor indicating
the sampled word id in each masked position.
"""
if use_topk:
if disallow is not None:
logits -= _LOGIT_PENALTY_MULTIPLIER * disallow
uniform_noise = tf.random.uniform(
tf_utils.get_shape_list(logits), minval=0, maxval=1)
gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)
_, sampled_tokens = tf.nn.top_k(logits + gumbel_noise, k=k, sorted=False)
else:
sampled_tokens_list = []
vocab_size = tf_utils.get_shape_list(logits)[-1]
if disallow is not None:
logits -= _LOGIT_PENALTY_MULTIPLIER * disallow
uniform_noise = tf.random.uniform(
tf_utils.get_shape_list(logits), minval=0, maxval=1)
gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)
logits += gumbel_noise
for _ in range(k):
token_ids = tf.argmax(logits, -1, output_type=tf.int32)
sampled_tokens_list.append(token_ids)
logits -= _LOGIT_PENALTY_MULTIPLIER * tf.one_hot(
token_ids, depth=vocab_size, dtype=tf.float32)
sampled_tokens = tf.stack(sampled_tokens_list, -1)
return sampled_tokens