official/projects/centernet/modeling/layers/cn_nn_blocks_test.py
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Centernet nn_blocks.
It is a literal translation of the PyTorch implementation.
"""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf, tf_keras
from official.projects.centernet.modeling.layers import cn_nn_blocks
from official.vision.modeling.layers import nn_blocks
class HourglassBlockPyTorch(tf_keras.layers.Layer):
"""An CornerNet-style implementation of the hourglass block."""
def __init__(self, dims, modules, k=0, **kwargs):
"""An CornerNet-style implementation of the hourglass block.
Args:
dims: input sizes of residual blocks
modules: number of repetitions of the residual blocks in each hourglass
upsampling and downsampling
k: recursive parameter
**kwargs: Additional keyword arguments to be passed.
"""
super(HourglassBlockPyTorch).__init__()
if len(dims) != len(modules):
raise ValueError('dims and modules lists must have the same length')
self.n = len(dims) - 1
self.k = k
self.modules = modules
self.dims = dims
self._kwargs = kwargs
def build(self, input_shape):
modules = self.modules
dims = self.dims
k = self.k
kwargs = self._kwargs
curr_mod = modules[k]
next_mod = modules[k + 1]
curr_dim = dims[k + 0]
next_dim = dims[k + 1]
self.up1 = self.make_up_layer(3, curr_dim, curr_dim, curr_mod, **kwargs)
self.max1 = tf_keras.layers.MaxPool2D(strides=2)
self.low1 = self.make_hg_layer(3, curr_dim, next_dim, curr_mod, **kwargs)
if self.n - k > 1:
self.low2 = type(self)(dims, modules, k=k + 1, **kwargs)
else:
self.low2 = self.make_low_layer(
3, next_dim, next_dim, next_mod, **kwargs)
self.low3 = self.make_hg_layer_revr(
3, next_dim, curr_dim, curr_mod, **kwargs)
self.up2 = tf_keras.layers.UpSampling2D(2)
self.merge = tf_keras.layers.Add()
super(HourglassBlockPyTorch, self).build(input_shape)
def call(self, x):
up1 = self.up1(x)
max1 = self.max1(x)
low1 = self.low1(max1)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3)
return self.merge([up1, up2])
def make_layer(self, k, inp_dim, out_dim, modules, **kwargs):
layers = [
nn_blocks.ResidualBlock(out_dim, 1, use_projection=True, **kwargs)]
for _ in range(1, modules):
layers.append(nn_blocks.ResidualBlock(out_dim, 1, **kwargs))
return tf_keras.Sequential(layers)
def make_layer_revr(self, k, inp_dim, out_dim, modules, **kwargs):
layers = []
for _ in range(modules - 1):
layers.append(
nn_blocks.ResidualBlock(inp_dim, 1, **kwargs))
layers.append(
nn_blocks.ResidualBlock(out_dim, 1, use_projection=True, **kwargs))
return tf_keras.Sequential(layers)
def make_up_layer(self, k, inp_dim, out_dim, modules, **kwargs):
return self.make_layer(k, inp_dim, out_dim, modules, **kwargs)
def make_low_layer(self, k, inp_dim, out_dim, modules, **kwargs):
return self.make_layer(k, inp_dim, out_dim, modules, **kwargs)
def make_hg_layer(self, k, inp_dim, out_dim, modules, **kwargs):
return self.make_layer(k, inp_dim, out_dim, modules, **kwargs)
def make_hg_layer_revr(self, k, inp_dim, out_dim, modules, **kwargs):
return self.make_layer_revr(k, inp_dim, out_dim, modules, **kwargs)
class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
def test_hourglass_block(self):
dims = [256, 256, 384, 384, 384, 512]
modules = [2, 2, 2, 2, 2, 4]
model = cn_nn_blocks.HourglassBlock(dims, modules)
test_input = tf_keras.Input((512, 512, 256))
_ = model(test_input)
filter_sizes = [256, 256, 384, 384, 384, 512]
rep_sizes = [2, 2, 2, 2, 2, 4]
hg_test_input_shape = (1, 512, 512, 256)
# bb_test_input_shape = (1, 512, 512, 3)
x_hg = tf.ones(shape=hg_test_input_shape)
# x_bb = tf.ones(shape=bb_test_input_shape)
hg = cn_nn_blocks.HourglassBlock(
channel_dims_per_stage=filter_sizes,
blocks_per_stage=rep_sizes)
hg.build(input_shape=hg_test_input_shape)
out = hg(x_hg)
self.assertAllEqual(
tf.shape(out), hg_test_input_shape,
'Hourglass module output shape and expected shape differ')
# ODAPI Test
layer = cn_nn_blocks.HourglassBlock(
blocks_per_stage=[2, 3, 4, 5, 6],
channel_dims_per_stage=[4, 6, 8, 10, 12])
output = layer(np.zeros((2, 64, 64, 4), dtype=np.float32))
self.assertEqual(output.shape, (2, 64, 64, 4))
if __name__ == '__main__':
tf.test.main()