tensorflow/models

View on GitHub
official/projects/mae/tasks/image_classification_test.py

Summary

Maintainability
F
3 days
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for image_classification."""

import numpy as np
import tensorflow as tf, tf_keras
import tensorflow_datasets as tfds

from official.modeling import optimization
from official.projects.mae.tasks import image_classification as vit_cls
from official.vision.configs import image_classification


_NUM_EXAMPLES = 10


def _gen_fn():
  h = np.random.randint(0, 300)
  w = np.random.randint(0, 300)
  return {
      'image': np.ones(shape=(h, w, 3), dtype=np.uint8),
      'label': np.random.randint(0, 100),
      'file_name': 'test',
  }


def _as_dataset(self, *args, **kwargs):
  del args
  del kwargs
  return tf.data.Dataset.from_generator(
      lambda: (_gen_fn() for i in range(_NUM_EXAMPLES)),
      output_types=self.info.features.dtype,
      output_shapes=self.info.features.shape,
  )


class ImageClassificationTest(tf.test.TestCase):

  def test_train_step(self):
    config = vit_cls.ViTConfig(
        num_classes=1000,
        train_data=image_classification.DataConfig(
            tfds_name='imagenet2012',
            tfds_split='validation',
            is_training=True,
            global_batch_size=2,
        ),
    )
    with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
      task = vit_cls.ViTClassificationTask(config)
      model = task.build_model()
      dataset = task.build_inputs(config.train_data)
      iterator = iter(dataset)
      opt_cfg = optimization.OptimizationConfig({
          'optimizer': {
              'type': 'adamw',
              'adamw': {
                  'weight_decay_rate': 0.05,
                  # Avoid AdamW legacy behavior.
                  'gradient_clip_norm': 0.0
              }
          },
          'learning_rate': {
              'type': 'cosine',
              'cosine': {
                  'initial_learning_rate': 1.5 * 1e-4,
                  'decay_steps': 5
              }
          },
          'warmup': {
              'type': 'linear',
              'linear': {
                  'warmup_steps': 1,
                  'warmup_learning_rate': 0
              }
          }
      })
      optimizer = vit_cls.ViTClassificationTask.create_optimizer(opt_cfg)
      task.train_step(next(iterator), model, optimizer)


if __name__ == '__main__':
  tf.test.main()