tensorflow/models

View on GitHub
official/nlp/tools/export_tfhub.py

Summary

Maintainability
A
3 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub.

This tool creates preprocessor and encoder SavedModels suitable for uploading
to https://tfhub.dev that implement the preprocessor and encoder APIs defined
at https://www.tensorflow.org/hub/common_saved_model_apis/text.

For a full usage guide, see
https://github.com/tensorflow/models/blob/master/official/nlp/docs/tfhub.md

Minimal usage examples:

1) Exporting an Encoder from checkpoint and config.

```
export_tfhub \
  --encoder_config_file=${BERT_DIR:?}/bert_encoder.yaml \
  --model_checkpoint_path=${BERT_DIR:?}/bert_model.ckpt \
  --vocab_file=${BERT_DIR:?}/vocab.txt \
  --export_type=model \
  --export_path=/tmp/bert_model
```

An --encoder_config_file can specify encoder types other than BERT.
For BERT, a --bert_config_file in the legacy JSON format can be passed instead.

Flag --vocab_file (and flag --do_lower_case, whose default value is guessed
from the vocab_file path) capture how BertTokenizer was used in pre-training.
Use flag --sp_model_file instead if SentencepieceTokenizer was used.

Changing --export_type to model_with_mlm additionally creates an `.mlm`
subobject on the exported SavedModel that can be called to produce
the logits of the Masked Language Model task from pretraining.
The help string for flag --model_checkpoint_path explains the checkpoint
formats required for each --export_type.


2) Exporting a preprocessor SavedModel

```
export_tfhub \
  --vocab_file ${BERT_DIR:?}/vocab.txt \
  --export_type preprocessing --export_path /tmp/bert_preprocessing
```

Be sure to use flag values that match the encoder and how it has been
pre-trained (see above for --vocab_file vs --sp_model_file).

If your encoder has been trained with text preprocessing for which tfhub.dev
already has SavedModel, you could guide your users to reuse that one instead
of exporting and publishing your own.

