jhfjhfj1/autokeras

View on GitHub
autokeras/blocks/basic_test.py

Summary

Maintainability
F
1 wk
Test Coverage
# Copyright 2020 The AutoKeras Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import keras
import keras_tuner
import pytest
import tensorflow as tf
import tree

from autokeras import blocks
from autokeras import test_utils


def test_resnet_build_return_tensor():
    block = blocks.ResNetBlock()

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(32, 32, 3), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_resnet_v1_return_tensor():
    block = blocks.ResNetBlock(version="v1")

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(32, 32, 3), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_efficientnet_b0_return_tensor():
    block = blocks.EfficientNetBlock(version="b0", pretrained=False)

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(32, 32, 3), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_resnet_pretrained_build_return_tensor():
    block = blocks.ResNetBlock(pretrained=True)

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(32, 32, 3), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_resnet_pretrained_with_one_channel_input():
    block = blocks.ResNetBlock(pretrained=True)

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(28, 28, 1), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_resnet_pretrained_error_with_two_channels():
    block = blocks.ResNetBlock(pretrained=True)

    with pytest.raises(ValueError) as info:
        block.build(
            keras_tuner.HyperParameters(),
            keras.Input(shape=(224, 224, 2), dtype=tf.float32),
        )

    assert "When pretrained is set to True" in str(info.value)


def test_resnet_deserialize_to_resnet():
    serialized_block = blocks.serialize(blocks.ResNetBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.ResNetBlock)


def test_resnet_get_config_has_all_attributes():
    block = blocks.ResNetBlock()

    config = block.get_config()

    assert test_utils.get_func_args(blocks.ResNetBlock.__init__).issubset(
        config.keys()
    )


def test_resnet_wrong_version_error():
    with pytest.raises(ValueError) as info:
        blocks.ResNetBlock(version="abc")

    assert "Expect version to be" in str(info.value)


def test_efficientnet_wrong_version_error():
    with pytest.raises(ValueError) as info:
        blocks.EfficientNetBlock(version="abc")

    assert "Expect version to be" in str(info.value)


def test_xception_build_return_tensor():
    block = blocks.XceptionBlock()

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(32, 32, 2), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_xception_pretrained_build_return_tensor():
    block = blocks.XceptionBlock(pretrained=True)

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(32, 32, 3), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_xception_pretrained_with_one_channel_input():
    block = blocks.XceptionBlock(pretrained=True)

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(224, 224, 1), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_xception_pretrained_error_with_two_channels():
    block = blocks.XceptionBlock(pretrained=True)

    with pytest.raises(ValueError) as info:
        block.build(
            keras_tuner.HyperParameters(),
            keras.Input(shape=(224, 224, 2), dtype=tf.float32),
        )

    assert "When pretrained is set to True" in str(info.value)


def test_xception_deserialize_to_xception():
    serialized_block = blocks.serialize(blocks.XceptionBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.XceptionBlock)


def test_xception_get_config_has_all_attributes():
    block = blocks.XceptionBlock()

    config = block.get_config()

    assert test_utils.get_func_args(blocks.XceptionBlock.__init__).issubset(
        config.keys()
    )


def test_conv_build_return_tensor():
    block = blocks.ConvBlock()

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(32, 32, 3), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_conv_with_small_image_size_return_tensor():
    block = blocks.ConvBlock()

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(10, 10, 3), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_conv_build_with_dropout_return_tensor():
    block = blocks.ConvBlock(dropout=0.5)

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(32, 32, 3), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_conv_deserialize_to_conv():
    serialized_block = blocks.serialize(blocks.ConvBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.ConvBlock)


def test_conv_get_config_has_all_attributes():
    block = blocks.ConvBlock()

    config = block.get_config()

    assert test_utils.get_func_args(blocks.ConvBlock.__init__).issubset(
        config.keys()
    )


def test_rnn_build_return_tensor():
    block = blocks.RNNBlock(bidirectional=False)

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(32, 10), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_rnn_input_shape_one_dim_error():
    block = blocks.RNNBlock()

    with pytest.raises(ValueError) as info:
        block.build(
            keras_tuner.HyperParameters(),
            keras.Input(shape=(32,), dtype=tf.float32),
        )

    assert "Expect the input tensor of RNNBlock" in str(info.value)


def test_rnn_deserialize_to_rnn():
    serialized_block = blocks.serialize(blocks.RNNBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.RNNBlock)


def test_rnn_get_config_has_all_attributes():
    block = blocks.RNNBlock()

    config = block.get_config()

    assert test_utils.get_func_args(blocks.RNNBlock.__init__).issubset(
        config.keys()
    )


def test_dense_build_return_tensor():
    block = blocks.DenseBlock(
        num_units=keras_tuner.engine.hyperparameters.Choice(
            "num_units", [10, 20]
        )
    )

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(32,), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_dense_build_with_dropout_return_tensor():
    block = blocks.DenseBlock(dropout=0.5)

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(32,), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_dense_build_with_bn_return_tensor():
    block = blocks.DenseBlock(use_batchnorm=True)

    outputs = block.build(
        keras_tuner.HyperParameters(),
        keras.Input(shape=(32,), dtype=tf.float32),
    )

    assert len(tree.flatten(outputs)) == 1


def test_dense_deserialize_to_dense():
    serialized_block = blocks.serialize(blocks.DenseBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.DenseBlock)


def test_dense_get_config_has_all_attributes():
    block = blocks.DenseBlock()

    config = block.get_config()

    assert test_utils.get_func_args(blocks.DenseBlock.__init__).issubset(
        config.keys()
    )


def test_bert_build_return_tensor():
    block = blocks.BertBlock()

    outputs = block.build(
        keras_tuner.HyperParameters(), keras.Input(shape=(1,), dtype=tf.string)
    )

    assert len(tree.flatten(outputs)) == 1


def test_bert_deserialize_to_bert():
    serialized_block = blocks.serialize(blocks.BertBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.BertBlock)


def test_bert_get_config_has_all_attributes():
    block = blocks.BertBlock()

    config = block.get_config()

    assert test_utils.get_func_args(blocks.BertBlock.__init__).issubset(
        config.keys()
    )