
View on GitHub


1 day
Test Coverage
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  *
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************

package org.deeplearning4j.nn.conf.inputs;

import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.nd4j.common.util.OneTimeLogger;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;

import java.util.Arrays;

@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public abstract class InputType implements Serializable {

     * The type of activations in/out of a given GraphVertex<br>
     * FF: Standard feed-foward (2d minibatch, 1d per example) data<br>
     * RNN: Recurrent neural network (3d minibatch) time series data<br>
     * CNN: 2D Convolutional neural network (4d minibatch, [miniBatchSize, channels, height, width])
     * CNNFlat: Flattened 2D conv net data (2d minibatch, [miniBatchSize, height * width * channels])
     * CNN3D: 3D convolutional neural network (5d minibatch, [miniBatchSize, channels, height, width, channels])
    public enum Type {
        FF, RNN, CNN, CNNFlat, CNN3D

    public static CNN2DFormat getDefaultCNN2DFormat() {
        return defaultCNN2DFormat;

    public static void setDefaultCNN2DFormat(CNN2DFormat defaultCNN2DFormat) {
        InputType.defaultCNN2DFormat = defaultCNN2DFormat;

    private static CNN2DFormat defaultCNN2DFormat = CNN2DFormat.NCHW;

    public abstract Type getType();

    public abstract String toString();

    public abstract long arrayElementsPerExample();

     * Returns the shape of this InputType
     * @param includeBatchDim Whether to include minibatch in the return shape array
     * @return int[]
    public abstract long[] getShape(boolean includeBatchDim);

     * Returns the shape of this InputType without minibatch dimension in the returned array
     * @return int[]
    public long[] getShape() {
        return getShape(false);

     * InputType for feed forward network data
     * @param size The size of the activations
     * @return InputTypeFeedForward
    public static InputType feedForward(long size) {
        return new InputTypeFeedForward(size, null);

    public static InputType feedForward(long size, DataFormat timeDistributedFormat) {
        return new InputTypeFeedForward(size,timeDistributedFormat);

     * InputType for recurrent neural network (time series) data
     * @param size The size of the activations
     * @return InputTypeRecurrent
    public static InputType recurrent(long size) {
        return new InputTypeRecurrent(size);

     * InputType for recurrent neural network (time series) data
     * @param size             The size of the activations
     * @param timeSeriesLength Length of the input time series
     * @return InputTypeRecurrent
    public static InputType recurrent(long size, long timeSeriesLength) {
        return new InputTypeRecurrent(size, timeSeriesLength, RNNFormat.NCW);

    public static InputType recurrent(long size, RNNFormat format){
        return new InputTypeRecurrent(size, format);

    public static InputType recurrent(long size, long timeSeriesLength, RNNFormat format){
        return new InputTypeRecurrent(size, timeSeriesLength, format);
     * Input type for convolutional (CNN) data, that is 4d with shape [miniBatchSize, channels, height, width].
     * For CNN data that has been flattened, use {@link #convolutionalFlat(long, long, long)}
     * @param height height of the input
     * @param width  Width of the input
     * @param depth  Depth, or number of channels
     * @return InputTypeConvolutional
    public static InputType convolutional(long height, long width, long depth) {
        return convolutional(height, width, depth, getDefaultCNN2DFormat());

    public static InputType convolutional(long height, long width, long depth, CNN2DFormat format){
        return new InputTypeConvolutional(height, width, depth, format);

     * Input type for 3D convolutional (CNN3D) data in NDHWC format, that is 5d with shape
     * [miniBatchSize, depth, height, width, channels].
     * @param height   height of the input
     * @param width    Width of the input
     * @param depth    Depth of the input
     * @param channels Number of channels of the input
     * @return InputTypeConvolutional3D
     * @deprecated Use {@link #convolutional3D(Convolution3D.DataFormat, long, long, long, long)}
    public static InputType convolutional3D(long depth, long height, long width,  long channels) {
        return convolutional3D(Convolution3D.DataFormat.NDHWC, depth, height, width, channels);

     * Input type for 3D convolutional (CNN3D) 5d data:<br>
     * If NDHWC format [miniBatchSize, depth, height, width, channels]<br>
     * If NDCWH
     * @param height   height of the input
     * @param width    Width of the input
     * @param depth    Depth of the input
     * @param channels Number of channels of the input
     * @return InputTypeConvolutional3D
    public static InputType convolutional3D(Convolution3D.DataFormat dataFormat, long depth, long height, long width, long channels) {
        return new InputTypeConvolutional3D(dataFormat, depth, height, width, channels);

     * Input type for convolutional (CNN) data, where the data is in flattened (row vector) format.
     * Expect data with shape [miniBatchSize, height * width * channels]. For CNN data in 4d format,
     * use {@link #convolutional(long, long, long)}
     * @param height Height of the (unflattened) data represented by this input type
     * @param width  Width of the (unflattened) data represented by this input type
     * @param depth  Depth of the (unflattened) data represented by this input type
     * @return InputTypeConvolutionalFlat
    public static InputType convolutionalFlat(long height, long width, long depth) {
        return new InputTypeConvolutionalFlat(height, width, depth);

    @EqualsAndHashCode(callSuper = false)
    public static class InputTypeFeedForward extends InputType {
        private long size;
        private DataFormat timeDistributedFormat;

        public InputTypeFeedForward(@JsonProperty("size") long size, @JsonProperty("timeDistributedFormat") DataFormat timeDistributedFormat) {
            if(size <= 0) {
                OneTimeLogger.warn(log,"Assigning a size of zero. This is normally only valid in model import cases with unknown dimensions.");
            this.size = size;
            this.timeDistributedFormat = timeDistributedFormat;

        public Type getType() {
            return Type.FF;

        public String toString() {
            return "InputTypeFeedForward(" + size + (timeDistributedFormat != null ? "," + timeDistributedFormat : "") + ")";

        public long arrayElementsPerExample() {
            return size;

        public long[] getShape(boolean includeBatchDim) {
            if(includeBatchDim) return new long[]{-1, size};
            else return new long[]{size};

    @EqualsAndHashCode(callSuper = false)
    public static class InputTypeRecurrent extends InputType {
        private long size;
        private long timeSeriesLength;
        private RNNFormat format = RNNFormat.NCW;
        public InputTypeRecurrent(long size) {
            this(size, -1);
        public InputTypeRecurrent(long size, long timeSeriesLength){
            this(size, timeSeriesLength, RNNFormat.NCW);

        public  InputTypeRecurrent(long size, RNNFormat format){
            this(size, -1, format);
        public InputTypeRecurrent(@JsonProperty("size") long size,
                                  @JsonProperty("timeSeriesLength") long timeSeriesLength,
                                  @JsonProperty("format") RNNFormat format) {
            this.size = size;
            this.timeSeriesLength = timeSeriesLength;
            this.format = format;

        public Type getType() {
            return Type.RNN;

        public String toString() {
            if (timeSeriesLength > 0) {
                return "InputTypeRecurrent(" + size + ",timeSeriesLength=" + timeSeriesLength + ",format=" + format + ")";
            } else {
                return "InputTypeRecurrent(" + size + ",format=" + format + ")";

        public long arrayElementsPerExample() {
            if (timeSeriesLength <= 0) {
                throw new IllegalStateException("Cannot calculate number of array elements per example: "
                        + "time series length is not set. Use InputType.recurrent(int size, int timeSeriesLength) instead?");
            return timeSeriesLength * size;

        public long[] getShape(boolean includeBatchDim) {
            if (includeBatchDim){
                if (format == RNNFormat.NCW) {
                    return new long[]{-1, size, timeSeriesLength};
                    return new long[]{-1, timeSeriesLength, size};

                if (format == RNNFormat.NCW) {
                    return new long[]{size, timeSeriesLength};
                    return new long[]{timeSeriesLength, size};

    @EqualsAndHashCode(callSuper = false)
    public static class InputTypeConvolutional extends InputType {
        private long height;
        private long width;
        private long channels;
        private CNN2DFormat format = CNN2DFormat.NCHW;  //Default for JSON deserialization of older configurations

        public InputTypeConvolutional(@JsonProperty("height") long height, @JsonProperty("width") long width,
                                      @JsonProperty("channels") long channels, @JsonProperty("format") CNN2DFormat format) {
            if(height <= 0) {
                OneTimeLogger.warn(log,"Assigning height of 0. Normally this is not valid. Exceptions for this are generally related" +
                        "to model import and unknown dimensions");

            if(width <= 0) {
                OneTimeLogger.warn(log,"Assigning width of 0. Normally this is not valid. Exceptions for this are generally related" +
                        "to model import and unknown dimensions");

            if(channels <= 0) {
                OneTimeLogger.warn(log,"Assigning channels of 0. Normally this is not valid. Exceptions for this are generally related" +
                        "to model import and unknown dimensions");

            this.height = height;
            this.width = width;
            this.channels = channels;
            if(format != null)
                this.format = format;

        public InputTypeConvolutional(long height, long width, long channels) {
            this(height, width, channels, CNN2DFormat.NCHW);

         * Return the number of channels / depth for this 2D convolution. This method has been deprecated,
         * for consistency purposes, use getChannels() instead.
         * @return number of channels, i.e. depth for 2D convolutions
        public long getDepth() {
            return channels;

         * Set the number of channels / depth for this 2D convolution. This method has been deprecated,
         * for consistency purposes, use setChannels(channels) instead.
        public void setDepth(long depth) {
            this.channels = depth;

        public Type getType() {
            return Type.CNN;

        public String toString() {
            return "InputTypeConvolutional(h=" + height + ",w=" + width + ",c=" + channels + "," + format + ")";

        public long arrayElementsPerExample() {
            return height * width * channels;

        public long[] getShape(boolean includeBatchDim) {
            if(format == CNN2DFormat.NCHW){
                if(includeBatchDim) return new long[]{-1, channels, height, width};
                else return new long[]{channels, height, width};
            } else {
                if(includeBatchDim) return new long[]{-1, height, width, channels};
                else return new long[]{height, width, channels};

    @EqualsAndHashCode(callSuper = false)
    public static class InputTypeConvolutional3D extends InputType {
        private Convolution3D.DataFormat dataFormat;
        private long depth;
        private long height;
        private long width;
        private long channels;

        public InputTypeConvolutional3D(@JsonProperty("dataFormat") Convolution3D.DataFormat dataFormat,
                                        @JsonProperty("depth") long depth, @JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("channels") long channels) {
            this.dataFormat = dataFormat;
            this.depth = depth;
            this.height = height;
            this.width = width;
            this.channels = channels;

        public Type getType() {
            return Type.CNN3D;

        public String toString() {
            return "InputTypeConvolutional3D(format=" + dataFormat + ",d=" + depth + ",h=" + height + ",w=" + width + ",c=" + channels + ")";

        public long arrayElementsPerExample() {
            return height * width * depth * channels;

        public long[] getShape(boolean includeBatchDim) {
            if(dataFormat == Convolution3D.DataFormat.NDHWC){
                if(includeBatchDim) return new long[]{-1, depth, height, width, channels};
                else return new long[]{depth, height, width, channels};
            } else {
                if(includeBatchDim) return new long[]{-1, channels, depth, height, width};
                else return new long[]{channels, depth, height, width};

    @EqualsAndHashCode(callSuper = false)
    public static class InputTypeConvolutionalFlat extends InputType {
        private long height;
        private long width;
        private long depth;

        public InputTypeConvolutionalFlat(@JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("depth") long depth) {
            this.height = height;
            this.width = width;
            this.depth = depth;

        public Type getType() {
            return Type.CNNFlat;

        public long getFlattenedSize() {
            return height * width * depth;

        public InputType getUnflattenedType() {
            return InputType.convolutional(height, width, depth);

        public String toString() {
            return "InputTypeConvolutionalFlat(h=" + height + ",w=" + width + ",d=" + depth + ")";

        public long arrayElementsPerExample() {
            return height * width * depth;

        public long[] getShape(boolean includeBatchDim) {
            if(includeBatchDim) return new long[]{-1, depth, height, width};
            else return new long[]{depth, height, width};

    public static InputType inferInputType(INDArray inputArray) {
        //Note: ConvolutionalFlat and FeedForward look identical... but either should work OK if using something
        // like FeedForwardToCnnPreProcessor

        switch (inputArray.rank()) {
            case 2:
                return InputType.feedForward(inputArray.size(1));
            case 3:
                return InputType.recurrent(inputArray.size(1), (int) inputArray.size(2));
            case 4:
                //Order: [minibatch, channels, height, width] -> [h, w, c]
                return InputType.convolutional(inputArray.size(2), (int) inputArray.size(3), (int) inputArray.size(1));
            case 5:
                //Order: [minibatch, channels, depth, height, width] -> [d, h, w, c]
                return InputType.convolutional3D(inputArray.size(2), (int) inputArray.size(3),
                        (int) inputArray.size(4), (int) inputArray.size(1));
                throw new IllegalArgumentException(
                        "Cannot infer input type for array with shape: " + Arrays.toString(inputArray.shape()));

    public static InputType[] inferInputTypes(INDArray... inputArrays) {
        InputType[] out = new InputType[inputArrays.length];
        for (int i = 0; i < inputArrays.length; i++) {
            out[i] = inferInputType(inputArrays[i]);

        return out;
