deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java

Summary

Maintainability
D
2 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;

import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.nd4j.common.util.ArrayUtil;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class KerasConvolutionUtils {




    /**
     * Get (convolution) stride from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return Strides array from Keras configuration
     * @throws InvalidKerasConfigurationException Invalid Keras config
     */
    public static int[] getStrideFromConfig(Map<String, Object> layerConfig, int dimension,
                                            KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        int[] strides;
        if (innerConfig.containsKey(conf.getLAYER_FIELD_CONVOLUTION_STRIDES()) && dimension >= 2) {
            /* 2D/3D Convolutional layers. */
            @SuppressWarnings("unchecked")
            List<Integer> stridesList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_CONVOLUTION_STRIDES());
            strides = ArrayUtil.toArray(stridesList);
        } else if (innerConfig.containsKey(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH()) && dimension == 1) {
            /* 1D Convolutional layers. */
            if ((int) layerConfig.get("keras_version") == 2) {
                @SuppressWarnings("unchecked")
                List<Integer> stridesList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH());
                strides = ArrayUtil.toArray(stridesList);
            } else {
                int subsampleLength = (int) innerConfig.get(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH());
                strides = new int[]{subsampleLength};
            }
        } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_STRIDES()) && dimension >= 2) {
            /* 2D/3D Pooling layers. */
            @SuppressWarnings("unchecked")
            List<Integer> stridesList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_STRIDES());
            strides = ArrayUtil.toArray(stridesList);
        } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_STRIDES()) && dimension == 1) {
            /* 1D Pooling layers. */
            int stride = (int) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_STRIDES());
            strides = new int[]{stride};
        } else
            throw new InvalidKerasConfigurationException("Could not determine layer stride: no "
                    + conf.getLAYER_FIELD_CONVOLUTION_STRIDES() + " or "
                    + conf.getLAYER_FIELD_POOL_STRIDES() + " field found");
        return strides;
    }

    static int getDepthMultiplier(Map<String, Object> layerConfig, KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        return  (int) innerConfig.get(conf.getLAYER_FIELD_DEPTH_MULTIPLIER());
    }

    /**
     * Get atrous / dilation rate from config
     *
     * @param layerConfig   dictionary containing Keras layer configuration
     * @param dimension     dimension of the convolution layer (1 or 2)
     * @param conf          Keras layer configuration
     * @param forceDilation boolean to indicate if dilation argument should be in config
     * @return list of integers with atrous rates
     *
     * @throws InvalidKerasConfigurationException Invalid Keras config
     */
    public static int[] getDilationRate(Map<String, Object> layerConfig, int dimension, KerasLayerConfiguration conf,
                                        boolean forceDilation)
            throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        int[] atrousRate;
        if (innerConfig.containsKey(conf.getLAYER_FIELD_DILATION_RATE()) && dimension >= 2) {
            @SuppressWarnings("unchecked")
            List<Integer> atrousRateList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_DILATION_RATE());
            atrousRate = ArrayUtil.toArray(atrousRateList);
        } else if (innerConfig.containsKey(conf.getLAYER_FIELD_DILATION_RATE()) && dimension == 1) {
            if ((int) layerConfig.get("keras_version") == 2) {
                @SuppressWarnings("unchecked")
                List<Integer> atrousRateList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_DILATION_RATE());
                atrousRate = new int[]{atrousRateList.get(0), atrousRateList.get(0)};
            } else {
                int atrous = (int) innerConfig.get(conf.getLAYER_FIELD_DILATION_RATE());
                atrousRate = new int[]{atrous, atrous};
            }
        } else {
            // If we are using keras 1, for regular convolutions, there is no "atrous" argument, for keras
            // 2 there always is.
            if (forceDilation)
                throw new InvalidKerasConfigurationException("Could not determine dilation rate: no "
                        + conf.getLAYER_FIELD_DILATION_RATE() + " field found");
            else
                atrousRate = null;
        }
        return atrousRate;

    }


    /**
     * Return the {@link Convolution3D.DataFormat}
     * from the configuration .
     * If the value is {@link KerasLayerConfiguration#getDIM_ORDERING_TENSORFLOW()}
     * then the value is {@link Convolution3D.DataFormat#NDHWC }
     * else it's {@link KerasLayerConfiguration#getDIM_ORDERING_THEANO()}
     * which is {@link Convolution3D.DataFormat#NDHWC}
     * @param layerConfig the layer configuration to get the values from
     * @param layerConfiguration the keras configuration used for retrieving
     *                           values from the configuration
     * @return the {@link CNN2DFormat} given the configuration
     * @throws InvalidKerasConfigurationException
     */
    public static Convolution3D.DataFormat getCNN3DDataFormatFromConfig(Map<String,Object> layerConfig, KerasLayerConfiguration layerConfiguration) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig,layerConfiguration);
        String dataFormat = innerConfig.containsKey(layerConfiguration.getLAYER_FIELD_DIM_ORDERING()) ?
                innerConfig.get(layerConfiguration.getLAYER_FIELD_DIM_ORDERING()).toString() : "channels_last";
        return dataFormat.equals("channels_last") ? Convolution3D.DataFormat.NDHWC : Convolution3D.DataFormat.NCDHW;

    }

    /**
     * Return the {@link CNN2DFormat}
     * from the configuration .
     * If the value is {@link KerasLayerConfiguration#getDIM_ORDERING_TENSORFLOW()}
     * then the value is {@link CNN2DFormat#NHWC}
     * else it's {@link KerasLayerConfiguration#getDIM_ORDERING_THEANO()}
     * which is {@link CNN2DFormat#NCHW}
     * @param layerConfig the layer configuration to get the values from
     * @param layerConfiguration the keras configuration used for retrieving
     *                           values from the configuration
     * @return the {@link CNN2DFormat} given the configuration
     * @throws InvalidKerasConfigurationException
     */
    public static CNN2DFormat getDataFormatFromConfig(Map<String,Object> layerConfig,KerasLayerConfiguration layerConfiguration) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig,layerConfiguration);
        String dataFormat = innerConfig.containsKey(layerConfiguration.getLAYER_FIELD_DIM_ORDERING()) ?
                innerConfig.get(layerConfiguration.getLAYER_FIELD_DIM_ORDERING()).toString() : "channels_last";
        return dataFormat.equals("channels_last") ? CNN2DFormat.NHWC : CNN2DFormat.NCHW;

    }

    /**
     * Get upsampling size from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     *
     * @return Upsampling integer array from Keras config
     * @throws InvalidKerasConfigurationException Invalid Keras configuration
     */
    static int[] getUpsamplingSizeFromConfig(Map<String, Object> layerConfig, int dimension,
                                             KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        int[] size;
        if (innerConfig.containsKey(conf.getLAYER_FIELD_UPSAMPLING_2D_SIZE()) && dimension == 2
                || innerConfig.containsKey(conf.getLAYER_FIELD_UPSAMPLING_3D_SIZE()) && dimension == 3) {
            @SuppressWarnings("unchecked")
            List<Integer> sizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_UPSAMPLING_2D_SIZE());
            size = ArrayUtil.toArray(sizeList);
        } else if (innerConfig.containsKey(conf.getLAYER_FIELD_UPSAMPLING_1D_SIZE()) && dimension == 1) {
            int upsamplingSize1D = (int) innerConfig.get(conf.getLAYER_FIELD_UPSAMPLING_1D_SIZE());
            size = new int[]{upsamplingSize1D};
        } else {
            throw new InvalidKerasConfigurationException("Could not determine kernel size: no "
                    + conf.getLAYER_FIELD_UPSAMPLING_1D_SIZE() + ", "
                    + conf.getLAYER_FIELD_UPSAMPLING_2D_SIZE());
        }
        return size;
    }


    /**
     * Get (convolution) kernel size from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     *
     * @return Convolutional kernel sizes
     * @throws InvalidKerasConfigurationException Invalid Keras config
     */
    public static int[] getKernelSizeFromConfig(Map<String, Object> layerConfig, int dimension,
                                                KerasLayerConfiguration conf, int kerasMajorVersion)
            throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        int[] kernelSize;
        if (kerasMajorVersion != 2) {
            if (innerConfig.containsKey(conf.getLAYER_FIELD_NB_ROW()) && dimension == 2
                    && innerConfig.containsKey(conf.getLAYER_FIELD_NB_COL())) {
                /* 2D Convolutional layers. */
                List<Integer> kernelSizeList = new ArrayList<>();
                kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_NB_ROW()));
                kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_NB_COL()));
                kernelSize = ArrayUtil.toArray(kernelSizeList);
            } else if (innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_1()) && dimension == 3
                    && innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_2())
                    && innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_3())) {
                /* 3D Convolutional layers. */
                List<Integer> kernelSizeList = new ArrayList<>();
                kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_1()));
                kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_2()));
                kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_3()));
                kernelSize = ArrayUtil.toArray(kernelSizeList);
            } else if (innerConfig.containsKey(conf.getLAYER_FIELD_FILTER_LENGTH()) && dimension == 1) {
                /* 1D Convolutional layers. */
                int filterLength = (int) innerConfig.get(conf.getLAYER_FIELD_FILTER_LENGTH());
                kernelSize = new int[]{filterLength};
            } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_SIZE()) && dimension >= 2) {
                /* 2D/3D Pooling layers. */
                @SuppressWarnings("unchecked")
                List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_SIZE());
                kernelSize = ArrayUtil.toArray(kernelSizeList);
            } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_SIZE()) && dimension == 1) {
                /* 1D Pooling layers. */
                int poolSize1D = (int) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_SIZE());
                kernelSize = new int[]{poolSize1D};
            } else {
                throw new InvalidKerasConfigurationException("Could not determine kernel size: no "
                        + conf.getLAYER_FIELD_NB_ROW() + ", "
                        + conf.getLAYER_FIELD_NB_COL() + ", or "
                        + conf.getLAYER_FIELD_FILTER_LENGTH() + ", or "
                        + conf.getLAYER_FIELD_POOL_1D_SIZE() + ", or "
                        + conf.getLAYER_FIELD_POOL_SIZE() + " field found");
            }
        } else {
            /* 2D/3D Convolutional layers. */
            if (innerConfig.containsKey(conf.getLAYER_FIELD_KERNEL_SIZE()) && dimension >= 2) {
                @SuppressWarnings("unchecked")
                List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_KERNEL_SIZE());
                kernelSize = ArrayUtil.toArray(kernelSizeList);
            } else if (innerConfig.containsKey(conf.getLAYER_FIELD_FILTER_LENGTH()) && dimension == 1) {
                /* 1D Convolutional layers. */
                @SuppressWarnings("unchecked")
                List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_FILTER_LENGTH());
                kernelSize = ArrayUtil.toArray(kernelSizeList);
            } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_SIZE()) && dimension >= 2) {
                /* 2D Pooling layers. */
                @SuppressWarnings("unchecked")
                List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_SIZE());
                kernelSize = ArrayUtil.toArray(kernelSizeList);
            } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_SIZE()) && dimension == 1) {
                /* 1D Pooling layers. */
                @SuppressWarnings("unchecked")
                List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_SIZE());
                kernelSize = ArrayUtil.toArray(kernelSizeList);
            } else {
                throw new InvalidKerasConfigurationException("Could not determine kernel size: no "
                        + conf.getLAYER_FIELD_KERNEL_SIZE() + ", or "
                        + conf.getLAYER_FIELD_FILTER_LENGTH() + ", or "
                        + conf.getLAYER_FIELD_POOL_SIZE() + " field found");
            }
        }

        return kernelSize;
    }

    /**
     * Get convolution border mode from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return Border mode of convolutional layers
     * @throws InvalidKerasConfigurationException Invalid Keras configuration
     */
    public static ConvolutionMode getConvolutionModeFromConfig(Map<String, Object> layerConfig,
                                                               KerasLayerConfiguration conf)
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (!innerConfig.containsKey(conf.getLAYER_FIELD_BORDER_MODE()))
            throw new InvalidKerasConfigurationException("Could not determine convolution border mode: no "
                    + conf.getLAYER_FIELD_BORDER_MODE() + " field found");
        String borderMode = (String) innerConfig.get(conf.getLAYER_FIELD_BORDER_MODE());
        ConvolutionMode convolutionMode;
        if (borderMode.equals(conf.getLAYER_BORDER_MODE_SAME())) {
            /* Keras relies upon the Theano and TensorFlow border mode definitions and operations:
             * TH: http://deeplearning.net/software/theano/library/tensor/nnet/conv.html#theano.tensor.nnet.conv.conv2d
             * TF: https://www.tensorflow.org/api_docs/python/nn/convolution#conv2d
             */
            convolutionMode = ConvolutionMode.Same;

        } else if (borderMode.equals(conf.getLAYER_BORDER_MODE_VALID()) ||
                borderMode.equals(conf.getLAYER_BORDER_MODE_FULL())) {
            convolutionMode = ConvolutionMode.Truncate;
        } else if(borderMode.equals(conf.getLAYER_BORDER_MODE_CAUSAL())) {
            convolutionMode = ConvolutionMode.Causal;
        } else {
            throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode);
        }
        return convolutionMode;
    }

    /**
     * Get (convolution) padding from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @return Padding values derived from border mode
     * @throws InvalidKerasConfigurationException Invalid Keras config
     */
    public static int[] getPaddingFromBorderModeConfig(Map<String, Object> layerConfig, int dimension,
                                                       KerasLayerConfiguration conf, int kerasMajorVersion)
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        int[] padding = null;
        if (!innerConfig.containsKey(conf.getLAYER_FIELD_BORDER_MODE()))
            throw new InvalidKerasConfigurationException("Could not determine convolution border mode: no "
                    + conf.getLAYER_FIELD_BORDER_MODE() + " field found");
        String borderMode = (String) innerConfig.get(conf.getLAYER_FIELD_BORDER_MODE());
        if (borderMode.equals(conf.getLAYER_FIELD_BORDER_MODE())) {
            padding = getKernelSizeFromConfig(layerConfig, dimension, conf, kerasMajorVersion);
            for (int i = 0; i < padding.length; i++)
                padding[i]--;
        }
        return padding;
    }

    /**
     * Get padding and cropping configurations from Keras layer configuration.
     *
     * @param layerConfig dictionary containing Keras layer configuration
     * @param conf        KerasLayerConfiguration
     * @param layerField  String value of the layer config name to check for (e.g. "padding" or "cropping")
     * @param dimension   Dimension of the padding layer
     * @return padding list of integers
     * @throws InvalidKerasConfigurationException Invalid keras configuration
     */
    static int[] getPaddingFromConfig(Map<String, Object> layerConfig,
                                      KerasLayerConfiguration conf,
                                      String layerField,
                                      int dimension)
            throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (!innerConfig.containsKey(layerField))
            throw new InvalidKerasConfigurationException(
                    "Field " + layerField + " not found in Keras cropping or padding layer");
        int[] padding;
        if (dimension >= 2) {
            List<Integer> paddingList;
            // For 2D layers, padding/cropping can either be a pair [[x_0, x_1].[y_0, y_1]] or a pair [x, y]
            // or a single integer x. Likewise for the 3D case.
            try {
                List paddingNoCast = (List) innerConfig.get(layerField);
                boolean isNested;
                try {
                    @SuppressWarnings("unchecked")
                    List<Integer> firstItem = (List<Integer>) paddingNoCast.get(0);
                    isNested = true;
                    paddingList = new ArrayList<>(2 * dimension);
                } catch (Exception e) {
                    int firstItem = (int) paddingNoCast.get(0);
                    isNested = false;
                    paddingList = new ArrayList<>(dimension);
                }

                if ((paddingNoCast.size() == dimension) && !isNested) {
                    for (int i = 0; i < dimension; i++)
                        paddingList.add((int) paddingNoCast.get(i));
                    padding = ArrayUtil.toArray(paddingList);
                } else if ((paddingNoCast.size() == dimension) && isNested) {
                    for (int j = 0; j < dimension; j++) {
                        @SuppressWarnings("unchecked")
                        List<Integer> item = (List<Integer>) paddingNoCast.get(j);
                        paddingList.add((item.get(0)));
                        paddingList.add((item.get(1)));
                    }

                    padding = ArrayUtil.toArray(paddingList);
                } else {
                    throw new InvalidKerasConfigurationException("Found Keras ZeroPadding" + dimension
                            + "D layer with invalid " + paddingList.size() + "D padding.");
                }
            } catch (Exception e) {
                int paddingInt = (int) innerConfig.get(layerField);
                if (dimension == 2) {
                    padding = new int[]{paddingInt, paddingInt, paddingInt, paddingInt};
                } else {
                    padding = new int[]{paddingInt, paddingInt, paddingInt, paddingInt, paddingInt, paddingInt};
                }
            }

        } else if (dimension == 1) {
            Object paddingObj  = innerConfig.get(layerField);
            if (paddingObj instanceof List){
                List<Integer> paddingList = (List)paddingObj;
                padding = new int[]{
                        paddingList.get(0),
                        paddingList.get(1)
                };
            }
            else{
                int paddingInt = (int) innerConfig.get(layerField);
                padding = new int[]{paddingInt, paddingInt};
            }

        } else {
            throw new UnsupportedKerasConfigurationException(
                    "Keras padding layer not supported");
        }
        return padding;
    }
}