deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java

Summary

Maintainability
F
5 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.modelimport.keras.utils;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasTFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.layers.attention.KerasAttentionLayer;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise;
import org.deeplearning4j.nn.modelimport.keras.layers.normalization.KerasBatchNormalization;
import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasGlobalPooling;
import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasPooling1D;
import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasPooling2D;
import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasPooling3D;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasLSTM;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasSimpleRnn;
import org.deeplearning4j.nn.modelimport.keras.layers.wrappers.KerasBidirectional;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;

import java.lang.reflect.Constructor;
import java.util.*;

@Slf4j
public class KerasLayerUtils {

    /**
     * Checks whether layer config contains unsupported options.
     *
     * @param layerConfig           dictionary containing Keras layer configuration
     * @param enforceTrainingConfig whether to use Keras training configuration
     * @throws InvalidKerasConfigurationException     Invalid Keras config
     * @throws UnsupportedKerasConfigurationException Unsupported Keras config
     */
    public static void checkForUnsupportedConfigurations(Map<String, Object> layerConfig,
                                                         boolean enforceTrainingConfig,
                                                         KerasLayerConfiguration conf)
            throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        getBiasL1RegularizationFromConfig(layerConfig, enforceTrainingConfig, conf);
        getBiasL2RegularizationFromConfig(layerConfig, enforceTrainingConfig, conf);
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (innerConfig.containsKey(conf.getLAYER_FIELD_W_REGULARIZER())) {
            checkForUnknownRegularizer((Map<String, Object>) innerConfig.get(conf.getLAYER_FIELD_W_REGULARIZER()),
                    enforceTrainingConfig, conf);
        }
        if (innerConfig.containsKey(conf.getLAYER_FIELD_B_REGULARIZER())) {
            checkForUnknownRegularizer((Map<String, Object>) innerConfig.get(conf.getLAYER_FIELD_B_REGULARIZER()),
                    enforceTrainingConfig, conf);
        }
    }

    /**
     * Get L1 bias regularization (if any) from Keras bias regularization configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return L1 regularization strength (0.0 if none)
     */
    public static double getBiasL1RegularizationFromConfig(Map<String, Object> layerConfig,
                                                           boolean enforceTrainingConfig,
                                                           KerasLayerConfiguration conf)
            throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (innerConfig.containsKey(conf.getLAYER_FIELD_B_REGULARIZER())) {
            Map<String, Object> regularizerConfig =
                    (Map<String, Object>) innerConfig.get(conf.getLAYER_FIELD_B_REGULARIZER());
            if (regularizerConfig != null && regularizerConfig.containsKey(conf.getREGULARIZATION_TYPE_L1()))
                throw new UnsupportedKerasConfigurationException("L1 regularization for bias parameter not supported");
        }
        return 0.0;
    }

    /**
     * Get L2 bias regularization (if any) from Keras bias regularization configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return L1 regularization strength (0.0 if none)
     */
    private static double getBiasL2RegularizationFromConfig(Map<String, Object> layerConfig,
                                                            boolean enforceTrainingConfig,
                                                            KerasLayerConfiguration conf)
            throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (innerConfig.containsKey(conf.getLAYER_FIELD_B_REGULARIZER())) {
            Map<String, Object> regularizerConfig =
                    (Map<String, Object>) innerConfig.get(conf.getLAYER_FIELD_B_REGULARIZER());
            if (regularizerConfig != null && regularizerConfig.containsKey(conf.getREGULARIZATION_TYPE_L2()))
                throw new UnsupportedKerasConfigurationException("L2 regularization for bias parameter not supported");
        }
        return 0.0;
    }

    /**
     * Check whether Keras weight regularization is of unknown type. Currently prints a warning
     * since main use case for model import is inference, not further training. Unlikely since
     * standard Keras weight regularizers are L1 and L2.
     *
     * @param regularizerConfig Map containing Keras weight reguarlization configuration
     */
    private static void checkForUnknownRegularizer(Map<String, Object> regularizerConfig, boolean enforceTrainingConfig,
                                                   KerasLayerConfiguration conf)
            throws UnsupportedKerasConfigurationException {
        if (regularizerConfig != null) {
            for (String field : regularizerConfig.keySet()) {
                if (!field.equals(conf.getREGULARIZATION_TYPE_L1()) && !field.equals(conf.getREGULARIZATION_TYPE_L2())
                        && !field.equals(conf.getLAYER_FIELD_NAME())
                        && !field.equals(conf.getLAYER_FIELD_CLASS_NAME())
                        && !field.equals(conf.getLAYER_FIELD_CONFIG())) {
                    if (enforceTrainingConfig)
                        throw new UnsupportedKerasConfigurationException("Unknown regularization field " + field);
                    else
                        log.warn("Ignoring unknown regularization field " + field);
                }
            }
        }
    }


    /**
     * Build KerasLayer from a Keras layer configuration.
     *
     * @param layerConfig map containing Keras layer properties
     * @return KerasLayer
     * @see Layer
     */
    public static KerasLayer getKerasLayerFromConfig(Map<String, Object> layerConfig,
                                                     KerasLayerConfiguration conf,
                                                     Map<String, Class<? extends KerasLayer>> customLayers,
                                                     Map<String, SameDiffLambdaLayer> lambdaLayers,
                                                     Map<String, ? extends KerasLayer> previousLayers)
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return getKerasLayerFromConfig(layerConfig, false, conf, customLayers, lambdaLayers, previousLayers);
    }

    /**
     * Build KerasLayer from a Keras layer configuration. Building layer with
     * enforceTrainingConfig=true will throw exceptions for unsupported Keras
     * options related to training (e.g., unknown regularizers). Otherwise
     * we only generate warnings.
     *
     * @param layerConfig           map containing Keras layer properties
     * @param enforceTrainingConfig whether to enforce training-only configurations
     * @return KerasLayer
     * @see Layer
     */
    public static KerasLayer getKerasLayerFromConfig(Map<String, Object> layerConfig,
                                                     boolean enforceTrainingConfig,
                                                     KerasLayerConfiguration conf,
                                                     Map<String, Class<? extends KerasLayer>> customLayers,
                                                     Map<String, SameDiffLambdaLayer> lambdaLayers,
                                                     Map<String, ? extends KerasLayer> previousLayers
    )
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        String layerClassName = getClassNameFromConfig(layerConfig, conf);
        if (layerClassName.equals(conf.getLAYER_CLASS_NAME_TIME_DISTRIBUTED())) {
            layerConfig = getTimeDistributedLayerConfig(layerConfig, conf);
            layerClassName = getClassNameFromConfig(layerConfig, conf);
        }
        KerasLayer layer = null;
        if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ACTIVATION())) {
            layer = new KerasActivation(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LEAKY_RELU())) {
            layer = new KerasLeakyReLU(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MASKING())) {
            layer = new KerasMasking(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_THRESHOLDED_RELU())) {
            layer = new KerasThresholdedReLU(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_PRELU())) {
            layer = new KerasPReLU(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_DROPOUT())) {
            layer = new KerasDropout(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_SPATIAL_DROPOUT_1D())
                || layerClassName.equals(conf.getLAYER_CLASS_NAME_SPATIAL_DROPOUT_2D())
                || layerClassName.equals(conf.getLAYER_CLASS_NAME_SPATIAL_DROPOUT_3D())) {
            layer = new KerasSpatialDropout(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ALPHA_DROPOUT())) {
            layer = new KerasAlphaDropout(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_GAUSSIAN_DROPOUT())) {
            layer = new KerasGaussianDropout(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_GAUSSIAN_NOISE())) {
            layer = new KerasGaussianNoise(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_DENSE()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_TIME_DISTRIBUTED_DENSE())) {
            layer = new KerasDense(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_BIDIRECTIONAL())) {
            layer = new KerasBidirectional(layerConfig, enforceTrainingConfig, previousLayers);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LSTM())) {
            layer = new KerasLSTM(layerConfig, enforceTrainingConfig, previousLayers);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_SIMPLE_RNN())) {
            layer = new KerasSimpleRnn(layerConfig, enforceTrainingConfig, previousLayers);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CONVOLUTION_3D())) {
            layer = new KerasConvolution3D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CONVOLUTION_2D())) {
            layer = new KerasConvolution2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_DECONVOLUTION_2D())) {
            layer = new KerasDeconvolution2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_DECONVOLUTION_3D())) {
            layer = new KerasDeconvolution3D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CONVOLUTION_1D())) {
            layer = new KerasConvolution1D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ATROUS_CONVOLUTION_2D())) {
            layer = new KerasAtrousConvolution2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ATROUS_CONVOLUTION_1D())) {
            layer = new KerasAtrousConvolution1D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_DEPTHWISE_CONVOLUTION_2D())) {
            layer = new KerasDepthwiseConvolution2D(layerConfig, previousLayers, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_SEPARABLE_CONVOLUTION_2D())) {
            layer = new KerasSeparableConvolution2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MAX_POOLING_3D()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_AVERAGE_POOLING_3D())) {
            layer = new KerasPooling3D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MAX_POOLING_2D()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_AVERAGE_POOLING_2D())) {
            layer = new KerasPooling2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MAX_POOLING_1D()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_AVERAGE_POOLING_1D())) {
            layer = new KerasPooling1D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_GLOBAL_AVERAGE_POOLING_1D()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_GLOBAL_AVERAGE_POOLING_2D()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_GLOBAL_AVERAGE_POOLING_3D()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_GLOBAL_MAX_POOLING_1D()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_GLOBAL_MAX_POOLING_2D()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_GLOBAL_MAX_POOLING_3D())) {
            layer = new KerasGlobalPooling(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_BATCHNORMALIZATION())) {
            layer = new KerasBatchNormalization(layerConfig, enforceTrainingConfig, previousLayers);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_EMBEDDING())) {
            layer = new KerasEmbedding(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_INPUT())) {
            layer = new KerasInput(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_REPEAT())) {
            layer = new KerasRepeatVector(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_PERMUTE())) {
            layer = new KerasPermute(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MERGE())) {
            layer = new KerasMerge(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ADD()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_ADD())) {
            layer = new KerasMerge(layerConfig, ElementWiseVertex.Op.Add, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_SUBTRACT()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_FUNCTIONAL_SUBTRACT())) {
            layer = new KerasMerge(layerConfig, ElementWiseVertex.Op.Subtract, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_AVERAGE()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_FUNCTIONAL_AVERAGE())) {
            layer = new KerasMerge(layerConfig, ElementWiseVertex.Op.Average, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MULTIPLY()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_FUNCTIONAL_MULTIPLY())) {
            layer = new KerasMerge(layerConfig, ElementWiseVertex.Op.Product, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MAXIMUM()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_FUNCTIONAL_MAXIMUM())) {
            layer = new KerasMerge(layerConfig, ElementWiseVertex.Op.Max, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CONCATENATE()) ||
                layerClassName.equals(conf.getLAYER_CLASS_NAME_FUNCTIONAL_CONCATENATE())) {
            layer = new KerasMerge(layerConfig, null, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_FLATTEN())) {
            layer = new KerasFlatten(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_RESHAPE())) {
            layer = new KerasReshape(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ZERO_PADDING_1D())) {
            layer = new KerasZeroPadding1D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ZERO_PADDING_2D())) {
            layer = new KerasZeroPadding2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ZERO_PADDING_3D())) {
            layer = new KerasZeroPadding3D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_UPSAMPLING_1D())) {
            layer = new KerasUpsampling1D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_UPSAMPLING_2D())) {
            layer = new KerasUpsampling2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_UPSAMPLING_3D())) {
            layer = new KerasUpsampling3D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CROPPING_3D())) {
            layer = new KerasCropping3D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CROPPING_2D())) {
            layer = new KerasCropping2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CROPPING_1D())) {
            layer = new KerasCropping1D(layerConfig, enforceTrainingConfig);
        } else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_ATTENTION())) {
            layer = new KerasAttentionLayer(layerConfig,enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LAMBDA())) {
            String lambdaLayerName = KerasLayerUtils.getLayerNameFromConfig(layerConfig, conf);
            if (!lambdaLayers.containsKey(lambdaLayerName) && !customLayers.containsKey(layerClassName)) {
                throw new UnsupportedKerasConfigurationException("No SameDiff Lambda layer found for Lambda " +
                        "layer " + lambdaLayerName + ". You can register a SameDiff Lambda layer using KerasLayer." +
                        "registerLambdaLayer(lambdaLayerName, sameDiffLambdaLayer);");
            }

            SameDiffLambdaLayer lambdaLayer = lambdaLayers.get(lambdaLayerName);
            if (lambdaLayer != null) {
                layer = new KerasLambda(layerConfig, enforceTrainingConfig, lambdaLayer);
            }
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_RELU())) {
            layer = new KerasReLU(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ELU())) {
            layer = new KerasELU(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())) {
            layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_1D())) {
            layer = new KerasLocallyConnected1D(layerConfig, enforceTrainingConfig);
        } else if (conf instanceof Keras2LayerConfiguration) {
            Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration) conf;
            if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())) {
                layer = new KerasTFOpLayer(layerConfig, enforceTrainingConfig);
            }
        }
        if (layer == null) {
            Class<? extends KerasLayer> customConfig = customLayers.get(layerClassName);
            if (customConfig == null)
                throw new UnsupportedKerasConfigurationException("Unsupported keras layer type " + layerClassName);
            try {
                Constructor constructor = customConfig.getConstructor(Map.class);
                layer = (KerasLayer) constructor.newInstance(layerConfig);
            } catch (Exception e) {
                throw new RuntimeException("The keras custom class " + layerClassName + " needs to have a constructor with only Map<String,Object> as the argument. Please ensure this is defined."
                        , e);
            }
        }
        return layer;
    }

    /**
     * Get Keras layer class name from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return Keras layer class name
     * @throws InvalidKerasConfigurationException Invalid Keras config
     */
    public static String getClassNameFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException {
        if (!layerConfig.containsKey(conf.getLAYER_FIELD_CLASS_NAME()))
            throw new InvalidKerasConfigurationException(
                    "Field " + conf.getLAYER_FIELD_CLASS_NAME() + " missing from layer config");
        return (String) layerConfig.get(conf.getLAYER_FIELD_CLASS_NAME());
    }

    /**
     * Extract inner layer config from TimeDistributed configuration and merge
     * it into the outer config.
     *
     * @param layerConfig dictionary containing Keras TimeDistributed configuration
     * @return Time distributed layer config
     * @throws InvalidKerasConfigurationException Invalid Keras config
     */
    public static Map<String, Object> getTimeDistributedLayerConfig(Map<String, Object> layerConfig,
                                                                    KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException {
        if (!layerConfig.containsKey(conf.getLAYER_FIELD_CLASS_NAME()))
            throw new InvalidKerasConfigurationException(
                    "Field " + conf.getLAYER_FIELD_CLASS_NAME() + " missing from layer config");
        if (!layerConfig.get(conf.getLAYER_FIELD_CLASS_NAME()).equals(conf.getLAYER_CLASS_NAME_TIME_DISTRIBUTED()))
            throw new InvalidKerasConfigurationException("Expected " + conf.getLAYER_CLASS_NAME_TIME_DISTRIBUTED()
                    + " layer, found " + layerConfig.get(conf.getLAYER_FIELD_CLASS_NAME()));
        if (!layerConfig.containsKey(conf.getLAYER_FIELD_CONFIG()))
            throw new InvalidKerasConfigurationException("Field "
                    + conf.getLAYER_FIELD_CONFIG() + " missing from layer config");
        Map<String, Object> outerConfig = getInnerLayerConfigFromConfig(layerConfig, conf);
        Map<String, Object> innerLayer = (Map<String, Object>) outerConfig.get(conf.getLAYER_FIELD_LAYER());
        layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), innerLayer.get(conf.getLAYER_FIELD_CLASS_NAME()));
        Map<String, Object> innerConfig = getInnerLayerConfigFromConfig(innerLayer, conf);
        innerConfig.put(conf.getLAYER_FIELD_NAME(), outerConfig.get(conf.getLAYER_FIELD_NAME()));
        outerConfig.putAll(innerConfig);
        outerConfig.remove(conf.getLAYER_FIELD_LAYER());
        return layerConfig;
    }

    /**
     * Get inner layer config from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return Inner layer config for a nested Keras layer configuration
     * @throws InvalidKerasConfigurationException Invalid Keras config
     */
    public static Map<String, Object> getInnerLayerConfigFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException {
        if (!layerConfig.containsKey(conf.getLAYER_FIELD_CONFIG()))
            throw new InvalidKerasConfigurationException("Field "
                    + conf.getLAYER_FIELD_CONFIG() + " missing from layer config");
        return (Map<String, Object>) layerConfig.get(conf.getLAYER_FIELD_CONFIG());
    }

    /**
     * Get layer name from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return Keras layer name
     * @throws InvalidKerasConfigurationException Invalid Keras config
     */
    public static String getLayerNameFromConfig(Map<String, Object> layerConfig,
                                                KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException {
        if (conf instanceof Keras2LayerConfiguration) {
            Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration) conf;
            if (getClassNameFromConfig(layerConfig, conf).equals(((Keras2LayerConfiguration) conf).getTENSORFLOW_OP_LAYER())) {
                if (!layerConfig.containsKey(conf.getLAYER_FIELD_NAME()))
                    throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME()
                            + " missing from layer config");
                return (String) layerConfig.get(conf.getLAYER_FIELD_NAME());
            }
        }

        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (!innerConfig.containsKey(conf.getLAYER_FIELD_NAME()))
            throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME()
                    + " missing from layer config");
        return (String) innerConfig.get(conf.getLAYER_FIELD_NAME());
    }

    /**
     * Get Keras input shape from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return input shape array
     */
    public static int[] getInputShapeFromConfig(Map<String, Object> layerConfig,
                                                KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException {
        // TODO: validate this. shouldn't we also have INPUT_SHAPE checked?
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (!innerConfig.containsKey(conf.getLAYER_FIELD_BATCH_INPUT_SHAPE()))
            return null;
        List<Integer> batchInputShape = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_BATCH_INPUT_SHAPE());
        int[] inputShape = new int[batchInputShape.size() - 1];
        for (int i = 1; i < batchInputShape.size(); i++) {
            inputShape[i - 1] = batchInputShape.get(i) != null ? batchInputShape.get(i) : 0;
        }
        return inputShape;
    }

    /**
     * Get Keras (backend) dimension order from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return Dimension order
     */
    public static KerasLayer.DimOrder getDimOrderFromConfig(Map<String, Object> layerConfig,
                                                            KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        KerasLayer.DimOrder dimOrder = KerasLayer.DimOrder.NONE;
        if (layerConfig.containsKey(conf.getLAYER_FIELD_BACKEND())) {
            String backend = (String) layerConfig.get(conf.getLAYER_FIELD_BACKEND());
            if (backend.equals("tensorflow") || backend.equals("cntk")) {
                dimOrder = KerasLayer.DimOrder.TENSORFLOW;
            } else if (backend.equals("theano")) {
                dimOrder = KerasLayer.DimOrder.THEANO;
            }
        }
        if (innerConfig.containsKey(conf.getLAYER_FIELD_DIM_ORDERING())) {
            String dimOrderStr = (String) innerConfig.get(conf.getLAYER_FIELD_DIM_ORDERING());
            if (dimOrderStr.equals(conf.getDIM_ORDERING_TENSORFLOW())) {
                dimOrder = KerasLayer.DimOrder.TENSORFLOW;
            } else if (dimOrderStr.equals(conf.getDIM_ORDERING_THEANO())) {
                dimOrder = KerasLayer.DimOrder.THEANO;
            } else {
                log.warn("Keras layer has unknown Keras dimension order: " + dimOrder);
            }
        }
        return dimOrder;
    }

    /**
     * Get list of inbound layers from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return List of inbound layer names
     */
    public static List<String> getInboundLayerNamesFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) {
        List<String> inboundLayerNames = new ArrayList<>();
        if (layerConfig.containsKey(conf.getLAYER_FIELD_INBOUND_NODES())) {
            List<Object> inboundNodes = (List<Object>) layerConfig.get(conf.getLAYER_FIELD_INBOUND_NODES());
            if (!inboundNodes.isEmpty()) {
                for (Object nodeName : inboundNodes) {
                    List<Object> list = (List<Object>) nodeName;
                    for (Object o : list) {
                        List<Object> list2 = (List<Object>) o;
                        inboundLayerNames.add(list2.get(0).toString());

                    }
                }


            }


        }
        return inboundLayerNames;
    }

    /**
     * Get list of inbound layers from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return List of inbound layer names
     */
    public static List<String> getOutboundLayerNamesFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) {
        List<String> outputLayerNames = new ArrayList<>();
        if (layerConfig.containsKey(conf.getLAYER_FIELD_OUTBOUND_NODES())) {
            List<Object> outboundNodes = (List<Object>) layerConfig.get(conf.getLAYER_FIELD_OUTBOUND_NODES());
            if (!outboundNodes.isEmpty()) {
                outboundNodes = (List<Object>) outboundNodes.get(0);
                for (Object o : outboundNodes) {
                    String nodeName = (String) ((List<Object>) o).get(0);
                    outputLayerNames.add(nodeName);
                }
            }
        }
        return outputLayerNames;
    }

    /**
     * Get number of outputs from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return Number of output neurons of the Keras layer
     * @throws InvalidKerasConfigurationException Invalid Keras config
     */
    public static int getNOutFromConfig(Map<String, Object> layerConfig,
                                        KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        int nOut;
        if (innerConfig.containsKey(conf.getLAYER_FIELD_OUTPUT_DIM()))
            /* Most feedforward layers: Dense, RNN, etc. */
            nOut = (int) innerConfig.get(conf.getLAYER_FIELD_OUTPUT_DIM());
        else if (innerConfig.containsKey(conf.getLAYER_FIELD_EMBEDDING_OUTPUT_DIM()))
            /* Embedding layers. */
            nOut = (int) innerConfig.get(conf.getLAYER_FIELD_EMBEDDING_OUTPUT_DIM());
        else if (innerConfig.containsKey(conf.getLAYER_FIELD_NB_FILTER()))
            /* Convolutional layers. */
            nOut = (int) innerConfig.get(conf.getLAYER_FIELD_NB_FILTER());
        else
            throw new InvalidKerasConfigurationException("Could not determine number of outputs for layer: no "
                    + conf.getLAYER_FIELD_OUTPUT_DIM() + " or " + conf.getLAYER_FIELD_NB_FILTER() + " field found");
        return nOut;
    }

    public static Integer getNInFromInputDim(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (innerConfig.containsKey(conf.getLAYER_FIELD_INPUT_DIM())) {
            Object id = innerConfig.get(conf.getLAYER_FIELD_INPUT_DIM());
            if (id instanceof Number) {
                return ((Number) id).intValue();
            }
        }
        return null;
    }

    /**
     * Get dropout from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return get dropout value from Keras config
     * @throws InvalidKerasConfigurationException Invalid Keras config
     */
    public static double getDropoutFromConfig(Map<String, Object> layerConfig,
                                              KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        /* NOTE: Keras "dropout" parameter determines dropout probability,
         * while DL4J "dropout" parameter determines retention probability.
         */
        double dropout = 1.0;
        if (innerConfig.containsKey(conf.getLAYER_FIELD_DROPOUT())) {
            /* For most feedforward layers. */
            try {
                dropout = 1.0 - (double) innerConfig.get(conf.getLAYER_FIELD_DROPOUT());
            } catch (Exception e) {
                int kerasDropout = (int) innerConfig.get(conf.getLAYER_FIELD_DROPOUT());
                dropout = 1.0 - kerasDropout;
            }
        } else if (innerConfig.containsKey(conf.getLAYER_FIELD_DROPOUT_W())) {
            /* For LSTMs. */
            try {
                dropout = 1.0 - (double) innerConfig.get(conf.getLAYER_FIELD_DROPOUT_W());
            } catch (Exception e) {
                int kerasDropout = (int) innerConfig.get(conf.getLAYER_FIELD_DROPOUT_W());
                dropout = 1.0 - kerasDropout;
            }
        }
        return dropout;
    }

    /**
     * Determine if layer should be instantiated with bias
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return whether layer has a bias term
     * @throws InvalidKerasConfigurationException Invalid Keras config
     */
    public static boolean getHasBiasFromConfig(Map<String, Object> layerConfig,
                                               KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        boolean hasBias = true;
        if (innerConfig.containsKey(conf.getLAYER_FIELD_USE_BIAS())) {
            hasBias = (boolean) innerConfig.get(conf.getLAYER_FIELD_USE_BIAS());
        }
        return hasBias;
    }

    /**
     * Get zero masking flag
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return if masking zeros or not
     * @throws InvalidKerasConfigurationException Invalid Keras configuration
     */
    public static boolean getZeroMaskingFromConfig(Map<String, Object> layerConfig,
                                                   KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        boolean hasZeroMasking = true;
        if (innerConfig.containsKey(conf.getLAYER_FIELD_MASK_ZERO())) {
            hasZeroMasking = (boolean) innerConfig.get(conf.getLAYER_FIELD_MASK_ZERO());
        }
        return hasZeroMasking;
    }

    /**
     * Get mask value
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return mask value, defaults to 0.0
     * @throws InvalidKerasConfigurationException Invalid Keras configuration
     */
    public static double getMaskingValueFromConfig(Map<String, Object> layerConfig,
                                                   KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        double maskValue = 0.0;
        if (innerConfig.containsKey(conf.getLAYER_FIELD_MASK_VALUE())) {
            try {
                maskValue = (double) innerConfig.get(conf.getLAYER_FIELD_MASK_VALUE());
            } catch (Exception e) {
                log.warn("Couldn't read masking value, default to 0.0");
            }
        } else {
            throw new InvalidKerasConfigurationException("No mask value found, field "
                    + conf.getLAYER_FIELD_MASK_VALUE());
        }
        return maskValue;
    }


    /**
     * Remove weights from config after weight setting.
     *
     * @param weights layer weights
     * @param conf    Keras layer configuration
     */
    public static void removeDefaultWeights(Map<String, INDArray> weights, KerasLayerConfiguration conf) {
        if (weights.size() > 2) {
            Set<String> paramNames = weights.keySet();
            paramNames.remove(conf.getKERAS_PARAM_NAME_W());
            paramNames.remove(conf.getKERAS_PARAM_NAME_B());
            String unknownParamNames = paramNames.toString();
            log.warn("Attemping to set weights for unknown parameters: "
                    + unknownParamNames.substring(1, unknownParamNames.length() - 1));
        }
    }

    public static Pair<Boolean, Double> getMaskingConfiguration(List<String> inboundLayerNames,
                                                                Map<String, ? extends KerasLayer> previousLayers) {
        Boolean hasMasking = false;
        Double maskingValue = 0.0;
        for (String inboundLayerName : inboundLayerNames) {
            if (previousLayers.containsKey(inboundLayerName)) {
                KerasLayer inbound = previousLayers.get(inboundLayerName);
                if (inbound instanceof KerasEmbedding && ((KerasEmbedding) inbound).isZeroMasking()) {
                    hasMasking = true;
                } else if (inbound instanceof KerasMasking) {
                    hasMasking = true;
                    maskingValue = ((KerasMasking) inbound).getMaskingValue();
                }
            }
        }
        return new Pair<>(hasMasking, maskingValue);
    }

}