jhfjhfj1/autokeras

View on GitHub
autokeras/engine/block.py

Summary

Maintainability
A
0 mins
Test Coverage
# Copyright 2020 The AutoKeras Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tree

from autokeras.engine import named_hypermodel
from autokeras.engine import node as node_module


class Block(named_hypermodel.NamedHyperModel):
    """The base class for different Block.

    The Block can be connected together to build the search space for an
    AutoModel. Notably, many args in the __init__ function are defaults to be a
    tunable variable when not specified by the user.

    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.inputs = None
        self.outputs = None
        self._num_output_node = 1

    def _build_wrapper(self, hp, *args, **kwargs):
        with hp.name_scope(self.name):
            return super()._build_wrapper(hp, *args, **kwargs)

    def __call__(self, inputs):
        """Functional API.

        # Arguments
            inputs: A list of input node(s) or a single input node for the
                block.

        # Returns
            list: A list of output node(s) of the Block.
        """
        self.inputs = tree.flatten(inputs)
        for input_node in self.inputs:
            if not isinstance(input_node, node_module.Node):
                raise TypeError(
                    "Expect the inputs to block {name} to be "
                    "a Node, but got {type}.".format(
                        name=self.name, type=type(input_node)
                    )
                )
            input_node.add_out_block(self)
        self.outputs = []
        for _ in range(self._num_output_node):
            output_node = node_module.Node()
            output_node.add_in_block(self)
            self.outputs.append(output_node)
        return self.outputs

    def build(self, hp, inputs=None):
        """Build the Block into a real Keras Model.

        The subclasses should override this function and return the output node.

        # Arguments
            hp: HyperParameters. The hyperparameters for building the model.
            inputs: A list of input node(s).
        """
        raise NotImplementedError