official/projects/edgetpu/nlp/serving/export_tflite_squad.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=line-too-long
r"""Export tflite for MobileBERT-EdgeTPU with SQUAD head.
Example usage:
python3 export_tflite_squad.py \
--config_file=official/projects/edgetpu/nlp/experiments/mobilebert_edgetpu_xs.yaml \
--export_path=/tmp/ \
--quantization_method=full-integer
"""
# pylint: enable=line-too-long
import os
import tempfile
from typing import Sequence
from absl import app
from absl import flags
from absl import logging
import orbit
import tensorflow as tf, tf_keras
from official.common import flags as tfm_flags
from official.nlp.data import data_loader_factory
from official.nlp.data import question_answering_dataloader
from official.nlp.modeling import models
from official.projects.edgetpu.nlp.configs import params
from official.projects.edgetpu.nlp.modeling import model_builder
from official.projects.edgetpu.nlp.utils import utils
FLAGS = flags.FLAGS
SQUAD_TRAIN_SPLIT = 'gs://**/tp/bert/squad_v1.1/train.tf_record'
flags.DEFINE_string('export_path', '/tmp/',
'File path to store tflite model.')
flags.DEFINE_enum('quantization_method', 'float',
['full-integer', 'hybrid', 'float'], 'Quantization method.')
flags.DEFINE_integer('batch_size', 1,
'Fixed batch size for exported TFLite model.')
flags.DEFINE_integer('sequence_length', 384,
'Fixed sequence length.')
flags.DEFINE_string('model_checkpoint', None,
'Checkpoint path for the model. Model will be initialized'
'with random weights if path is None.')
def build_model_for_serving(model: tf_keras.Model,
sequence_length: int = 384,
batch_size: int = 1) -> tf_keras.Model:
"""Builds MLPerf evaluation compatible models.
To run the model on device, the model input/output datatype and node names
need to match the MLPerf setup.
Args:
model: Input keras model.
sequence_length: BERT model sequence length.
batch_size: Inference batch size.
Returns:
Keras model with new input/output nodes.
"""
word_ids = tf_keras.Input(shape=(sequence_length,),
batch_size=batch_size,
dtype=tf.int32,
name='input_word_ids')
mask = tf_keras.Input(shape=(sequence_length,),
batch_size=batch_size,
dtype=tf.int32, name='input_mask')
type_ids = tf_keras.Input(shape=(sequence_length,),
batch_size=batch_size,
dtype=tf.int32, name='input_type_ids')
model_output = model([word_ids, type_ids, mask])
# Use identity layers wrapped in lambdas to explicitly name the output
# tensors.
start_logits = tf_keras.layers.Lambda(
tf.identity, name='start_positions')(
model_output[0])
end_logits = tf_keras.layers.Lambda(
tf.identity, name='end_positions')(
model_output[1])
model = tf_keras.Model(
inputs=[word_ids, type_ids, mask],
outputs=[start_logits, end_logits])
return model
def build_inputs(data_params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task."""
return data_loader_factory.get_data_loader(data_params).load(input_context)
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
# Set up experiment params and load the configs from file/files.
experiment_params = params.EdgeTPUBERTCustomParams()
experiment_params = utils.config_override(experiment_params, FLAGS)
# change the input mask type to tf.float32 to avoid additional casting op.
experiment_params.student_model.encoder.mobilebert.input_mask_dtype = 'float32'
# Experiments indicate using -120 as the mask value for Softmax is good enough
# for both int8 and bfloat. So we set quantization_friendly to True for both
# quant and float model.
pretrainer_model = model_builder.build_bert_pretrainer(
experiment_params.student_model,
name='pretrainer',
quantization_friendly=True)
encoder_network = pretrainer_model.encoder_network
model = models.BertSpanLabeler(
network=encoder_network,
initializer=tf_keras.initializers.TruncatedNormal(stddev=0.01))
# Load model weights.
if FLAGS.model_checkpoint is not None:
checkpoint_dict = {'model': model}
checkpoint = tf.train.Checkpoint(**checkpoint_dict)
checkpoint.restore(FLAGS.model_checkpoint).assert_existing_objects_matched()
model_for_serving = build_model_for_serving(model, FLAGS.sequence_length,
FLAGS.batch_size)
model_for_serving.summary()
# TODO(b/194449109): Need to save the model to file and then convert tflite
# with 'tf.lite.TFLiteConverter.from_saved_model()' to get the expected
# accuracy
tmp_dir = tempfile.TemporaryDirectory().name
model_for_serving.save(tmp_dir)
def _representative_dataset():
dataset_params = question_answering_dataloader.QADataConfig()
dataset_params.input_path = SQUAD_TRAIN_SPLIT
dataset_params.drop_remainder = False
dataset_params.global_batch_size = 1
dataset_params.is_training = True
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
build_inputs, dataset_params)
for example in dataset.take(100):
inputs = example[0]
input_word_ids = inputs['input_word_ids']
input_mask = inputs['input_mask']
input_type_ids = inputs['input_type_ids']
yield [input_word_ids, input_mask, input_type_ids]
converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir)
if FLAGS.quantization_method in ['full-integer', 'hybrid']:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
if FLAGS.quantization_method in ['full-integer']:
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.float32
converter.representative_dataset = _representative_dataset
tflite_quant_model = converter.convert()
export_model_path = os.path.join(FLAGS.export_path, 'model.tflite')
with tf.io.gfile.GFile(export_model_path, 'wb') as f:
f.write(tflite_quant_model)
logging.info('Successfully save the tflite to %s', FLAGS.export_path)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)