tensorflow/models

View on GitHub
official/projects/edgetpu/vision/modeling/mobilenet_edgetpu_v2_model.py

Summary

Maintainability
A
35 mins
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains definitions for MobilenetEdgeTPUV2 image classification models."""

from typing import Any, Mapping, Optional
from absl import logging
import tensorflow as tf, tf_keras

from official.projects.edgetpu.vision.modeling import common_modules
from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v2_model_blocks

ModelConfig = mobilenet_edgetpu_v2_model_blocks.ModelConfig

MODEL_CONFIGS = {
    'mobilenet_edgetpu_v2':
        mobilenet_edgetpu_v2_model_blocks.mobilenet_edgetpu_v2_s(),
    'mobilenet_edgetpu_v2_tiny':
        mobilenet_edgetpu_v2_model_blocks.mobilenet_edgetpu_v2_tiny(),
    'mobilenet_edgetpu_v2_xs':
        mobilenet_edgetpu_v2_model_blocks.mobilenet_edgetpu_v2_xs(),
    'mobilenet_edgetpu_v2_s':
        mobilenet_edgetpu_v2_model_blocks.mobilenet_edgetpu_v2_s(),
    'mobilenet_edgetpu_v2_m':
        mobilenet_edgetpu_v2_model_blocks.mobilenet_edgetpu_v2_m(),
    'mobilenet_edgetpu_v2_l':
        mobilenet_edgetpu_v2_model_blocks.mobilenet_edgetpu_v2_l(),
    'autoseg_edgetpu_backbone_xs':
        mobilenet_edgetpu_v2_model_blocks.autoseg_edgetpu_backbone_xs(),
    'autoseg_edgetpu_backbone_s':
        mobilenet_edgetpu_v2_model_blocks.autoseg_edgetpu_backbone_s(),
    'autoseg_edgetpu_backbone_m':
        mobilenet_edgetpu_v2_model_blocks.autoseg_edgetpu_backbone_m(),
}


@tf_keras.utils.register_keras_serializable(package='Vision')
class MobilenetEdgeTPUV2(tf_keras.Model):
  """Wrapper class for a MobilenetEdgeTPUV2 Keras model.

  Contains helper methods to build, manage, and save metadata about the model.
  """

  def __init__(self,
               model_config_name: Optional[str] = None,
               overrides: Optional[Mapping[str, Any]] = None,
               **kwargs):
    """Creates a MobilenetEdgeTPUV2 model.

    Args:
      model_config_name: (optional) the model parameters to create the model.
      overrides: (optional) a dict containing keys that can override config.
      **kwargs: All the rest model arguments in a dictionary.
    """
    self.model_config_name = model_config_name
    self._self_setattr_tracking = False
    self.overrides = overrides or {}

    if model_config_name is None:
      model_config = ModelConfig()
    else:
      if model_config_name not in MODEL_CONFIGS:
        supported_model_list = list(MODEL_CONFIGS.keys())
        raise ValueError(f'Unknown model name {model_config_name}. Only support'
                         f'model configs in {supported_model_list}.')
      model_config = MODEL_CONFIGS[model_config_name]

    self.config = model_config.replace(**self.overrides)

    input_channels = self.config.input_channels
    model_name = self.config.model_name
    if isinstance(self.config.resolution, tuple):
      input_shape = (self.config.resolution[0], self.config.resolution[1],
                     input_channels)
    else:
      input_shape = (self.config.resolution, self.config.resolution,
                     input_channels)
    image_input = tf_keras.layers.Input(shape=input_shape)

    output = mobilenet_edgetpu_v2_model_blocks.mobilenet_edgetpu_v2(
        image_input, self.config)

    if not isinstance(output, list):
      # Cast to float32 in case we have a different model dtype
      output = tf.cast(output, tf.float32)
      self._output_specs = output.get_shape()
    else:
      if self.config.features_as_dict:
        # Dict output is required for the decoder ASPP module.
        self._output_specs = {
            str(i): output[i].get_shape() for i in range(len(output))
        }
        output = {str(i): output[i] for i in range(len(output))}
      else:
        # edgetpu/tasks/segmentation assumes features as list.
        self._output_specs = [feat.get_shape() for feat in output]

    logging.info('Building model %s with params %s',
                 model_name,
                 self.config)

    super(MobilenetEdgeTPUV2, self).__init__(
        inputs=image_input, outputs=output, **kwargs)
    self._self_setattr_tracking = True

  @classmethod
  def from_name(cls,
                model_name: str,
                model_weights_path: Optional[str] = None,
                checkpoint_format: Optional[str] = 'tf_checkpoint',
                overrides: Optional[Mapping[str, Any]] = None):
    """Constructs an MobilenetEdgeTPUV2 model from a predefined model name.

    E.g., `MobilenetEdgeTPUV2.from_name('mobilenet_edgetpu_v2_s')`.

    Args:
      model_name: the predefined model name
      model_weights_path: the path to the weights (h5 file or saved model dir)
      checkpoint_format: the model weights format. One of 'tf_checkpoint' or
        'keras_checkpoint'.
      overrides: (optional) a dict containing keys that can override config

    Returns:
      A constructed EfficientNet instance.
    """
    overrides = dict(overrides) if overrides else {}

    # One can define their own custom models if necessary
    MODEL_CONFIGS.update(overrides.pop('model_config', {}))

    model = cls(model_config_name=model_name, overrides=overrides)

    if model_weights_path:
      common_modules.load_weights(model,
                                  model_weights_path,
                                  checkpoint_format=checkpoint_format)
    return model

  def get_config(self):
    config = {'model_config_name': self.model_config_name,
              'overrides': self.overrides}
    keras_model_config = super().get_config()
    return dict(list(config.items()) + list(keras_model_config.items()))

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(model_config_name=config['model_config_name'],
               overrides=config['overrides'])

  @property
  def output_specs(self):
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs