official/projects/yt8m/configs/yt8m.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Video classification configuration definition."""
import dataclasses
from typing import Optional, Tuple
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.configs import common
YT8M_TRAIN_EXAMPLES = 3888919
YT8M_VAL_EXAMPLES = 1112356
# 2/frame -> frame level
# 3/frame -> segment level
YT8M_TRAIN_PATH = 'gs://youtube8m-ml/2/frame/train/train*.tfrecord'
YT8M_VAL_PATH = 'gs://youtube8m-ml/3/frame/validate/validate*.tfrecord'
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""The base configuration for building datasets.
Attributes:
name: Dataset name.
split: dataset split, 'train' or 'valid'.
feature_sizes: shape(length) of each feature specified in the feature_names.
feature_names: names of the features in the tf.SequenceExample.
feature_sources: if the feature from 'context' or 'features'.
feature_dtypes: dtype of decoded feature.
feature_from_bytes: decode feature from bytes or as dtype list.
label_fields: name of field to read from tf.SequenceExample.
segment_size: Number of frames in each segment.
segment_labels: Use segment level label. Default: False, video level label.
include_video_id: `True` means include video id (string) in the input to the
model.
temporal_stride: Not used. Need to deprecated.
max_frames: Maxim Number of frames in a input example. It is used to crop
the input in the temporal dimension.
sample_random_frames: If sample random frames or random sequence.
num_sample_frames: Number of frames to sample for each input example. No
frame sampling if None.
num_classes: Number of classes to classify. Assuming it is a classification
task.
num_devices: Not used. To be deprecated.
input_path: The path to the input.
is_training: Whether this data is used for training or not.
num_examples: Number of examples in the dataset. It is used to compute the
steps for train or eval. set the value to `-1` to make the experiment run
until the end of dataset.
file_type: type of input files.
"""
name: Optional[str] = 'yt8m'
split: Optional[str] = None
feature_sizes: Tuple[int, ...] = (1024, 128)
feature_names: Tuple[str, ...] = ('rgb', 'audio')
feature_sources: Tuple[str, ...] = ('feature', 'feature')
feature_dtypes: Tuple[str, ...] = ('uint8', 'uint8')
feature_from_bytes: Tuple[bool, ...] = (True, True)
label_field: str = 'labels'
segment_size: int = 1
segment_labels: bool = False
include_video_id: bool = False
temporal_stride: int = 1
max_frames: int = 300 # Cap input frames.
sample_random_frames: bool = True
# Sample random frames if not None. No sampling in inference.
num_sample_frames: Optional[int] = 300
input_per_feature_l2_norm: bool = False
prefetch_buffer_size: int = 100
shuffle_buffer_size: int = 100
num_classes: int = 3862
num_devices: int = 1
input_path: str = ''
is_training: bool = True
num_examples: int = -1
file_type: str = 'tfrecord'
def yt8m(is_training):
"""YT8M dataset configs."""
# pylint: disable=unexpected-keyword-arg
return DataConfig(
temporal_stride=1,
segment_labels=False,
segment_size=5,
is_training=is_training,
split='train' if is_training else 'valid',
drop_remainder=is_training, # pytype: disable=wrong-keyword-args
num_examples=YT8M_TRAIN_EXAMPLES if is_training else YT8M_VAL_EXAMPLES,
input_path=YT8M_TRAIN_PATH if is_training else YT8M_VAL_PATH)
# pylint: enable=unexpected-keyword-arg
@dataclasses.dataclass
class DbofModel(hyperparams.Config):
"""The model config."""
cluster_size: int = 3000
hidden_size: int = 2000
add_batch_norm: bool = True
pooling_method: str = 'average'
use_context_gate_cluster_layer: bool = False
context_gate_cluster_bottleneck_size: int = 0
@dataclasses.dataclass
class Backbone(hyperparams.OneOfConfig):
"""Configuration for backbones.
Attributes:
type: 'str', type of backbone be used, one of the fields below.
dbof: dbof backbone config.
"""
type: Optional[str] = None
dbof: DbofModel = dataclasses.field(default_factory=DbofModel)
@dataclasses.dataclass
class MoeModel(hyperparams.Config):
"""The MoE model config."""
num_mixtures: int = 5
vocab_as_last_dim: bool = False
use_input_context_gate: bool = False
use_output_context_gate: bool = False
@dataclasses.dataclass
class LogisticModel(hyperparams.Config):
"""The logistic model config."""
return_logits: bool = False
@dataclasses.dataclass
class Head(hyperparams.OneOfConfig):
"""Configuration for aggreagation heads.
Attributes:
type: 'str', type of head be used, one of the fields below.
moe: MoE head config.
logistic: Logistic head config.
"""
type: Optional[str] = None
moe: MoeModel = dataclasses.field(default_factory=MoeModel)
logistic: LogisticModel = dataclasses.field(default_factory=LogisticModel)
@dataclasses.dataclass
class VideoClassificationModel(hyperparams.Config):
"""The classifier model config."""
backbone: Backbone = dataclasses.field(
default_factory=lambda: Backbone(type='dbof')
)
head: Head = dataclasses.field(default_factory=lambda: Head(type='moe'))
norm_activation: common.NormActivation = dataclasses.field(
default_factory=lambda: common.NormActivation( # pylint: disable=g-long-lambda
activation='relu', use_sync_bn=False
)
)
@dataclasses.dataclass
class Losses(hyperparams.Config):
name: str = 'binary_crossentropy'
from_logits: bool = False
label_smoothing: float = 0.0
l2_weight_decay: float = 1e-5
@dataclasses.dataclass
class AveragePrecisionConfig(hyperparams.Config):
top_k: int = 20
top_n: Optional[int] = None
return_per_class_ap: bool = False
@dataclasses.dataclass
class Evaluation(hyperparams.Config):
average_precision: Optional[AveragePrecisionConfig] = None
@dataclasses.dataclass
class YT8MTask(cfg.TaskConfig):
"""The task config."""
model: VideoClassificationModel = dataclasses.field(
default_factory=VideoClassificationModel
)
train_data: DataConfig = dataclasses.field(
default_factory=lambda: yt8m(is_training=True)
)
validation_data: DataConfig = dataclasses.field(
default_factory=lambda: yt8m(is_training=False)
)
losses: Losses = dataclasses.field(default_factory=Losses)
evaluation: Evaluation = dataclasses.field(
default_factory=lambda: Evaluation( # pylint: disable=g-long-lambda
average_precision=AveragePrecisionConfig()
)
)
gradient_clip_norm: float = 1.0
def add_trainer(
experiment: cfg.ExperimentConfig,
train_batch_size: int,
eval_batch_size: int,
learning_rate: float = 0.0001,
train_epochs: int = 50,
num_train_examples: int = YT8M_TRAIN_EXAMPLES,
num_val_examples: int = YT8M_VAL_EXAMPLES,
) -> cfg.ExperimentConfig:
"""Adds and config a trainer to the experiment config."""
if num_train_examples <= 0:
raise ValueError('Wrong train dataset size {!r}'.format(
experiment.task.train_data))
if num_val_examples <= 0:
raise ValueError('Wrong validation dataset size {!r}'.format(
experiment.task.validation_data))
experiment.task.train_data.global_batch_size = train_batch_size
experiment.task.validation_data.global_batch_size = eval_batch_size
steps_per_epoch = num_train_examples // train_batch_size
steps_per_loop = 500
experiment.trainer = cfg.TrainerConfig(
steps_per_loop=steps_per_loop,
summary_interval=steps_per_loop,
checkpoint_interval=steps_per_loop,
train_steps=train_epochs * steps_per_epoch,
validation_steps=num_val_examples // eval_batch_size,
validation_interval=steps_per_loop,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
'adam': {}
},
'learning_rate': {
'type': 'exponential',
'exponential': {
'initial_learning_rate': learning_rate,
'decay_rate': 0.95,
'decay_steps': int(steps_per_epoch * 1.5),
'offset': 500,
}
},
'warmup': {
'linear': {
'name': 'linear',
'warmup_learning_rate': 0,
'warmup_steps': 500,
},
'type': 'linear',
}
}))
return experiment
@exp_factory.register_config_factory('yt8m_experiment')
def yt8m_experiment() -> cfg.ExperimentConfig:
"""Video classification general."""
exp_config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=YT8MTask(),
trainer=cfg.TrainerConfig(),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
'task.train_data.num_classes == task.validation_data.num_classes',
'task.train_data.feature_sizes != None',
'task.train_data.feature_names != None',
'task.train_data.feature_sources != None',
'task.train_data.feature_dtypes != None',
])
# Per TPUv3 Core batch size 16GB HBM. `factor` in range(1, 26)
factor = 1
num_cores = 32 # for TPUv3 4x4
train_per_core_bs = 32 * factor
train_bs = train_per_core_bs * num_cores
eval_per_core_bs = 4 * 50 # multiplier<=100
eval_bs = eval_per_core_bs * num_cores
# based lr=0.0001 for bs=512
return add_trainer(
exp_config,
train_batch_size=train_bs,
eval_batch_size=eval_bs,
learning_rate=0.0001 * (train_bs / 512),
train_epochs=100)