TODO(b/175369555): When exporting to users of TensorFlow 2.4, add flag
`--experimental_disable_assert_in_preprocessing`.
"""

from absl import app
from absl import flags
import gin

from official.legacy.bert import configs
from official.modeling import hyperparams
from official.nlp.configs import encoders
from official.nlp.tools import export_tfhub_lib

FLAGS = flags.FLAGS

flags.DEFINE_enum(
    "export_type", "model",
    ["model", "model_with_mlm", "preprocessing"],
    "The overall type of SavedModel to export. Flags "
    "--bert_config_file/--encoder_config_file and --vocab_file/--sp_model_file "
    "control which particular encoder model and preprocessing are exported.")
flags.DEFINE_string(
    "export_path", None,
    "Directory to which the SavedModel is written.")
flags.DEFINE_string(
    "encoder_config_file", None,
    "A yaml file representing `encoders.EncoderConfig` to define the encoder "
    "(BERT or other). "
    "Exactly one of --bert_config_file and --encoder_config_file can be set. "
    "Needed for --export_type model and model_with_mlm.")
flags.DEFINE_string(
    "bert_config_file", None,
    "A JSON file with a legacy BERT configuration to define the BERT encoder. "
    "Exactly one of --bert_config_file and --encoder_config_file can be set. "
    "Needed for --export_type model and model_with_mlm.")
flags.DEFINE_bool(
    "copy_pooler_dense_to_encoder", False,
    "When the model is trained using `BertPretrainerV2`, the pool layer "
    "of next sentence prediction task exists in `ClassificationHead` passed "
    "to `BertPretrainerV2`. If True, we will copy this pooler's dense layer "
    "to the encoder that is exported by this tool (as in classic BERT). "
    "Using `BertPretrainerV2` and leaving this False exports an untrained "
    "(randomly initialized) pooling layer, which some authors recommend for "
    "subsequent fine-tuning,")
flags.DEFINE_string(
    "model_checkpoint_path", None,
    "File path to a pre-trained model checkpoint. "
    "For --export_type model, this has to be an object-based (TF2) checkpoint "
    "that can be restored to `tf.train.Checkpoint(encoder=encoder)` "
    "for the `encoder` defined by the config file."
    "(Legacy checkpoints with `model=` instead of `encoder=` are also "
    "supported for now.) "
    "For --export_type model_with_mlm, it must be restorable to "
    "`tf.train.Checkpoint(**BertPretrainerV2(...).checkpoint_items)`. "
    "(For now, `tf.train.Checkpoint(pretrainer=BertPretrainerV2(...))` is also "
    "accepted.)")
flags.DEFINE_string(
    "vocab_file", None,
    "For encoders trained on BertTokenzier input: "
    "the vocabulary file that the encoder model was trained with. "
    "Exactly one of --vocab_file and --sp_model_file can be set. "
    "Needed for --export_type model, model_with_mlm and preprocessing.")
flags.DEFINE_string(
    "sp_model_file", None,
    "For encoders trained on SentencepieceTokenzier input: "
    "the SentencePiece .model file that the encoder model was trained with. "
    "Exactly one of --vocab_file and --sp_model_file can be set. "
    "Needed for --export_type model, model_with_mlm and preprocessing.")
flags.DEFINE_bool(
    "do_lower_case", None,
    "Whether to lowercase before tokenization. "
    "If left as None, and --vocab_file is set, do_lower_case will be enabled "
    "if 'uncased' appears in the name of --vocab_file. "
    "If left as None, and --sp_model_file set, do_lower_case defaults to true. "
    "Needed for --export_type model, model_with_mlm and preprocessing.")
flags.DEFINE_integer(
    "default_seq_length", 128,
    "The sequence length of preprocessing results from "
    "top-level preprocess method. This is also the default "
    "sequence length for the bert_pack_inputs subobject."
    "Needed for --export_type preprocessing.")
flags.DEFINE_bool(
    "tokenize_with_offsets", False,  # TODO(b/181866850)
    "Whether to export a .tokenize_with_offsets subobject for "
    "--export_type preprocessing.")
flags.DEFINE_multi_string(
    "gin_file", default=None,
    help="List of paths to the config files.")
flags.DEFINE_multi_string(
    "gin_params", default=None,
    help="List of Gin bindings.")
flags.DEFINE_bool(  # TODO(b/175369555): Remove this flag and its use.
    "experimental_disable_assert_in_preprocessing", False,
    "Export a preprocessing model without tf.Assert ops. "
    "Usually, that would be a bad idea, except TF2.4 has an issue with "
    "Assert ops in tf.functions used in Dataset.map() on a TPU worker, "
    "and omitting the Assert ops lets SavedModels avoid the issue.")


def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)

  if bool(FLAGS.vocab_file) == bool(FLAGS.sp_model_file):
    raise ValueError("Exactly one of `vocab_file` and `sp_model_file` "
                     "can be specified, but got %s and %s." %
                     (FLAGS.vocab_file, FLAGS.sp_model_file))
  do_lower_case = export_tfhub_lib.get_do_lower_case(
      FLAGS.do_lower_case, FLAGS.vocab_file, FLAGS.sp_model_file)

  if FLAGS.export_type in ("model", "model_with_mlm"):
    if bool(FLAGS.bert_config_file) == bool(FLAGS.encoder_config_file):
      raise ValueError("Exactly one of `bert_config_file` and "
                       "`encoder_config_file` can be specified, but got "
                       "%s and %s." %
                       (FLAGS.bert_config_file, FLAGS.encoder_config_file))
    if FLAGS.bert_config_file:
      bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
      encoder_config = None
    else:
      bert_config = None
      encoder_config = encoders.EncoderConfig()
      encoder_config = hyperparams.override_params_dict(
          encoder_config, FLAGS.encoder_config_file, is_strict=True)
    export_tfhub_lib.export_model(
        FLAGS.export_path,
        bert_config=bert_config,
        encoder_config=encoder_config,
        model_checkpoint_path=FLAGS.model_checkpoint_path,
        vocab_file=FLAGS.vocab_file,
        sp_model_file=FLAGS.sp_model_file,
        do_lower_case=do_lower_case,
        with_mlm=FLAGS.export_type == "model_with_mlm",
        copy_pooler_dense_to_encoder=FLAGS.copy_pooler_dense_to_encoder)

  elif FLAGS.export_type == "preprocessing":
    export_tfhub_lib.export_preprocessing(
        FLAGS.export_path,
        vocab_file=FLAGS.vocab_file,
        sp_model_file=FLAGS.sp_model_file,
        do_lower_case=do_lower_case,
        default_seq_length=FLAGS.default_seq_length,
        tokenize_with_offsets=FLAGS.tokenize_with_offsets,
        experimental_disable_assert=
        FLAGS.experimental_disable_assert_in_preprocessing)

  else:
    raise app.UsageError(
        "Unknown value '%s' for flag --export_type" % FLAGS.export_type)


if __name__ == "__main__":
  app.run(main)