tensorflow/models

View on GitHub
official/projects/assemblenet/configs/assemblenet.py

Summary

Maintainability
A
35 mins
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Definitions for AssembleNet/++ structures.

This structure is a `list` corresponding to a graph representation of the
network, where a node is a convolutional block and an edge specifies a
connection from one block to another.

Each node itself (in the structure list) is a list with the following format:
[block_level, [list_of_input_blocks], number_filter, temporal_dilation,
spatial_stride]. [list_of_input_blocks] should be the list of node indexes whose
values are less than the index of the node itself. The 'stems' of the network
directly taking raw inputs follow a different node format:
[stem_type, temporal_dilation]. The stem_type is -1 for RGB stem and is -2 for
optical flow stem. The stem_type -3 is reserved for the object segmentation
input.

In AssembleNet++lite, instead of passing a single `int` for number_filter, we
pass a list/tuple of three `int`s. They specify the number of channels to be
used for each layer in the inverted bottleneck modules.

The structure_weights specify the learned connection weights.
"""
import dataclasses
from typing import List, Optional, Tuple

from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.configs import backbones_3d
from official.vision.configs import common
from official.vision.configs import video_classification


@dataclasses.dataclass
class BlockSpec(hyperparams.Config):
  level: int = -1
  input_blocks: Tuple[int, ...] = tuple()
  num_filters: int = -1
  temporal_dilation: int = 1
  spatial_stride: int = 1
  input_block_weight: Tuple[float, ...] = tuple()


def flat_lists_to_blocks(model_structures, model_edge_weights):
  """Transforms the raw list structure configs to BlockSpec tuple."""
  blocks = []
  for node, edge_weights in zip(model_structures, model_edge_weights):
    if node[0] < 0:
      block = BlockSpec(level=node[0], temporal_dilation=node[1])
    else:
      block = BlockSpec(  # pytype: disable=wrong-arg-types
          level=node[0],
          input_blocks=node[1],
          num_filters=node[2],
          temporal_dilation=node[3],
          spatial_stride=node[4])
      if edge_weights:
        assert len(edge_weights[0]) == len(block.input_blocks), (
            f'{len(edge_weights[0])} != {len(block.input_blocks)} at block '
            f'{block} weight {edge_weights}')
        block.input_block_weight = tuple(edge_weights[0])
    blocks.append(block)
  return tuple(blocks)


def blocks_to_flat_lists(blocks: List[BlockSpec]):
  """Transforms BlockSpec tuple to the raw list structure configs."""
  # pylint: disable=g-complex-comprehension
  # pylint: disable=g-long-ternary
  model_structure = [[
      b.level,
      list(b.input_blocks), b.num_filters, b.temporal_dilation,
      b.spatial_stride, 0
  ] if b.level >= 0 else [b.level, b.temporal_dilation] for b in blocks]
  model_edge_weights = [
      [list(b.input_block_weight)] if b.input_block_weight else []
      for b in blocks
  ]
  return model_structure, model_edge_weights


# AssembleNet structure for 50/101 layer models, found using evolution with the
# Moments-in-Time dataset. This is the structure used for the experiments in the
# AssembleNet paper. The learned connectivity weights are also provided.
asn50_structure = [[-1, 4], [-1, 4], [-2, 1], [-2, 1], [0, [1], 32, 1, 1, 0],
                   [0, [0], 32, 4, 1, 0], [0, [0, 1, 2, 3], 32, 1, 1, 0],
                   [0, [2, 3], 32, 2, 1, 0], [1, [0, 4, 5, 6, 7], 64, 2, 2, 0],
                   [1, [0, 2, 4, 7], 64, 1, 2, 0], [1, [0, 5, 7], 64, 4, 2, 0],
                   [1, [0, 5], 64, 1, 2, 0], [2, [4, 8, 10, 11], 256, 1, 2, 0],
                   [2, [8, 9], 256, 4, 2, 0], [3, [12, 13], 512, 2, 2, 0]]
asn101_structure = [[-1, 4], [-1, 4], [-2, 1], [-2, 1], [0, [1], 32, 1, 1, 0],
                    [0, [0], 32, 4, 1, 0], [0, [0, 1, 2, 3], 32, 1, 1, 0],
                    [0, [2, 3], 32, 2, 1, 0], [1, [0, 4, 5, 6, 7], 64, 2, 2, 0],
                    [1, [0, 2, 4, 7], 64, 1, 2, 0], [1, [0, 5, 7], 64, 4, 2, 0],
                    [1, [0, 5], 64, 1, 2, 0], [2, [4, 8, 10, 11], 192, 1, 2, 0],
                    [2, [8, 9], 192, 4, 2, 0], [3, [12, 13], 256, 2, 2, 0]]
asn_structure_weights = [
    [], [], [], [], [], [],
    [[
        0.13810564577579498, 0.8465337157249451, 0.3072969317436218,
        0.2867436408996582
    ]], [[0.5846117734909058, 0.6066334843635559]],
    [[
        0.16382087767124176, 0.8852924704551697, 0.4039595425128937,
        0.6823437809944153, 0.5331538319587708
    ]],
    [[
        0.028569204732775688, 0.10333596915006638, 0.7517264485359192,
        0.9260114431381226
    ]], [[0.28832191228866577, 0.7627848982810974, 0.404977947473526]],
    [[0.23474831879138947, 0.7841425538063049]],
    [[
        0.27616503834724426, 0.9514784812927246, 0.6568767428398132,
        0.9547983407974243
    ]], [[0.5047007203102112, 0.8876819610595703]],
    [[0.9892204403877258, 0.8454614877700806]]
]


# AssembleNet++ structure for 50 layer models, found with the Charades dataset.
# This is the model used in the experiments in the AssembleNet++ paper.
# Note that, in order the build AssembleNet++ with this structure, you also need
# to feed 'object segmentation input' to the network indicated as [-3, 4]. It's
# the 5th block in the architecture.
# If you don't plan to use the object input but want to still benefit from
# peer-attention in AssembleNet++ (with RGB and OF), please use the above
# AssembleNet-50 model instead with assemblenet_plus.py code.
full_asnp50_structure = [[-1, 2], [-1, 4], [-2, 2], [-2, 1], [-3, 4],
                         [0, [0, 1, 2, 3, 4], 32, 1, 1, 0],
                         [0, [0, 1, 4], 32, 4, 1, 0],
                         [0, [2, 3, 4], 32, 8, 1, 0],
                         [0, [2, 3, 4], 32, 1, 1, 0],
                         [1, [0, 1, 2, 4, 5, 6, 7, 8], 64, 4, 2, 0],
                         [1, [2, 3, 4, 7, 8], 64, 1, 2, 0],
                         [1, [0, 4, 5, 6, 7], 128, 8, 2, 0],
                         [2, [4, 11], 256, 8, 2, 0],
                         [2, [2, 3, 4, 5, 6, 7, 8, 10, 11], 256, 4, 2, 0],
                         [3, [12, 13], 512, 2, 2, 0]]
full_asnp_structure_weights = [[], [], [], [], [], [[0.6143830418586731, 0.7111759185791016, 0.19351491332054138, 0.1701001077890396, 0.7178536653518677]], [[0.5755624771118164, 0.5644599795341492, 0.7128658294677734]], [[0.26563042402267456, 0.3033692538738251, 0.8244096636772156]], [[0.07013848423957825, 0.07905343919992447, 0.8767927885055542]], [[0.5008697509765625, 0.5020178556442261, 0.49819135665893555, 0.5015180706977844, 0.4987695813179016, 0.4990265369415283, 0.499239057302475, 0.4974501430988312]], [[0.47034338116645813, 0.4694305658340454, 0.767791748046875, 0.5539310574531555, 0.4520096182823181]], [[0.2769702076911926, 0.8116549253463745, 0.597356915473938, 0.6585626602172852, 0.5915306210517883]], [[0.501274824142456, 0.5016682147979736]], [[0.0866393893957138, 0.08469288796186447, 0.9739039540290833, 0.058271341025829315, 0.08397126197814941, 0.10285478830337524, 0.18506969511508942, 0.23874442279338837, 0.9188644886016846]], [[0.4174623489379883, 0.5844835638999939]]]  # pylint: disable=line-too-long


# AssembleNet++lite structure using inverted bottleneck blocks. By specifing
# the connection weights as [], the model could alos automatically learn the
# connection weights during its training.
asnp_lite_structure = [[-1, 1], [-2, 1],
                       [0, [0, 1], [27, 27, 12], 1, 2, 0],
                       [0, [0, 1], [27, 27, 12], 4, 2, 0],
                       [1, [0, 1, 2, 3], [54, 54, 24], 2, 2, 0],
                       [1, [0, 1, 2, 3], [54, 54, 24], 1, 2, 0],
                       [1, [0, 1, 2, 3], [54, 54, 24], 4, 2, 0],
                       [1, [0, 1, 2, 3], [54, 54, 24], 1, 2, 0],
                       [2, [0, 1, 2, 3, 4, 5, 6, 7], [152, 152, 68], 1, 2, 0],
                       [2, [0, 1, 2, 3, 4, 5, 6, 7], [152, 152, 68], 4, 2, 0],
                       [3, [2, 3, 4, 5, 6, 7, 8, 9], [432, 432, 192], 2, 2, 0]]
asnp_lite_structure_weights = [[], [], [[0.19914183020591736, 0.9278576374053955]], [[0.010816320776939392, 0.888792097568512]], [[0.9473835825920105, 0.6303419470787048, 0.1704932451248169, 0.05950307101011276]], [[0.9560931324958801, 0.7898273468017578, 0.36138781905174255, 0.07344610244035721]], [[0.9213919043540955, 0.13418640196323395, 0.8371981978416443, 0.07936054468154907]], [[0.9441559910774231, 0.9435100555419922, 0.7253988981246948, 0.13498817384243011]], [[0.9964852333068848, 0.8427878618240356, 0.8895476460456848, 0.11014710366725922, 0.6270533204078674, 0.44782018661499023, 0.61344975233078, 0.44898226857185364]], [[0.9970942735671997, 0.7105681896209717, 0.5078442096710205, 0.0951600968837738, 0.624282717704773, 0.8527252674102783, 0.8105692863464355, 0.7857823967933655]], [[0.6180334091186523, 0.11882413923740387, 0.06102970987558365, 0.04484326392412186, 0.05602221190929413, 0.052324872463941574, 0.9969874024391174, 0.9987731575965881]]]  # pylint: disable=line-too-long


@dataclasses.dataclass
class AssembleNet(hyperparams.Config):
  model_id: str = '50'
  num_frames: int = 0
  combine_method: str = 'sigmoid'
  blocks: Tuple[BlockSpec, ...] = tuple()


@dataclasses.dataclass
class AssembleNetPlus(hyperparams.Config):
  model_id: str = '50'
  num_frames: int = 0
  attention_mode: str = 'None'
  blocks: Tuple[BlockSpec, ...] = tuple()
  use_object_input: bool = False


@dataclasses.dataclass
class Backbone3D(backbones_3d.Backbone3D):
  """Configuration for backbones.

  Attributes:
    type: 'str', type of backbone be used, on the of fields below.
    assemblenet: AssembleNet backbone config.
    assemblenet_plus : AssembleNetPlus backbone config.
  """
  type: Optional[str] = None
  assemblenet: AssembleNet = dataclasses.field(default_factory=AssembleNet)
  assemblenet_plus: AssembleNetPlus = dataclasses.field(
      default_factory=AssembleNetPlus
  )


@dataclasses.dataclass
class AssembleNetModel(video_classification.VideoClassificationModel):
  """The AssembleNet model config."""
  model_type: str = 'assemblenet'
  backbone: Backbone3D = dataclasses.field(
      default_factory=lambda: Backbone3D(type='assemblenet')
  )
  norm_activation: common.NormActivation = dataclasses.field(
      default_factory=lambda: common.NormActivation(  # pylint: disable=g-long-lambda
          norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=True
      )
  )
  max_pool_predictions: bool = False


@dataclasses.dataclass
class AssembleNetPlusModel(video_classification.VideoClassificationModel):
  """The AssembleNet model config."""
  model_type: str = 'assemblenet_plus'
  backbone: Backbone3D = dataclasses.field(
      default_factory=lambda: Backbone3D(type='assemblenet_plus')
  )
  norm_activation: common.NormActivation = dataclasses.field(
      default_factory=lambda: common.NormActivation(  # pylint: disable=g-long-lambda
          norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=True
      )
  )
  max_pool_predictions: bool = False


@exp_factory.register_config_factory('assemblenet50_kinetics600')
def assemblenet_kinetics600() -> cfg.ExperimentConfig:
  """Video classification on Videonet with assemblenet."""
  exp = video_classification.video_classification_kinetics600()

  feature_shape = (32, 224, 224, 3)
  exp.task.train_data.global_batch_size = 1024
  exp.task.validation_data.global_batch_size = 32
  exp.task.train_data.feature_shape = feature_shape
  exp.task.validation_data.feature_shape = (120, 224, 224, 3)
  exp.task.train_data.dtype = 'bfloat16'
  exp.task.validation_data.dtype = 'bfloat16'

  model = AssembleNetModel()
  model.backbone.assemblenet.model_id = '50'
  model.backbone.assemblenet.blocks = flat_lists_to_blocks(
      asn50_structure, asn_structure_weights)
  model.backbone.assemblenet.num_frames = feature_shape[0]
  exp.task.model = model

  assert exp.task.model.backbone.assemblenet.num_frames > 0, (
      f'backbone num_frames '
      f'{exp.task.model.backbone.assemblenet}')

  return exp


@exp_factory.register_config_factory('assemblenet_ucf101')
def assemblenet_ucf101() -> cfg.ExperimentConfig:
  """Video classification on Videonet with assemblenet."""
  exp = video_classification.video_classification_ucf101()
  exp.task.train_data.dtype = 'bfloat16'
  exp.task.validation_data.dtype = 'bfloat16'
  feature_shape = (32, 224, 224, 3)
  model = AssembleNetModel()
  model.backbone.assemblenet.blocks = flat_lists_to_blocks(
      asn50_structure, asn_structure_weights)
  model.backbone.assemblenet.num_frames = feature_shape[0]
  exp.task.model = model

  assert exp.task.model.backbone.assemblenet.num_frames > 0, (
      f'backbone num_frames '
      f'{exp.task.model.backbone.assemblenet}')

  return exp


@exp_factory.register_config_factory('assemblenetplus_ucf101')
def assemblenetplus_ucf101() -> cfg.ExperimentConfig:
  """Video classification on Videonet with assemblenet."""
  exp = video_classification.video_classification_ucf101()
  exp.task.train_data.dtype = 'bfloat16'
  exp.task.validation_data.dtype = 'bfloat16'
  feature_shape = (32, 224, 224, 3)
  model = AssembleNetPlusModel()
  model.backbone.assemblenet_plus.blocks = flat_lists_to_blocks(
      asn50_structure, asn_structure_weights)
  model.backbone.assemblenet_plus.num_frames = feature_shape[0]
  exp.task.model = model

  assert exp.task.model.backbone.assemblenet_plus.num_frames > 0, (
      f'backbone num_frames '
      f'{exp.task.model.backbone.assemblenet_plus}')

  return exp