official/projects/mobilebert/distillation.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Progressive distillation for MobileBERT student model."""
import dataclasses
from typing import List, Optional
from absl import logging
import orbit
import tensorflow as tf, tf_keras
from official.core import base_task
from official.core import config_definitions as cfg
from official.modeling import optimization
from official.modeling import tf_utils
from official.modeling.fast_training.progressive import policies
from official.modeling.hyperparams import base_config
from official.nlp import modeling
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory
from official.nlp.modeling import layers
from official.nlp.modeling import models
@dataclasses.dataclass
class LayerWiseDistillConfig(base_config.Config):
"""Defines the behavior of layerwise distillation."""
num_steps: int = 10000
warmup_steps: int = 0
initial_learning_rate: float = 1.5e-3
end_learning_rate: float = 1.5e-3
decay_steps: int = 10000
hidden_distill_factor: float = 100.0
beta_distill_factor: float = 5000.0
gamma_distill_factor: float = 5.0
if_transfer_attention: bool = True
attention_distill_factor: float = 1.0
if_freeze_previous_layers: bool = False
# The ids of teacher layers that will be mapped to the student model.
# For example, if you want to compress a 24 layer teacher to a 6 layer
# student, you can set it to [3, 7, 11, 15, 19, 23] (the index starts from 0).
# If `None`, we assume teacher and student have the same number of layers,
# and each layer of teacher model will be mapped to student's corresponding
# layer.
transfer_teacher_layers: Optional[List[int]] = None
@dataclasses.dataclass
class PretrainDistillConfig(base_config.Config):
"""Defines the behavior of pretrain distillation."""
num_steps: int = 500000
warmup_steps: int = 10000
initial_learning_rate: float = 1.5e-3
end_learning_rate: float = 1.5e-7
decay_steps: int = 500000
if_use_nsp_loss: bool = True
distill_ground_truth_ratio: float = 0.5
@dataclasses.dataclass
class BertDistillationProgressiveConfig(policies.ProgressiveConfig):
"""Defines the specific distillation behavior."""
if_copy_embeddings: bool = True
layer_wise_distill_config: LayerWiseDistillConfig = dataclasses.field(
default_factory=LayerWiseDistillConfig
)
pretrain_distill_config: PretrainDistillConfig = dataclasses.field(
default_factory=PretrainDistillConfig
)
@dataclasses.dataclass
class BertDistillationTaskConfig(cfg.TaskConfig):
"""Defines the teacher/student model architecture and training data."""
teacher_model: bert.PretrainerConfig = dataclasses.field(
default_factory=lambda: bert.PretrainerConfig( # pylint: disable=g-long-lambda
encoder=encoders.EncoderConfig(type='mobilebert')
)
)
student_model: bert.PretrainerConfig = dataclasses.field(
default_factory=lambda: bert.PretrainerConfig( # pylint: disable=g-long-lambda
encoder=encoders.EncoderConfig(type='mobilebert')
)
)
# The path to the teacher model checkpoint or its directory.
teacher_model_init_checkpoint: str = ''
train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
validation_data: cfg.DataConfig = dataclasses.field(
default_factory=cfg.DataConfig
)
def build_sub_encoder(encoder, target_layer_id):
"""Builds an encoder that only computes first few transformer layers."""
input_ids = encoder.inputs[0]
input_mask = encoder.inputs[1]
type_ids = encoder.inputs[2]
attention_mask = modeling.layers.SelfAttentionMask()(
inputs=input_ids, to_mask=input_mask)
embedding_output = encoder.embedding_layer(input_ids, type_ids)
layer_output = embedding_output
attention_score = None
for layer_idx in range(target_layer_id + 1):
layer_output, attention_score = encoder.transformer_layers[layer_idx](
layer_output, attention_mask, return_attention_scores=True)
return tf_keras.Model(
inputs=[input_ids, input_mask, type_ids],
outputs=[layer_output, attention_score])
class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
"""Distillation language modeling task progressively."""
def __init__(self,
strategy,
progressive: BertDistillationProgressiveConfig,
optimizer_config: optimization.OptimizationConfig,
task_config: BertDistillationTaskConfig,
logging_dir=None):
self._strategy = strategy
self._task_config = task_config
self._progressive_config = progressive
self._optimizer_config = optimizer_config
self._train_data_config = task_config.train_data
self._eval_data_config = task_config.validation_data
self._the_only_train_dataset = None
self._the_only_eval_dataset = None
layer_wise_config = self._progressive_config.layer_wise_distill_config
transfer_teacher_layers = layer_wise_config.transfer_teacher_layers
num_teacher_layers = (
self._task_config.teacher_model.encoder.mobilebert.num_blocks)
num_student_layers = (
self._task_config.student_model.encoder.mobilebert.num_blocks)
if transfer_teacher_layers and len(
transfer_teacher_layers) != num_student_layers:
raise ValueError('The number of `transfer_teacher_layers` %s does not '
'match the number of student layers. %d' %
(transfer_teacher_layers, num_student_layers))
if not transfer_teacher_layers and (num_teacher_layers !=
num_student_layers):
raise ValueError('`transfer_teacher_layers` is not specified, and the '
'number of teacher layers does not match '
'the number of student layers.')
ratio = progressive.pretrain_distill_config.distill_ground_truth_ratio
if ratio < 0 or ratio > 1:
raise ValueError('distill_ground_truth_ratio has to be within [0, 1].')
# A non-trainable layer for feature normalization for transfer loss
self._layer_norm = tf_keras.layers.LayerNormalization(
axis=-1,
beta_initializer='zeros',
gamma_initializer='ones',
trainable=False)
# Build the teacher and student pretrainer model.
self._teacher_pretrainer = self._build_pretrainer(
self._task_config.teacher_model, name='teacher')
self._student_pretrainer = self._build_pretrainer(
self._task_config.student_model, name='student')
base_task.Task.__init__(
self, params=task_config, logging_dir=logging_dir)
policies.ProgressivePolicy.__init__(self)
def _build_pretrainer(self, pretrainer_cfg: bert.PretrainerConfig, name: str):
"""Builds pretrainer from config and encoder."""
encoder = encoders.build_encoder(pretrainer_cfg.encoder)
if pretrainer_cfg.cls_heads:
cls_heads = [
layers.ClassificationHead(**cfg.as_dict())
for cfg in pretrainer_cfg.cls_heads
]
else:
cls_heads = []
masked_lm = layers.MobileBertMaskedLM(
embedding_table=encoder.get_embedding_table(),
activation=tf_utils.get_activation(pretrainer_cfg.mlm_activation),
initializer=tf_keras.initializers.TruncatedNormal(
stddev=pretrainer_cfg.mlm_initializer_range),
name='cls/predictions')
pretrainer = models.BertPretrainerV2(
encoder_network=encoder,
classification_heads=cls_heads,
customized_masked_lm=masked_lm,
name=name)
return pretrainer
# override policies.ProgressivePolicy
def num_stages(self):
# One stage for each layer, plus additional stage for pre-training
return self._task_config.student_model.encoder.mobilebert.num_blocks + 1
# override policies.ProgressivePolicy
def num_steps(self, stage_id) -> int:
"""Return the total number of steps in this stage."""
if stage_id + 1 < self.num_stages():
return self._progressive_config.layer_wise_distill_config.num_steps
else:
return self._progressive_config.pretrain_distill_config.num_steps
# override policies.ProgressivePolicy
def get_model(self, stage_id, old_model=None) -> tf_keras.Model:
del old_model
return self.build_model(stage_id)
# override policies.ProgressivePolicy
def get_optimizer(self, stage_id):
"""Build optimizer for each stage."""
if stage_id + 1 < self.num_stages():
distill_config = self._progressive_config.layer_wise_distill_config
else:
distill_config = self._progressive_config.pretrain_distill_config
params = self._optimizer_config.replace(
learning_rate={
'polynomial': {
'decay_steps':
distill_config.decay_steps,
'initial_learning_rate':
distill_config.initial_learning_rate,
'end_learning_rate':
distill_config.end_learning_rate,
}
},
warmup={
'linear':
{'warmup_steps':
distill_config.warmup_steps,
}
})
opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
if isinstance(optimizer, tf_keras.optimizers.experimental.Optimizer):
optimizer = tf_keras.__internal__.optimizers.convert_to_legacy_optimizer(
optimizer)
return optimizer
# override policies.ProgressivePolicy
def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
"""Return Dataset for this stage."""
del stage_id
if self._the_only_train_dataset is None:
self._the_only_train_dataset = orbit.utils.make_distributed_dataset(
self._strategy, self.build_inputs, self._train_data_config)
return self._the_only_train_dataset
# overrides policies.ProgressivePolicy
def get_eval_dataset(self, stage_id):
del stage_id
if self._the_only_eval_dataset is None:
self._the_only_eval_dataset = orbit.utils.make_distributed_dataset(
self._strategy, self.build_inputs, self._eval_data_config)
return self._the_only_eval_dataset
# override base_task.task
def build_model(self, stage_id) -> tf_keras.Model:
"""Build teacher/student keras models with outputs for current stage."""
# Freeze the teacher model.
self._teacher_pretrainer.trainable = False
layer_wise_config = self._progressive_config.layer_wise_distill_config
freeze_previous_layers = layer_wise_config.if_freeze_previous_layers
student_encoder = self._student_pretrainer.encoder_network
if stage_id != self.num_stages() - 1:
# Build a model that outputs teacher's and student's transformer outputs.
inputs = student_encoder.inputs
student_sub_encoder = build_sub_encoder(
encoder=student_encoder, target_layer_id=stage_id)
student_output_feature, student_attention_score = student_sub_encoder(
inputs)
if layer_wise_config.transfer_teacher_layers:
teacher_layer_id = layer_wise_config.transfer_teacher_layers[stage_id]
else:
teacher_layer_id = stage_id
teacher_sub_encoder = build_sub_encoder(
encoder=self._teacher_pretrainer.encoder_network,
target_layer_id=teacher_layer_id)
teacher_output_feature, teacher_attention_score = teacher_sub_encoder(
inputs)
if freeze_previous_layers:
student_encoder.embedding_layer.trainable = False
for i in range(stage_id):
student_encoder.transformer_layers[i].trainable = False
return tf_keras.Model(
inputs=inputs,
outputs=dict(
student_output_feature=student_output_feature,
student_attention_score=student_attention_score,
teacher_output_feature=teacher_output_feature,
teacher_attention_score=teacher_attention_score))
else:
# Build a model that outputs teacher's and student's MLM/NSP outputs.
inputs = self._student_pretrainer.inputs
student_pretrainer_output = self._student_pretrainer(inputs)
teacher_pretrainer_output = self._teacher_pretrainer(inputs)
# Set all student's transformer blocks to trainable.
if freeze_previous_layers:
student_encoder.embedding_layer.trainable = True
for layer in student_encoder.transformer_layers:
layer.trainable = True
model = tf_keras.Model(
inputs=inputs,
outputs=dict(
student_pretrainer_output=student_pretrainer_output,
teacher_pretrainer_output=teacher_pretrainer_output,
))
# Checkpoint the student encoder which is the goal of distillation.
model.checkpoint_items = self._student_pretrainer.checkpoint_items
return model
# overrides base_task.Task
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for pretraining."""
# copy from masked_lm.py for testing
if params.input_path == 'dummy':
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32)
return dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_lm,
masked_lm_ids=dummy_lm,
masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32),
next_sentence_labels=tf.zeros((1, 1), dtype=tf.int32))
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
return data_loader_factory.get_data_loader(params).load(input_context)
def _get_distribution_losses(self, teacher, student):
"""Return the beta and gamma distall losses for feature distribution."""
teacher_mean = tf.math.reduce_mean(teacher, axis=-1, keepdims=True)
student_mean = tf.math.reduce_mean(student, axis=-1, keepdims=True)
teacher_var = tf.math.reduce_variance(teacher, axis=-1, keepdims=True)
student_var = tf.math.reduce_variance(student, axis=-1, keepdims=True)
beta_loss = tf.math.squared_difference(student_mean, teacher_mean)
beta_loss = tf.math.reduce_mean(beta_loss, axis=None, keepdims=False)
gamma_loss = tf.math.abs(student_var - teacher_var)
gamma_loss = tf.math.reduce_mean(gamma_loss, axis=None, keepdims=False)
return beta_loss, gamma_loss
def _get_attention_loss(self, teacher_score, student_score):
# Note that the definition of KLDivergence here is a little different from
# the original one (tf_keras.losses.KLDivergence). We adopt this approach
# to stay consistent with the TF1 implementation.
teacher_weight = tf_keras.activations.softmax(teacher_score, axis=-1)
student_log_weight = tf.nn.log_softmax(student_score, axis=-1)
kl_divergence = -(teacher_weight * student_log_weight)
kl_divergence = tf.math.reduce_sum(kl_divergence, axis=-1, keepdims=True)
kl_divergence = tf.math.reduce_mean(kl_divergence, axis=None,
keepdims=False)
return kl_divergence
def build_losses(self, labels, outputs, metrics) -> tf.Tensor:
"""Builds losses and update loss-related metrics for the current stage."""
last_stage = 'student_pretrainer_output' in outputs
# Layer-wise warmup stage
if not last_stage:
distill_config = self._progressive_config.layer_wise_distill_config
teacher_feature = outputs['teacher_output_feature']
student_feature = outputs['student_output_feature']
feature_transfer_loss = tf_keras.losses.mean_squared_error(
self._layer_norm(teacher_feature), self._layer_norm(student_feature))
feature_transfer_loss *= distill_config.hidden_distill_factor
beta_loss, gamma_loss = self._get_distribution_losses(teacher_feature,
student_feature)
beta_loss *= distill_config.beta_distill_factor
gamma_loss *= distill_config.gamma_distill_factor
total_loss = feature_transfer_loss + beta_loss + gamma_loss
if distill_config.if_transfer_attention:
teacher_attention = outputs['teacher_attention_score']
student_attention = outputs['student_attention_score']
attention_loss = self._get_attention_loss(teacher_attention,
student_attention)
attention_loss *= distill_config.attention_distill_factor
total_loss += attention_loss
total_loss /= tf.cast((self._stage_id + 1), tf.float32)
# Last stage to distill pretraining layer.
else:
distill_config = self._progressive_config.pretrain_distill_config
lm_label = labels['masked_lm_ids']
vocab_size = (
self._task_config.student_model.encoder.mobilebert.word_vocab_size)
# Shape: [batch, max_predictions_per_seq, vocab_size]
lm_label = tf.one_hot(indices=lm_label, depth=vocab_size, on_value=1.0,
off_value=0.0, axis=-1, dtype=tf.float32)
gt_ratio = distill_config.distill_ground_truth_ratio
if gt_ratio != 1.0:
teacher_mlm_logits = outputs['teacher_pretrainer_output']['mlm_logits']
teacher_labels = tf.nn.softmax(teacher_mlm_logits, axis=-1)
lm_label = gt_ratio * lm_label + (1-gt_ratio) * teacher_labels
student_pretrainer_output = outputs['student_pretrainer_output']
# Shape: [batch, max_predictions_per_seq, vocab_size]
student_lm_log_probs = tf.nn.log_softmax(
student_pretrainer_output['mlm_logits'], axis=-1)
# Shape: [batch * max_predictions_per_seq]
per_example_loss = tf.reshape(
-tf.reduce_sum(student_lm_log_probs * lm_label, axis=[-1]), [-1])
lm_label_weights = tf.reshape(labels['masked_lm_weights'], [-1])
lm_numerator_loss = tf.reduce_sum(per_example_loss * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
total_loss = mlm_loss
if 'next_sentence_labels' in labels:
sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast(
student_pretrainer_output['next_sentence'], dtype=tf.float32)
sentence_loss = tf.reduce_mean(
tf_keras.losses.sparse_categorical_crossentropy(
sentence_labels, sentence_outputs, from_logits=True))
total_loss += sentence_loss
# Also update loss-related metrics here, instead of in `process_metrics`.
metrics = dict([(metric.name, metric) for metric in metrics])
if not last_stage:
metrics['feature_transfer_mse'].update_state(feature_transfer_loss)
metrics['beta_transfer_loss'].update_state(beta_loss)
metrics['gamma_transfer_loss'].update_state(gamma_loss)
layer_wise_config = self._progressive_config.layer_wise_distill_config
if layer_wise_config.if_transfer_attention:
metrics['attention_transfer_loss'].update_state(attention_loss)
else:
metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in labels:
metrics['next_sentence_loss'].update_state(sentence_loss)
metrics['total_loss'].update_state(total_loss)
return total_loss
# overrides base_task.Task
def build_metrics(self, training=None):
del training
metrics = [
tf_keras.metrics.Mean(name='feature_transfer_mse'),
tf_keras.metrics.Mean(name='beta_transfer_loss'),
tf_keras.metrics.Mean(name='gamma_transfer_loss'),
tf_keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'),
tf_keras.metrics.Mean(name='lm_example_loss'),
tf_keras.metrics.Mean(name='total_loss')]
if self._progressive_config.layer_wise_distill_config.if_transfer_attention:
metrics.append(tf_keras.metrics.Mean(name='attention_transfer_loss'))
if self._task_config.train_data.use_next_sentence_label:
metrics.append(tf_keras.metrics.SparseCategoricalAccuracy(
name='next_sentence_accuracy'))
metrics.append(tf_keras.metrics.Mean(name='next_sentence_loss'))
return metrics
# overrides base_task.Task
# process non-loss metrics
def process_metrics(self, metrics, labels, student_pretrainer_output):
metrics = dict([(metric.name, metric) for metric in metrics])
# Final pretrainer layer distillation stage.
if student_pretrainer_output is not None:
if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(
labels['masked_lm_ids'], student_pretrainer_output['mlm_logits'],
labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state(
labels['next_sentence_labels'],
student_pretrainer_output['next_sentence'])
# overrides base_task.Task
def train_step(self, inputs, model: tf_keras.Model,
optimizer: tf_keras.optimizers.Optimizer, metrics):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
# Computes per-replica loss.
loss = self.build_losses(
labels=inputs,
outputs=outputs,
metrics=metrics)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# TODO(b/154564893): enable loss scaling.
# scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
# get trainable variables for current stage
tvars = model.trainable_variables
last_stage = 'student_pretrainer_output' in outputs
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
self.process_metrics(
metrics, inputs,
outputs['student_pretrainer_output'] if last_stage else None)
return {self.loss: loss}
# overrides base_task.Task
def validation_step(self, inputs, model: tf_keras.Model, metrics):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
outputs = model(inputs, training=False)
# Computes per-replica loss.
loss = self.build_losses(labels=inputs, outputs=outputs, metrics=metrics)
last_stage = 'student_pretrainer_output' in outputs
self.process_metrics(
metrics, inputs,
outputs['student_pretrainer_output'] if last_stage else None)
return {self.loss: loss}
@property
def cur_checkpoint_items(self):
"""Checkpoints for model, stage_id, optimizer for preemption handling."""
return dict(
stage_id=self._stage_id,
volatiles=self._volatiles,
student_pretrainer=self._student_pretrainer,
teacher_pretrainer=self._teacher_pretrainer,
encoder=self._student_pretrainer.encoder_network)
def initialize(self, model):
"""Loads teacher's pretrained checkpoint and copy student's embedding."""
# This function will be called when no checkpoint found for the model,
# i.e., when the training starts (not preemption case).
# The weights of teacher pretrainer and student pretrainer will be
# initialized, rather than the passed-in `model`.
del model
logging.info('Begin to load checkpoint for teacher pretrainer model.')
ckpt_dir_or_file = self._task_config.teacher_model_init_checkpoint
if not ckpt_dir_or_file:
raise ValueError('`teacher_model_init_checkpoint` is not specified.')
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Makes sure the teacher pretrainer variables are created.
_ = self._teacher_pretrainer(self._teacher_pretrainer.inputs)
teacher_checkpoint = tf.train.Checkpoint(
**self._teacher_pretrainer.checkpoint_items)
teacher_checkpoint.read(ckpt_dir_or_file).assert_existing_objects_matched()
logging.info('Begin to copy word embedding from teacher model to student.')
teacher_encoder = self._teacher_pretrainer.encoder_network
student_encoder = self._student_pretrainer.encoder_network
embedding_weights = teacher_encoder.embedding_layer.get_weights()
student_encoder.embedding_layer.set_weights(embedding_weights)