autokeras/blocks/reduction.py
# Copyright 2020 The AutoKeras Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import tree
from keras import layers
from keras import ops
from autokeras.engine import block as block_module
from autokeras.utils import layer_utils
from autokeras.utils import utils
REDUCTION_TYPE = "reduction_type"
FLATTEN = "flatten"
GLOBAL_MAX = "global_max"
GLOBAL_AVG = "global_avg"
def shape_compatible(shape1, shape2):
if len(shape1) != len(shape2):
return False
# TODO: If they can be the same after passing through any layer,
# they are compatible. e.g. (32, 32, 3), (16, 16, 2) are compatible
return shape1[:-1] == shape2[:-1]
class Merge(block_module.Block):
"""Merge block to merge multiple nodes into one.
# Arguments
merge_type: String. 'add' or 'concatenate'. If left unspecified, it will
be tuned automatically.
"""
def __init__(self, merge_type: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.merge_type = merge_type
def get_config(self):
config = super().get_config()
config.update({"merge_type": self.merge_type})
return config
def build(self, hp, inputs=None):
inputs = tree.flatten(inputs)
if len(inputs) == 1:
return inputs
if not all(
[
shape_compatible(input_node.shape, inputs[0].shape)
for input_node in inputs
]
):
inputs = [Flatten().build(hp, input_node) for input_node in inputs]
# TODO: Even inputs have different shape[-1], they can still be Add(
# ) after another layer. Check if the inputs are all of the same
# shape
if self._inputs_same_shape(inputs):
merge_type = self.merge_type or hp.Choice(
"merge_type", ["add", "concatenate"], default="add"
)
if merge_type == "add":
return layers.Add()(inputs)
return layers.Concatenate()(inputs)
def _inputs_same_shape(self, inputs):
return all(input_node.shape == inputs[0].shape for input_node in inputs)
class Flatten(block_module.Block):
"""Flatten the input tensor with Keras Flatten layer."""
def build(self, hp, inputs=None):
inputs = tree.flatten(inputs)
utils.validate_num_inputs(inputs, 1)
input_node = inputs[0]
if len(input_node.shape) > 2:
return layers.Flatten()(input_node)
return input_node
class Reduction(block_module.Block):
def __init__(self, reduction_type: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.reduction_type = reduction_type
def get_config(self):
config = super().get_config()
config.update({REDUCTION_TYPE: self.reduction_type})
return config
def global_max(self, input_node):
raise NotImplementedError
def global_avg(self, input_node):
raise NotImplementedError
def build(self, hp, inputs=None):
inputs = tree.flatten(inputs)
utils.validate_num_inputs(inputs, 1)
input_node = inputs[0]
output_node = input_node
# No need to reduce.
if len(output_node.shape) <= 2:
return output_node
if self.reduction_type is not None:
return self._build_block(hp, output_node, self.reduction_type)
reduction_type = hp.Choice(
REDUCTION_TYPE, [FLATTEN, GLOBAL_MAX, GLOBAL_AVG]
)
with hp.conditional_scope(REDUCTION_TYPE, [reduction_type]):
return self._build_block(hp, output_node, reduction_type)
def _build_block(self, hp, output_node, reduction_type):
if reduction_type == FLATTEN:
output_node = Flatten().build(hp, output_node)
elif reduction_type == GLOBAL_MAX:
output_node = self.global_max(output_node)
elif reduction_type == GLOBAL_AVG:
output_node = self.global_avg(output_node)
return output_node
class SpatialReduction(Reduction):
"""Reduce the dimension of a spatial tensor, e.g. image, to a vector.
# Arguments
reduction_type: String. 'flatten', 'global_max' or 'global_avg'.
If left unspecified, it will be tuned automatically.
"""
def __init__(self, reduction_type: Optional[str] = None, **kwargs):
super().__init__(reduction_type, **kwargs)
def global_max(self, input_node):
return layer_utils.get_global_max_pooling(input_node.shape)()(
input_node
)
def global_avg(self, input_node):
return layer_utils.get_global_average_pooling(input_node.shape)()(
input_node
)
class TemporalReduction(Reduction):
"""Reduce the dim of a temporal tensor, e.g. output of RNN, to a vector.
# Arguments
reduction_type: String. 'flatten', 'global_max' or 'global_avg'. If left
unspecified, it will be tuned automatically.
"""
def __init__(self, reduction_type: Optional[str] = None, **kwargs):
super().__init__(reduction_type, **kwargs)
def global_max(self, input_node):
return ops.max(input_node, axis=-2)
def global_avg(self, input_node):
return ops.mean(input_node, axis=-2)