docs/vision/image_classification.ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Tce3stUlHN0L"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "tuOe1ymfHZPu"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qFdPvlXBOdUN"
},
"source": [
"# Image classification with Model Garden"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MfBg1C5NB3X0"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/image_classification\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ta_nFXaVAqLD"
},
"source": [
"This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow [Model Garden](https://github.com/tensorflow/models) package (`tensorflow-models`) to classify images in the [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.\n",
"\n",
"Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
"\n",
"This tutorial uses a [ResNet](https://arxiv.org/pdf/1512.03385.pdf) model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.\n",
"\n",
"This tutorial demonstrates how to:\n",
"1. Use models from the TensorFlow Models package.\n",
"2. Fine-tune a pre-built ResNet for image classification.\n",
"3. Export the tuned ResNet model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G2FlaQcEPOER"
},
"source": [
"## Setup\n",
"\n",
"Install and import the necessary modules."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XvWfdCrvrV5W"
},
"outputs": [],
"source": [
"!pip install -U -q \"tf-models-official\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CKYMTPjOE400"
},
"source": [
"Import TensorFlow, TensorFlow Datasets, and a few helper libraries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Wlon1uoIowmZ"
},
"outputs": [],
"source": [
"import pprint\n",
"import tempfile\n",
"\n",
"from IPython import display\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AVTs0jDd1b24"
},
"source": [
"The `tensorflow_models` package contains the ResNet vision model, and the `official.vision.serving` model contains the function to save and export the tuned model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NHT1iiIiBzlC"
},
"outputs": [],
"source": [
"import tensorflow_models as tfm\n",
"\n",
"# These are not in the tfm public API for v2.9. They will be available in v2.10\n",
"from official.vision.serving import export_saved_model_lib\n",
"import official.core.train_lib"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aKv3wdqkQ8FU"
},
"source": [
"## Configure the ResNet-18 model for the Cifar-10 dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5iN8mHEJjKYE"
},
"source": [
"The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.\n",
"\n",
"In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
"\n",
"Use the `resnet_imagenet` factory configuration, as defined by `tfm.vision.configs.image_classification.image_classification_imagenet`. The configuration is set up to train ResNet to converge on [ImageNet](https://www.image-net.org/)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1M77f88Dj2Td"
},
"outputs": [],
"source": [
"exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n",
"tfds_name = 'cifar10'\n",
"ds,ds_info = tfds.load(\n",
"tfds_name,\n",
"with_info=True)\n",
"ds_info"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U6PVwXA-j3E7"
},
"source": [
"Adjust the model and dataset configurations so that it works with Cifar-10 (`cifar10`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YWI7faVStQaV"
},
"outputs": [],
"source": [
"# Configure model\n",
"exp_config.task.model.num_classes = 10\n",
"exp_config.task.model.input_size = list(ds_info.features[\"image\"].shape)\n",
"exp_config.task.model.backbone.resnet.model_id = 18\n",
"\n",
"# Configure training and testing data\n",
"batch_size = 128\n",
"\n",
"exp_config.task.train_data.input_path = ''\n",
"exp_config.task.train_data.tfds_name = tfds_name\n",
"exp_config.task.train_data.tfds_split = 'train'\n",
"exp_config.task.train_data.global_batch_size = batch_size\n",
"\n",
"exp_config.task.validation_data.input_path = ''\n",
"exp_config.task.validation_data.tfds_name = tfds_name\n",
"exp_config.task.validation_data.tfds_split = 'test'\n",
"exp_config.task.validation_data.global_batch_size = batch_size\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DE3ggKzzTD56"
},
"source": [
"Adjust the trainer configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "inE_-4UGkLud"
},
"outputs": [],
"source": [
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" print('This may be broken in Colab.')\n",
" device = 'GPU'\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" print('This may be broken in Colab.')\n",
" device = 'TPU'\n",
"else:\n",
" print('Running on CPU is slow, so only train for a few steps.')\n",
" device = 'CPU'\n",
"\n",
"if device=='CPU':\n",
" train_steps = 20\n",
" exp_config.trainer.steps_per_loop = 5\n",
"else:\n",
" train_steps=5000\n",
" exp_config.trainer.steps_per_loop = 100\n",
"\n",
"exp_config.trainer.summary_interval = 100\n",
"exp_config.trainer.checkpoint_interval = train_steps\n",
"exp_config.trainer.validation_interval = 1000\n",
"exp_config.trainer.validation_steps = ds_info.splits['test'].num_examples // batch_size\n",
"exp_config.trainer.train_steps = train_steps\n",
"exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
"exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
"exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
"exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5mTcDnBiTOYD"
},
"source": [
"Print the modified configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tuVfxSBCTK-y"
},
"outputs": [],
"source": [
"pprint.pprint(exp_config.as_dict())\n",
"\n",
"display.Javascript(\"google.colab.output.setIframeHeight('300px');\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w7_X0UHaRF2m"
},
"source": [
"Set up the distribution strategy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ykL14FIbTaSt"
},
"outputs": [],
"source": [
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
"\n",
"if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
" tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" distribution_strategy = tf.distribute.MirroredStrategy()\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" tf.tpu.experimental.initialize_tpu_system()\n",
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
" distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
"else:\n",
" print('Warning: this will be really slow.')\n",
" distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W4k5YH5pTjaK"
},
"source": [
"Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
"\n",
"The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6MgYSH0PtUaW"
},
"outputs": [],
"source": [
"with distribution_strategy.scope():\n",
" model_dir = tempfile.mkdtemp()\n",
" task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)\n",
"\n",
"# tf.keras.utils.plot_model(task.build_model(), show_shapes=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IFXEZYdzBKoX"
},
"outputs": [],
"source": [
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
" print()\n",
" print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
" print(f'labels.shape: {str(labels.shape):16} labels.dtype: {labels.dtype!r}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yrwxnGDaRU0U"
},
"source": [
"## Visualize the training data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "683c255c6c52"
},
"source": [
"The dataloader applies a z-score normalization using \n",
"`preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`, so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PdmOz2EC0Nx2"
},
"outputs": [],
"source": [
"plt.hist(images.numpy().flatten());"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7a8582ebde7b"
},
"source": [
"Use `ds_info` (which is an instance of `tfds.core.DatasetInfo`) to lookup the text descriptions of each class ID."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Wq4Wq_CuDG3Q"
},
"outputs": [],
"source": [
"label_info = ds_info.features['label']\n",
"label_info.int2str(1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8c652a6fdbcf"
},
"source": [
"Visualize a batch of the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZKfTxytf1l0d"
},
"outputs": [],
"source": [
"def show_batch(images, labels, predictions=None):\n",
" plt.figure(figsize=(10, 10))\n",
" min = images.numpy().min()\n",
" max = images.numpy().max()\n",
" delta = max - min\n",
"\n",
" for i in range(12):\n",
" plt.subplot(6, 6, i + 1)\n",
" plt.imshow((images[i]-min) / delta)\n",
" if predictions is None:\n",
" plt.title(label_info.int2str(labels[i]))\n",
" else:\n",
" if labels[i] == predictions[i]:\n",
" color = 'g'\n",
" else:\n",
" color = 'r'\n",
" plt.title(label_info.int2str(predictions[i]), color=color)\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xkA5h_RBtYYU"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10))\n",
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
" show_batch(images, labels)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v_A9VnL2RbXP"
},
"source": [
"## Visualize the testing data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AXovuumW_I2z"
},
"source": [
"Visualize a batch of images from the validation dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ma-_Eb-nte9A"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10));\n",
"for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):\n",
" show_batch(images, labels)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ihKJt2FHRi2N"
},
"source": [
"## Train and evaluate"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0AFMNvYxtjXx"
},
"outputs": [],
"source": [
"model, eval_logs = tfm.core.train_lib.run_experiment(\n",
" distribution_strategy=distribution_strategy,\n",
" task=task,\n",
" mode='train_and_eval',\n",
" params=exp_config,\n",
" model_dir=model_dir,\n",
" run_post_eval=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gCcHMQYhozmA"
},
"outputs": [],
"source": [
"# tf.keras.utils.plot_model(model, show_shapes=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7nVfxlBA8Gb"
},
"source": [
"Print the `accuracy`, `top_5_accuracy`, and `validation_loss` evaluation metrics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0124f938a1b9"
},
"outputs": [],
"source": [
"for key, value in eval_logs.items():\n",
" if isinstance(value, tf.Tensor):\n",
" value = value.numpy()\n",
" print(f'{key:20}: {value:.3f}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TDys5bZ1zsml"
},
"source": [
"Run a batch of the processed training data through the model, and view the results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GhI7zR-Uz1JT"
},
"outputs": [],
"source": [
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
" predictions = model.predict(images)\n",
" predictions = tf.argmax(predictions, axis=-1)\n",
"\n",
"show_batch(images, labels, tf.cast(predictions, tf.int32))\n",
"\n",
"if device=='CPU':\n",
" plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fkE9locGTBgt"
},
"source": [
"## Export a SavedModel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9669d08c91af"
},
"source": [
"The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AQCFa7BvtmDg"
},
"outputs": [],
"source": [
"# Saving and exporting the trained model\n",
"export_saved_model_lib.export_inference_graph(\n",
" input_type='image_tensor',\n",
" batch_size=1,\n",
" input_image_size=[32, 32],\n",
" params=exp_config,\n",
" checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
" export_dir='./export/')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vVr6DxNqTyLZ"
},
"source": [
"Test the exported model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gP7nOvrftsB0"
},
"outputs": [],
"source": [
"# Importing SavedModel\n",
"imported = tf.saved_model.load('./export/')\n",
"model_fn = imported.signatures['serving_default']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GiOp2WVIUNUZ"
},
"source": [
"Visualize the predictions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BTRMrZQAN4mk"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10))\n",
"for data in tfds.load('cifar10', split='test').batch(12).take(1):\n",
" predictions = []\n",
" for image in data['image']:\n",
" index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]\n",
" predictions.append(index)\n",
" show_batch(data['image'], data['label'], predictions)\n",
"\n",
" if device=='CPU':\n",
" plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')"
]
}
],
"metadata": {
"colab": {
"name": "classification_with_model_garden.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}