jhfjhfj1/autokeras

View on GitHub
autokeras/utils/utils.py

Summary

Maintainability
A
1 hr
Test Coverage
# Copyright 2020 The AutoKeras Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re

import keras
import keras_tuner
import tensorflow as tf
import tree


def validate_num_inputs(inputs, num):
    inputs = tree.flatten(inputs)
    if not len(inputs) == num:
        raise ValueError(
            "Expected {num} elements in the inputs list "
            "but received {len} inputs.".format(num=num, len=len(inputs))
        )


def to_snake_case(name):
    intermediate = re.sub("(.)([A-Z][a-z0-9]+)", r"\1_\2", name)
    insecure = re.sub("([a-z])([A-Z])", r"\1_\2", intermediate).lower()
    return insecure


def contain_instance(instance_list, instance_type):
    return any(
        [isinstance(instance, instance_type) for instance in instance_list]
    )


def evaluate_with_adaptive_batch_size(
    model, batch_size, verbose=1, **fit_kwargs
):
    return run_with_adaptive_batch_size(
        batch_size,
        lambda x, validation_data, **kwargs: model.evaluate(
            x, verbose=verbose, **kwargs
        ),
        **fit_kwargs,
    )


def predict_with_adaptive_batch_size(
    model, batch_size, verbose=1, **fit_kwargs
):
    return run_with_adaptive_batch_size(
        batch_size,
        lambda x, validation_data, **kwargs: model.predict(
            x, verbose=verbose, **kwargs
        ),
        **fit_kwargs,
    )


def fit_with_adaptive_batch_size(model, batch_size, **fit_kwargs):
    history = run_with_adaptive_batch_size(
        batch_size, lambda **kwargs: model.fit(**kwargs), **fit_kwargs
    )
    return model, history


def run_with_adaptive_batch_size(batch_size, func, **fit_kwargs):
    x = fit_kwargs.pop("x")
    validation_data = None
    if "validation_data" in fit_kwargs:
        validation_data = fit_kwargs.pop("validation_data")
    while batch_size > 0:
        try:
            history = func(x=x, validation_data=validation_data, **fit_kwargs)
            break
        except tf.errors.ResourceExhaustedError as e:
            if batch_size == 1:
                raise e
            batch_size //= 2
            print(
                "Not enough memory, reduce batch size to {batch_size}.".format(
                    batch_size=batch_size
                )
            )
            x = x.unbatch().batch(batch_size)
            if validation_data is not None:
                validation_data = validation_data.unbatch().batch(batch_size)
    return history


def get_hyperparameter(value, hp, dtype):
    if value is None:
        return hp
    return value


def add_to_hp(hp, hps, name=None):
    """Add the HyperParameter (self) to the HyperParameters.

    # Arguments
        hp: keras_tuner.HyperParameters.
        name: String. If left unspecified, the hp name is used.
    """
    if not isinstance(hp, keras_tuner.engine.hyperparameters.HyperParameter):
        return hp
    kwargs = hp.get_config()
    if name is None:
        name = hp.name
    kwargs.pop("conditions")
    kwargs.pop("name")
    class_name = hp.__class__.__name__
    func = getattr(hps, class_name)
    return func(name=name, **kwargs)


def serialize_keras_object(obj):
    return keras.utils.serialize_keras_object(obj)  # pragma: no cover


def deserialize_keras_object(config, module_objects=None, custom_objects=None):
    return keras.utils.deserialize_keras_object(
        config, custom_objects, module_objects
    )