official/projects/edgetpu/nlp/modeling/pretrainer.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT Pre-training model."""
# pylint: disable=g-classes-have-attributes
import copy
from typing import List, Optional
import tensorflow as tf, tf_keras
from official.nlp.modeling import layers
@tf_keras.utils.register_keras_serializable(package='Text')
class MobileBERTEdgeTPUPretrainer(tf_keras.Model):
"""BERT pretraining model V2.
Adds the masked language model head and optional classification heads upon the
transformer encoder.
Args:
encoder_network: A transformer network. This network should output a
sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network. If
None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM. Default
to a Glorot uniform initializer.
classification_heads: A list of optional head layers to transform on encoder
sequence outputs.
customized_masked_lm: A customized masked_lm layer. If None, will create
a standard layer from `layers.MaskedLM`; if not None, will use the
specified masked_lm layer. Above arguments `mlm_activation` and
`mlm_initializer` will be ignored.
name: The name of the model.
Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a
dictionary.
Outputs: A dictionary of `lm_output`, classification head outputs keyed by
head names, and also outputs from `encoder_network`, keyed by
`sequence_output` and `encoder_outputs` (if any).
"""
def __init__(
self,
encoder_network: tf_keras.Model,
mlm_activation=None,
mlm_initializer='glorot_uniform',
classification_heads: Optional[List[tf_keras.layers.Layer]] = None,
customized_masked_lm: Optional[tf_keras.layers.Layer] = None,
name: str = 'bert',
**kwargs):
inputs = copy.copy(encoder_network.inputs)
outputs = {}
encoder_network_outputs = encoder_network(inputs)
if isinstance(encoder_network_outputs, list):
outputs['pooled_output'] = encoder_network_outputs[1]
if isinstance(encoder_network_outputs[0], list):
outputs['encoder_outputs'] = encoder_network_outputs[0]
outputs['sequence_output'] = encoder_network_outputs[0][-1]
else:
outputs['sequence_output'] = encoder_network_outputs[0]
elif isinstance(encoder_network_outputs, dict):
outputs = encoder_network_outputs
else:
raise ValueError('encoder_network\'s output should be either a list '
'or a dict, but got %s' % encoder_network_outputs)
masked_lm_positions = tf_keras.layers.Input(
shape=(None,), name='masked_lm_positions', dtype=tf.int32)
inputs.append(masked_lm_positions)
masked_lm_layer = customized_masked_lm or layers.MaskedLM(
embedding_table=encoder_network.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
name='cls/predictions')
sequence_output = outputs['sequence_output']
outputs['mlm_logits'] = masked_lm_layer(
sequence_output, masked_positions=masked_lm_positions)
classification_head_layers = classification_heads or []
for cls_head in classification_head_layers:
cls_outputs = cls_head(sequence_output)
if isinstance(cls_outputs, dict):
outputs.update(cls_outputs)
else:
outputs[cls_head.name] = cls_outputs
super(MobileBERTEdgeTPUPretrainer, self).__init__(
inputs=inputs,
outputs=outputs,
name=name,
**kwargs)
self._config = {
'encoder_network': encoder_network,
'mlm_activation': mlm_activation,
'mlm_initializer': mlm_initializer,
'classification_heads': classification_heads,
'customized_masked_lm': customized_masked_lm,
'name': name,
}
self.encoder_network = encoder_network
self.masked_lm = masked_lm_layer
self.classification_heads = classification_head_layers
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.encoder_network, masked_lm=self.masked_lm)
for head in self.classification_heads:
for key, item in head.checkpoint_items.items():
items['.'.join([head.name, key])] = item
return items
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)