tensorflow/models

View on GitHub
docs/orbit/index.ipynb

Summary

Maintainability
Test Coverage
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# Training with Orbit"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/orbit\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "456h0idS2Xcq"
      },
      "source": [
        "This example will work through fine-tuning a BERT model using the [Orbit](https://www.tensorflow.org/api_docs/python/orbit) training library.\n",
        "\n",
        "Orbit is a flexible, lightweight library designed to make it easy to write [custom training loops](https://www.tensorflow.org/tutorials/distribute/custom_training) in TensorFlow. Orbit handles common model training tasks such as saving checkpoints, running model evaluations, and setting up summary writing, while giving users full control over implementing the inner training loop. It integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU).\n",
        "\n",
        "Most examples on [tensorflow.org](https://www.tensorflow.org/) use custom training loops or  [model.fit()](https://www.tensorflow.org/api_docs/python/tf/keras/Model) from Keras. Orbit is a good alternative to `model.fit` if your model is complex and your training loop requires more flexibility, control, or customization. Also, using Orbit can simplify the code when there are many different model architectures that all use the same custom training loop.\n",
        "\n",
        "This tutorial focuses on setting up and using Orbit, rather than details about BERT, model construction, and data processing. For more in-depth tutorials on these topics, refer to the following tutorials:\n",
        "\n",
        "* [Fine tune BERT](https://www.tensorflow.org/text/tutorials/fine_tune_bert) - which goes into detail on these sub-topics.\n",
        "* [Fine tune BERT for GLUE on TPU](https://www.tensorflow.org/text/tutorials/bert_glue) - which generalizes the code to run any BERT configuration on any [GLUE](https://www.tensorflow.org/datasets/catalog/glue) sub-task, and runs on TPU."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TJ4m3khW3p_W"
      },
      "source": [
        "## Install the TensorFlow Models package\n",
        "\n",
        "Install and import the necessary packages, then configure all the objects necessary for training a model.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FZlj0U8Aq9Gt"
      },
      "outputs": [],
      "source": [
        "!pip install -q opencv-python\n",
        "!pip install tensorflow>=2.9.0 tf-models-official"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MEJkRrmapr16"
      },
      "source": [
        "The `tf-models-official` package contains both the `orbit` and `tensorflow_models` modules."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dUVPW84Zucuq"
      },
      "outputs": [],
      "source": [
        "import tensorflow_models as tfm\n",
        "import orbit"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "18Icocf3lwYD"
      },
      "source": [
        "## Setup for training\n",
        "\n",
        "This tutorial does not focus on configuring the environment, building the model and optimizer, and loading data. All these techniques are covered in more detail in the [Fine tune BERT](https://www.tensorflow.org/text/tutorials/fine_tune_bert) and [Fine tune BERT with GLUE](https://www.tensorflow.org/text/tutorials/bert_glue) tutorials.\n",
        "\n",
        "To view how the training is set up for this tutorial, expand the rest of this section.\n",
        "\n",
        "  \u003c!-- \u003cdiv class=\"tfo-display-only-on-site\"\u003e\u003cdevsite-expandable\u003e\n",
        "  \u003cbutton type=\"button\" class=\"button-red button expand-control\"\u003eExpand Section\u003c/button\u003e --\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ljy0z-i3okCS"
      },
      "source": [
        "### Import the necessary packages\n",
        "\n",
        "Import the BERT model and dataset building library from [Tensorflow Model Garden](https://github.com/tensorflow/models)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gCBo6wxA2b5n"
      },
      "outputs": [],
      "source": [
        "import glob\n",
        "import os\n",
        "import pathlib\n",
        "import tempfile\n",
        "import time\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PG1kwhnvq3VC"
      },
      "outputs": [],
      "source": [
        "from official.nlp.data import sentence_prediction_dataloader\n",
        "from official.nlp import optimization"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PsbhUV_p3wxN"
      },
      "source": [
        "### Configure the distribution strategy\n",
        "\n",
        "While `tf.distribute` won't help the model's runtime if you're running on a single machine or GPU, it's necessary for TPUs. Setting up a distribution strategy allows you to use the same code regardless of the configuration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PG702dqstXIk"
      },
      "outputs": [],
      "source": [
        "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
        "\n",
        "if 'GPU' in ''.join(logical_device_names):\n",
        "  strategy = tf.distribute.MirroredStrategy()\n",
        "elif 'TPU' in ''.join(logical_device_names):\n",
        "  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n",
        "  tf.config.experimental_connect_to_cluster(resolver)\n",
        "  tf.tpu.experimental.initialize_tpu_system(resolver)\n",
        "  strategy = tf.distribute.TPUStrategy(resolver)\n",
        "else:\n",
        "  strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eaQgM98deAMu"
      },
      "source": [
        "For more information about the TPU setup, refer to the [TPU guide](https://www.tensorflow.org/guide/tpu)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7aOxMLLV32Zm"
      },
      "source": [
        "### Create a model and an optimizer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YRdWzOfK3_56"
      },
      "outputs": [],
      "source": [
        "max_seq_length = 128\n",
        "learning_rate = 3e-5\n",
        "num_train_epochs = 3\n",
        "train_batch_size = 32\n",
        "eval_batch_size = 64\n",
        "\n",
        "train_data_size = 3668\n",
        "steps_per_epoch = int(train_data_size / train_batch_size)\n",
        "\n",
        "train_steps = steps_per_epoch * num_train_epochs\n",
        "warmup_steps = int(train_steps * 0.1)\n",
        "\n",
        "print(\"train batch size: \", train_batch_size)\n",
        "print(\"train epochs:     \", num_train_epochs)\n",
        "print(\"steps_per_epoch:  \", steps_per_epoch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BVw3886Ysse6"
      },
      "outputs": [],
      "source": [
        "model_dir = pathlib.Path(tempfile.mkdtemp())\n",
        "print(model_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mu9cV7ew-cVe"
      },
      "source": [
        "\n",
        "Create a BERT Classifier model and a simple optimizer. They must be created inside `strategy.scope` so that the variables can be distributed. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gmwtX0cp-mj5"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  encoder_network = tfm.nlp.encoders.build_encoder(\n",
        "      tfm.nlp.encoders.EncoderConfig(type=\"bert\"))\n",
        "  classifier_model = tfm.nlp.models.BertClassifier(\n",
        "      network=encoder_network, num_classes=2)\n",
        "\n",
        "  optimizer = optimization.create_optimizer(\n",
        "      init_lr=3e-5,\n",
        "      num_train_steps=steps_per_epoch * num_train_epochs,\n",
        "      num_warmup_steps=warmup_steps,\n",
        "      end_lr=0.0,\n",
        "      optimizer_type='adamw')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jwJSfewG5jVV"
      },
      "outputs": [],
      "source": [
        "tf.keras.utils.plot_model(classifier_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IQy5pYgAf8Ft"
      },
      "source": [
        "### Initialize from a Checkpoint"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6CE14GEybgRR"
      },
      "outputs": [],
      "source": [
        "bert_dir = 'gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12/'\n",
        "tf.io.gfile.listdir(bert_dir)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x7fwxz9xidKt"
      },
      "outputs": [],
      "source": [
        "bert_checkpoint = bert_dir + 'bert_model.ckpt'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q7EfwVCRe7N_"
      },
      "outputs": [],
      "source": [
        "def init_from_ckpt_fn():\n",
        "  init_checkpoint = tf.train.Checkpoint(**classifier_model.checkpoint_items)\n",
        "  with strategy.scope():\n",
        "    (init_checkpoint\n",
        "     .read(bert_checkpoint)\n",
        "     .expect_partial()\n",
        "     .assert_existing_objects_matched())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M0LUMlsde-2f"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  init_from_ckpt_fn()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gAuns4vN_IYV"
      },
      "source": [
        "\n",
        "To use Orbit, create a `tf.train.CheckpointManager` object."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "i7NwM1Jq_MX7"
      },
      "outputs": [],
      "source": [
        "checkpoint = tf.train.Checkpoint(model=classifier_model, optimizer=optimizer)\n",
        "checkpoint_manager = tf.train.CheckpointManager(\n",
        "    checkpoint,\n",
        "    directory=model_dir,\n",
        "    max_to_keep=5,\n",
        "    step_counter=optimizer.iterations,\n",
        "    checkpoint_interval=steps_per_epoch,\n",
        "    init_fn=init_from_ckpt_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nzeiAFhcCOAo"
      },
      "source": [
        "### Create distributed datasets\n",
        "\n",
        "As a shortcut for this tutorial, the [GLUE/MPRC dataset](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc) has been converted to a pair of [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) files containing serialized `tf.train.Example` protos.\n",
        "\n",
        "The data was converted using [this script](https://github.com/tensorflow/models/blob/r2.9.0/official/nlp/data/create_finetuning_data.py).\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZVfbiT1dCnDk"
      },
      "outputs": [],
      "source": [
        "train_data_path = \"gs://download.tensorflow.org/data/model_garden_colab/mrpc_train.tf_record\"\n",
        "eval_data_path = \"gs://download.tensorflow.org/data/model_garden_colab/mrpc_eval.tf_record\"\n",
        "\n",
        "def _dataset_fn(input_file_pattern, \n",
        "                global_batch_size, \n",
        "                is_training, \n",
        "                input_context=None):\n",
        "  data_config = sentence_prediction_dataloader.SentencePredictionDataConfig(\n",
        "    input_path=input_file_pattern,\n",
        "    seq_length=max_seq_length,\n",
        "    global_batch_size=global_batch_size,\n",
        "    is_training=is_training)\n",
        "  return sentence_prediction_dataloader.SentencePredictionDataLoader(\n",
        "      data_config).load(input_context=input_context)\n",
        "\n",
        "train_dataset = orbit.utils.make_distributed_dataset(\n",
        "    strategy, _dataset_fn, input_file_pattern=train_data_path,\n",
        "    global_batch_size=train_batch_size, is_training=True)\n",
        "eval_dataset = orbit.utils.make_distributed_dataset(\n",
        "    strategy, _dataset_fn, input_file_pattern=eval_data_path,\n",
        "    global_batch_size=eval_batch_size, is_training=False)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dPgiDBQCjsXW"
      },
      "source": [
        "### Create a loss function\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7MCUmmo2jvXl"
      },
      "outputs": [],
      "source": [
        "def loss_fn(labels, logits):\n",
        "  \"\"\"Classification loss.\"\"\"\n",
        "  labels = tf.squeeze(labels)\n",
        "  log_probs = tf.nn.log_softmax(logits, axis=-1)\n",
        "  one_hot_labels = tf.one_hot(\n",
        "      tf.cast(labels, dtype=tf.int32), depth=2, dtype=tf.float32)\n",
        "  per_example_loss = -tf.reduce_sum(\n",
        "      tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)\n",
        "  return tf.reduce_mean(per_example_loss)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ohlO-8FQkwsr"
      },
      "source": [
        " \u003c/devsite-expandable\u003e\u003c/div\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ymhbvPaEJ96T"
      },
      "source": [
        "## Controllers, Trainers and Evaluators\n",
        "\n",
        "When using Orbit, the `orbit.Controller` class drives the training. The Controller handles the details of distribution strategies, step counting, TensorBoard summaries, and checkpointing.\n",
        "\n",
        "To implement the training and evaluation, pass a `trainer` and `evaluator`, which are subclass instances of `orbit.AbstractTrainer` and `orbit.AbstractEvaluator`. Keeping with Orbit's light-weight design, these two classes have a minimal interface.\n",
        "\n",
        "The Controller drives training and evaluation by calling `trainer.train(num_steps)` and `evaluator.evaluate(num_steps)`. These `train` and `evaluate` methods return a dictionary of results for logging.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a6sU2vBeyXtu"
      },
      "source": [
        "Training is broken into chunks of length `num_steps`. This is set by the Controller's [`steps_per_loop`](https://tensorflow.org/api_docs/python/orbit/Controller#args) argument. With the trainer and evaluator abstract base classes, the meaning of `num_steps` is entirely determined by the implementer.\n",
        "\n",
        "Some common examples include:\n",
        "\n",
        "* Having the chunks represent dataset-epoch boundaries, like the default keras setup. \n",
        "* Using it to more efficiently dispatch a number of training steps to an accelerator with a single `tf.function` call (like the `steps_per_execution` argument to `Model.compile`). \n",
        "* Subdividing into smaller chunks as needed.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p4mXGIRJsf1j"
      },
      "source": [
        "### StandardTrainer and StandardEvaluator\n",
        "\n",
        "Orbit provides two additional classes, `orbit.StandardTrainer` and `orbit.StandardEvaluator`, to give more structure around the training and evaluation loops.\n",
        "\n",
        "With StandardTrainer, you only need to set `train_loop_begin`, `train_step`, and `train_loop_end`. The base class handles the loops, dataset logic, and `tf.function` (according to the options set by their `orbit.StandardTrainerOptions`). This is simpler than `orbit.AbstractTrainer`, which requires you to handle the entire loop. StandardEvaluator has a similar structure and simplification to StandardTrainer.\n",
        "\n",
        "This is effectively an implementation of the `steps_per_execution` approach used by Keras."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-hvZ8PvohmR5"
      },
      "source": [
        "Contrast this with Keras, where training is divided both into epochs (a single pass over the dataset) and `steps_per_execution`(set within [`Model.compile`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile). In Keras, metric averages are typically accumulated over an epoch, and reported \u0026 reset between epochs. For efficiency, `steps_per_execution` only controls the number of training steps made per call.\n",
        "\n",
        "In this simple case, `steps_per_loop` (within `StandardTrainer`) will handle both the metric resets and the number of steps per call.  \n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NoDFN1L-1jIu"
      },
      "source": [
        "The minimal setup when using these base classes is to implement the methods as follows:\n",
        "\n",
        "1. `StandardTrainer.train_loop_begin` - Reset your training metrics.\n",
        "2. `StandardTrainer.train_step` - Apply a single gradient update.\n",
        "3. `StandardTrainer.train_loop_end` - Report your training metrics.\n",
        "\n",
        "and\n",
        "\n",
        "4. `StandardEvaluator.eval_begin` - Reset your evaluation metrics.\n",
        "5. `StandardEvaluator.eval_step` - Run a single evaluation setep.\n",
        "6. `StandardEvaluator.eval_reduce` - This is not necessary in this simple setup.\n",
        "7. `StandardEvaluator.eval_end` - Report your evaluation metrics.\n",
        "\n",
        "Depending on the settings, the base class may wrap the `train_step` and `eval_step` code in `tf.function` or `tf.while_loop`, which has some limitations compared to standard python."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3KPA0NDZt2JD"
      },
      "source": [
        "### Define the trainer class"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6LDPsvJwfuPR"
      },
      "source": [
        "In this section you'll create a subclass of `orbit.StandardTrainer` for this task.  \n",
        "\n",
        "Note: To better explain the `BertClassifierTrainer` class, this section defines each method as a stand-alone function and assembles them into a class at the end.\n",
        "\n",
        "The trainer needs access to the training data, model, optimizer, and distribution strategy. Pass these as arguments to the initializer.\n",
        "\n",
        "Define a single training metric, `training_loss`, using `tf.keras.metrics.Mean`. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6DQYZN5ax-MG"
      },
      "outputs": [],
      "source": [
        "def trainer_init(self,\n",
        "                 train_dataset,\n",
        "                 model,\n",
        "                 optimizer,\n",
        "                 strategy):\n",
        "  self.strategy = strategy\n",
        "  with self.strategy.scope():\n",
        "    self.model = model\n",
        "    self.optimizer = optimizer\n",
        "    self.global_step = self.optimizer.iterations\n",
        "    \n",
        "\n",
        "    self.train_loss = tf.keras.metrics.Mean(\n",
        "        'training_loss', dtype=tf.float32)\n",
        "    orbit.StandardTrainer.__init__(self, train_dataset)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QOwHD7U5hVue"
      },
      "source": [
        "Before starting a run of the training loop, the `train_loop_begin` method will reset the `train_loss` metric."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AkpcHqXShWL0"
      },
      "outputs": [],
      "source": [
        "def train_loop_begin(self):\n",
        "  self.train_loss.reset_states()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UjtFOFyxn2BB"
      },
      "source": [
        "The `train_step` is a straight-forward loss-calculation and gradient update that is run by the distribution strategy. This is accomplished by defining the gradient step as a nested function (`step_fn`).\n",
        "\n",
        "The method receives `tf.distribute.DistributedIterator` to handle the [distributed input](https://www.tensorflow.org/tutorials/distribute/input). The method uses `Strategy.run` to execute `step_fn` and feeds it from the distributed iterator.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QuPwNnT5I-GP"
      },
      "outputs": [],
      "source": [
        "def train_step(self, iterator):\n",
        "\n",
        "  def step_fn(inputs):\n",
        "    labels = inputs.pop(\"label_ids\")\n",
        "    with tf.GradientTape() as tape:\n",
        "      model_outputs = self.model(inputs, training=True)\n",
        "      # Raw loss is used for reporting in metrics/logs.\n",
        "      raw_loss = loss_fn(labels, model_outputs)\n",
        "      # Scales down the loss for gradients to be invariant from replicas.\n",
        "      loss = raw_loss / self.strategy.num_replicas_in_sync\n",
        "\n",
        "    grads = tape.gradient(loss, self.model.trainable_variables)\n",
        "    optimizer.apply_gradients(zip(grads, self.model.trainable_variables))\n",
        "    # For reporting, the metric takes the mean of losses.\n",
        "    self.train_loss.update_state(raw_loss)\n",
        "\n",
        "  self.strategy.run(step_fn, args=(next(iterator),))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VmQNwx5QpyDt"
      },
      "source": [
        "The `orbit.StandardTrainer` handles the `@tf.function` and loops.\n",
        "\n",
        "After running through `num_steps` of training, `StandardTrainer` calls `train_loop_end`. The function returns the metric results:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GqCyVk1zzGod"
      },
      "outputs": [],
      "source": [
        "def train_loop_end(self):\n",
        "  return {\n",
        "      self.train_loss.name: self.train_loss.result(),\n",
        "  }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xvmLONl80KUv"
      },
      "source": [
        "Build a subclass of `orbit.StandardTrainer` with those methods."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oRoL7VE6xt1G"
      },
      "outputs": [],
      "source": [
        "class BertClassifierTrainer(orbit.StandardTrainer):\n",
        "  __init__ = trainer_init\n",
        "  train_loop_begin = train_loop_begin\n",
        "  train_step = train_step\n",
        "  train_loop_end = train_loop_end"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yjG4QAWj1B00"
      },
      "source": [
        "### Define the evaluator class\n",
        "\n",
        "Note: Like the previous section, this section defines each method as a stand-alone function and assembles them into a `BertClassifierEvaluator` class at the end.\n",
        "\n",
        "The evaluator is even simpler for this task. It needs access to the evaluation dataset, the model, and the strategy. After saving references to those objects, the constructor just needs to create the metrics."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cvX7seCY1CWj"
      },
      "outputs": [],
      "source": [
        "def evaluator_init(self,\n",
        "                   eval_dataset,\n",
        "                   model,\n",
        "                   strategy):\n",
        "  self.strategy = strategy\n",
        "  with self.strategy.scope():\n",
        "    self.model = model\n",
        "    \n",
        "    self.eval_loss = tf.keras.metrics.Mean(\n",
        "        'evaluation_loss', dtype=tf.float32)\n",
        "    self.eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "        name='accuracy', dtype=tf.float32)\n",
        "    orbit.StandardEvaluator.__init__(self, eval_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0r-z-XK7ybyX"
      },
      "source": [
        "Similar to the trainer, the `eval_begin` and `eval_end` methods just need to reset the metrics before the loop and then report the results after the loop."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7VVb0Tg6yZjI"
      },
      "outputs": [],
      "source": [
        "def eval_begin(self):\n",
        "  self.eval_accuracy.reset_states()\n",
        "  self.eval_loss.reset_states()\n",
        "\n",
        "def eval_end(self):\n",
        "  return {\n",
        "      self.eval_accuracy.name: self.eval_accuracy.result(),\n",
        "      self.eval_loss.name: self.eval_loss.result(),\n",
        "  }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iDOZcQvttdmZ"
      },
      "source": [
        "The `eval_step` method works like `train_step`. The inner `step_fn` defines the actual work of calculating the loss \u0026 accuracy and updating the metrics. The outer `eval_step` receives `tf.distribute.DistributedIterator` as input, and uses `Strategy.run` to launch the distributed execution to `step_fn`, feeding it from the distributed iterator."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JLJnYuuGJjvd"
      },
      "outputs": [],
      "source": [
        "def eval_step(self, iterator):\n",
        "\n",
        "  def step_fn(inputs):\n",
        "    labels = inputs.pop(\"label_ids\")\n",
        "    model_outputs = self.model(inputs, training=True)\n",
        "    loss = loss_fn(labels, model_outputs)\n",
        "    self.eval_loss.update_state(loss)\n",
        "    self.eval_accuracy.update_state(labels, model_outputs)\n",
        "\n",
        "  self.strategy.run(step_fn, args=(next(iterator),))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gt3hh0V30QcP"
      },
      "source": [
        "Build a subclass of `orbit.StandardEvaluator` with those methods."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3zqyLxfNyCgA"
      },
      "outputs": [],
      "source": [
        "class BertClassifierEvaluator(orbit.StandardEvaluator):\n",
        "  __init__ = evaluator_init\n",
        "  eval_begin = eval_begin\n",
        "  eval_end = eval_end\n",
        "  eval_step = eval_step"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aK9gEja9qPOc"
      },
      "source": [
        "### End-to-end training and evaluation\n",
        "\n",
        "To run the training and evaluation, simply create the trainer, evaluator, and `orbit.Controller` instances. Then call the `Controller.train_and_evaluate` method."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PqQetxyXqRA9"
      },
      "outputs": [],
      "source": [
        "trainer = BertClassifierTrainer(\n",
        "    train_dataset, classifier_model, optimizer, strategy)\n",
        "\n",
        "evaluator = BertClassifierEvaluator(\n",
        "    eval_dataset, classifier_model, strategy)\n",
        "\n",
        "controller = orbit.Controller(\n",
        "    trainer=trainer,\n",
        "    evaluator=evaluator,\n",
        "    global_step=trainer.global_step,\n",
        "    steps_per_loop=20,\n",
        "    checkpoint_manager=checkpoint_manager)\n",
        "\n",
        "result = controller.train_and_evaluate(\n",
        "    train_steps=steps_per_epoch * num_train_epochs,\n",
        "    eval_steps=-1,\n",
        "    eval_interval=steps_per_epoch)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "Tce3stUlHN0L"
      ],
      "name": "Orbit Tutorial.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}