tensorflow/models

View on GitHub
official/vision/modeling/factory_3d.py

Summary

Maintainability
A
35 mins
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Factory methods to build models."""

# Import libraries
import tensorflow as tf, tf_keras

from official.core import registry
from official.vision.configs import video_classification as video_classification_cfg
from official.vision.modeling import video_classification_model
from official.vision.modeling import backbones

_REGISTERED_MODEL_CLS = {}


def register_model_builder(key: str):
  """Decorates a builder of model class.

  The builder should be a Callable (a class or a function).
  This decorator supports registration of backbone builder as follows:

  ```
  class MyModel(tf_keras.Model):
    pass

  @register_backbone_builder('mybackbone')
  def builder(input_specs, config, l2_reg):
    return MyModel(...)

  # Builds a MyModel object.
  my_backbone = build_backbone_3d(input_specs, config, l2_reg)
  ```

  Args:
    key: the key to look up the builder.

  Returns:
    A callable for use as class decorator that registers the decorated class
    for creation from an instance of model class.
  """
  return registry.register(_REGISTERED_MODEL_CLS, key)


def build_model(
    model_type: str,
    input_specs: tf_keras.layers.InputSpec,
    model_config: video_classification_cfg.hyperparams.Config,
    num_classes: int,
    l2_regularizer: tf_keras.regularizers.Regularizer = None) -> tf_keras.Model:
  """Builds backbone from a config.

  Args:
    model_type: string name of model type. It should be consistent with
      ModelConfig.model_type.
    input_specs: tf_keras.layers.InputSpec.
    model_config: a OneOfConfig. Model config.
    num_classes: number of classes.
    l2_regularizer: tf_keras.regularizers.Regularizer instance. Default to None.

  Returns:
    tf_keras.Model instance of the backbone.
  """
  model_builder = registry.lookup(_REGISTERED_MODEL_CLS, model_type)

  return model_builder(input_specs, model_config, num_classes, l2_regularizer)


@register_model_builder('video_classification')
def build_video_classification_model(
    input_specs: tf_keras.layers.InputSpec,
    model_config: video_classification_cfg.VideoClassificationModel,
    num_classes: int,
    l2_regularizer: tf_keras.regularizers.Regularizer = None) -> tf_keras.Model:
  """Builds the video classification model."""
  input_specs_dict = {'image': input_specs}
  norm_activation_config = model_config.norm_activation
  backbone = backbones.factory.build_backbone(
      input_specs=input_specs,
      backbone_config=model_config.backbone,
      norm_activation_config=norm_activation_config,
      l2_regularizer=l2_regularizer)

  model = video_classification_model.VideoClassificationModel(
      backbone=backbone,
      num_classes=num_classes,
      input_specs=input_specs_dict,
      dropout_rate=model_config.dropout_rate,
      aggregate_endpoints=model_config.aggregate_endpoints,
      kernel_regularizer=l2_regularizer,
      require_endpoints=model_config.require_endpoints)
  return model