tensorflow/models

View on GitHub
official/vision/serving/export_module_factory.py

Summary

Maintainability
A
1 hr
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Factory for vision export modules."""

from typing import List, Optional

import tensorflow as tf, tf_keras

from official.core import config_definitions as cfg
from official.vision import configs
from official.vision.dataloaders import classification_input
from official.vision.modeling import factory
from official.vision.serving import export_base_v2 as export_base
from official.vision.serving import export_utils


def create_classification_export_module(params: cfg.ExperimentConfig,
                                        input_type: str,
                                        batch_size: int,
                                        input_image_size: List[int],
                                        num_channels: int = 3):
  """Creats classification export module."""
  input_signature = export_utils.get_image_input_signatures(
      input_type, batch_size, input_image_size, num_channels)
  input_specs = tf_keras.layers.InputSpec(
      shape=[batch_size] + input_image_size + [num_channels])

  model = factory.build_classification_model(
      input_specs=input_specs,
      model_config=params.task.model,
      l2_regularizer=None)

  def preprocess_fn(inputs):
    image_tensor = export_utils.parse_image(inputs, input_type,
                                            input_image_size, num_channels)
    # If input_type is `tflite`, do not apply image preprocessing.
    if input_type == 'tflite':
      return image_tensor

    def preprocess_image_fn(inputs):
      return classification_input.Parser.inference_fn(
          inputs, input_image_size, num_channels)

    images = tf.map_fn(
        preprocess_image_fn, elems=image_tensor,
        fn_output_signature=tf.TensorSpec(
            shape=input_image_size + [num_channels],
            dtype=tf.float32))

    return images

  def postprocess_fn(logits):
    probs = tf.nn.softmax(logits)
    return {'logits': logits, 'probs': probs}

  export_module = export_base.ExportModule(params,
                                           model=model,
                                           input_signature=input_signature,
                                           preprocessor=preprocess_fn,
                                           postprocessor=postprocess_fn)
  return export_module


def get_export_module(params: cfg.ExperimentConfig,
                      input_type: str,
                      batch_size: Optional[int],
                      input_image_size: List[int],
                      num_channels: int = 3) -> export_base.ExportModule:
  """Factory for export modules."""
  if isinstance(params.task,
                configs.image_classification.ImageClassificationTask):
    export_module = create_classification_export_module(
        params, input_type, batch_size, input_image_size, num_channels)
  else:
    raise ValueError('Export module not implemented for {} task.'.format(
        type(params.task)))
  return export_module