tensorflow/models

View on GitHub
official/projects/simclr/heads/simclr_head_test.py

Summary

Maintainability
A
25 mins
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import parameterized

import numpy as np
import tensorflow as tf, tf_keras

from official.projects.simclr.heads import simclr_head


class ProjectionHeadTest(tf.test.TestCase, parameterized.TestCase):

  @parameterized.parameters(
      (0, None),
      (1, 128),
      (2, 128),
  )
  def test_head_creation(self, num_proj_layers, proj_output_dim):
    test_layer = simclr_head.ProjectionHead(
        num_proj_layers=num_proj_layers,
        proj_output_dim=proj_output_dim)

    input_dim = 64
    x = tf_keras.Input(shape=(input_dim,))
    proj_head_output, proj_finetune_output = test_layer(x)

    proj_head_output_dim = input_dim
    if num_proj_layers > 0:
      proj_head_output_dim = proj_output_dim
    self.assertAllEqual(proj_head_output.shape.as_list(),
                        [None, proj_head_output_dim])

    if num_proj_layers > 0:
      proj_finetune_output_dim = input_dim
      self.assertAllEqual(proj_finetune_output.shape.as_list(),
                          [None, proj_finetune_output_dim])

  @parameterized.parameters(
      (0, None, 0),
      (1, 128, 0),
      (2, 128, 1),
      (2, 128, 2),
  )
  def test_outputs(self, num_proj_layers, proj_output_dim, ft_proj_idx):
    test_layer = simclr_head.ProjectionHead(
        num_proj_layers=num_proj_layers,
        proj_output_dim=proj_output_dim,
        ft_proj_idx=ft_proj_idx
    )

    input_dim = 64
    batch_size = 2
    inputs = np.random.rand(batch_size, input_dim)
    proj_head_output, proj_finetune_output = test_layer(inputs)

    if num_proj_layers == 0:
      self.assertAllClose(inputs, proj_head_output)
      self.assertAllClose(inputs, proj_finetune_output)
    else:
      self.assertAllEqual(proj_head_output.shape.as_list(),
                          [batch_size, proj_output_dim])
      if ft_proj_idx == 0:
        self.assertAllClose(inputs, proj_finetune_output)
      elif ft_proj_idx < num_proj_layers:
        self.assertAllEqual(proj_finetune_output.shape.as_list(),
                            [batch_size, input_dim])
      else:
        self.assertAllEqual(proj_finetune_output.shape.as_list(),
                            [batch_size, proj_output_dim])


class ClassificationHeadTest(tf.test.TestCase, parameterized.TestCase):

  @parameterized.parameters(
      10, 20
  )
  def test_head_creation(self, num_classes):
    test_layer = simclr_head.ClassificationHead(num_classes=num_classes)

    input_dim = 64
    x = tf_keras.Input(shape=(input_dim,))
    out_x = test_layer(x)

    self.assertAllEqual(out_x.shape.as_list(),
                        [None, num_classes])


if __name__ == '__main__':
  tf.test.main()