autokeras/graph.py
# Copyright 2020 The AutoKeras Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras
import keras_tuner
import tree
from autokeras import blocks as blocks_module
from autokeras import nodes as nodes_module
from autokeras.engine import head as head_module
from autokeras.engine import serializable
from autokeras.utils import io_utils
def load_graph(filepath, custom_objects=None):
if custom_objects is None:
custom_objects = {}
with keras.utils.custom_object_scope(custom_objects):
return Graph.from_config(io_utils.load_json(filepath))
class Graph(keras_tuner.HyperModel, serializable.Serializable):
"""A graph consists of connected Blocks, or Heads.
# Arguments
inputs: A list of input node(s) for the Graph.
outputs: A list of output node(s) for the Graph.
"""
def __init__(self, inputs=None, outputs=None, **kwargs):
super().__init__(**kwargs)
self.inputs = tree.flatten(inputs)
self.outputs = tree.flatten(outputs)
self._node_to_id = {}
self._nodes = []
self.blocks = []
self._block_to_id = {}
if inputs and outputs:
self._build_network()
# Temporary attributes
self.epochs = None
self.num_samples = None
def _build_network(self):
self._node_to_id = {}
# Recursively find all the interested nodes.
for input_node in self.inputs:
self._search_network(input_node, self.outputs, set(), set())
self._nodes = sorted(
list(self._node_to_id.keys()), key=lambda x: self._node_to_id[x]
)
for node in self.inputs + self.outputs:
if node not in self._node_to_id:
raise ValueError("Inputs and outputs not connected.")
# Find the blocks.
blocks = []
for input_node in self._nodes:
for block in input_node.out_blocks:
if (
any(
[
output_node in self._node_to_id
for output_node in block.outputs
]
)
and block not in blocks
):
blocks.append(block)
# Check if all the inputs of the blocks are set as inputs.
for block in blocks:
for input_node in block.inputs:
if input_node not in self._node_to_id:
raise ValueError(
"A required input is missing for HyperModel "
"{name}.".format(name=block.name)
)
# Calculate the in degree of all the nodes
in_degree = [0] * len(self._nodes)
for node_id, node in enumerate(self._nodes):
in_degree[node_id] = len(
[block for block in node.in_blocks if block in blocks]
)
# Add the blocks in topological order.
self.blocks = []
self._block_to_id = {}
while len(blocks) != 0:
new_added = []
# Collect blocks with in degree 0.
for block in blocks:
if any(
[in_degree[self._node_to_id[node]] for node in block.inputs]
):
continue
new_added.append(block)
# Remove the collected blocks from blocks.
for block in new_added:
blocks.remove(block)
for block in new_added:
# Add the collected blocks to the Graph.
self._add_block(block)
# Decrease the in degree of the output nodes.
for output_node in block.outputs:
output_node_id = self._node_to_id[output_node]
in_degree[output_node_id] -= 1
def _search_network(
self, input_node, outputs, in_stack_nodes, visited_nodes
):
visited_nodes.add(input_node)
in_stack_nodes.add(input_node)
outputs_reached = False
if input_node in outputs:
outputs_reached = True
for block in input_node.out_blocks:
for output_node in block.outputs:
if output_node in in_stack_nodes:
raise ValueError("The network has a cycle.")
if output_node not in visited_nodes:
self._search_network(
output_node, outputs, in_stack_nodes, visited_nodes
)
if output_node in self._node_to_id.keys():
outputs_reached = True
if outputs_reached:
self._add_node(input_node)
in_stack_nodes.remove(input_node)
def _add_block(self, block):
if block not in self.blocks:
block_id = len(self.blocks)
self._block_to_id[block] = block_id
self.blocks.append(block)
def _add_node(self, input_node):
if input_node not in self._node_to_id:
self._node_to_id[input_node] = len(self._node_to_id)
def get_config(self):
blocks = [blocks_module.serialize(block) for block in self.blocks]
nodes = {
str(self._node_to_id[node]): nodes_module.serialize(node)
for node in self.inputs
}
block_inputs = {
str(block_id): [self._node_to_id[node] for node in block.inputs]
for block_id, block in enumerate(self.blocks)
}
block_outputs = {
str(block_id): [self._node_to_id[node] for node in block.outputs]
for block_id, block in enumerate(self.blocks)
}
outputs = [self._node_to_id[node] for node in self.outputs]
return {
"blocks": blocks, # Dict {id: serialized}.
"nodes": nodes, # Dict {id: serialized}.
"outputs": outputs, # List of node_ids.
"block_inputs": block_inputs, # Dict {id: List of node_ids}.
"block_outputs": block_outputs, # Dict {id: List of node_ids}.
}
@classmethod
def from_config(cls, config):
blocks = [
blocks_module.deserialize(block) for block in config["blocks"]
]
nodes = {
int(node_id): nodes_module.deserialize(node)
for node_id, node in config["nodes"].items()
}
inputs = [nodes[node_id] for node_id in nodes]
for block_id, block in enumerate(blocks):
input_nodes = [
nodes[node_id]
for node_id in config["block_inputs"][str(block_id)]
]
output_nodes = tree.flatten(block(input_nodes))
for output_node, node_id in zip(
output_nodes, config["block_outputs"][str(block_id)]
):
nodes[node_id] = output_node
outputs = [nodes[node_id] for node_id in config["outputs"]]
return cls(inputs=inputs, outputs=outputs)
def build(self, hp):
"""Build the HyperModel into a Keras Model."""
keras_nodes = {}
keras_input_nodes = []
for node in self.inputs:
node_id = self._node_to_id[node]
input_node = node.build_node(hp)
output_node = node.build(hp, input_node)
keras_input_nodes.append(input_node)
keras_nodes[node_id] = output_node
for block in self.blocks:
temp_inputs = [
keras_nodes[self._node_to_id[input_node]]
for input_node in block.inputs
]
outputs = block.build(hp, inputs=temp_inputs)
outputs = tree.flatten(outputs)
for output_node, real_output_node in zip(block.outputs, outputs):
keras_nodes[self._node_to_id[output_node]] = real_output_node
model = keras.Model(
keras_input_nodes,
[
keras_nodes[self._node_to_id[output_node]]
for output_node in self.outputs
],
)
return self._compile_keras_model(hp, model)
def _get_metrics(self):
metrics = {}
for output_node in self.outputs:
block = output_node.in_blocks[0]
if isinstance(block, head_module.Head):
metrics[block.name] = block.metrics
return metrics
def _get_loss(self):
loss = {}
for output_node in self.outputs:
block = output_node.in_blocks[0]
if isinstance(block, head_module.Head):
loss[block.name] = block.loss
return loss
def _compile_keras_model(self, hp, model):
# Specify hyperparameters from compile(...)
optimizer_name = hp.Choice(
"optimizer",
["adam", "sgd", "adam_weight_decay"],
default="adam",
)
# TODO: add adadelta optimizer when it can optimize embedding layer on
# GPU.
learning_rate = hp.Choice(
"learning_rate", [1e-1, 1e-2, 1e-3, 1e-4, 2e-5, 1e-5], default=1e-3
)
if optimizer_name == "adam":
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
elif optimizer_name == "sgd":
optimizer = keras.optimizers.SGD(learning_rate=learning_rate)
elif optimizer_name == "adam_weight_decay":
steps_per_epoch = int(self.num_samples / self.batch_size)
num_train_steps = steps_per_epoch * self.epochs
lr_schedule = keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=learning_rate,
decay_steps=num_train_steps,
end_learning_rate=0.0,
)
optimizer = keras.optimizers.AdamW(
learning_rate=lr_schedule,
weight_decay=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
)
model.compile(
optimizer=optimizer,
metrics=self._get_metrics(),
loss=self._get_loss(),
)
return model
def save(self, filepath):
io_utils.save_json(filepath, self.get_config())
def set_io_shapes(self, shapes):
for node, shape in zip(self.inputs, tree.flatten(shapes[0])):
node.shape = tuple(shape[1:])
for node, shape in zip(self.outputs, tree.flatten(shapes[1])):
node.in_blocks[0].shape = tuple(shape[1:])
def set_fit_args(self, validation_split, epochs=None):
self.epochs = epochs
# Epochs not specified by the user
if self.epochs is None:
self.epochs = 1
# num_samples from analysers are before split
self.num_samples = self.inputs[0].num_samples * (1 - validation_split)
@property
def batch_size(self):
return self.inputs[0].batch_size