deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java

Summary

Maintainability
F
1 wk
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.conf;

import lombok.*;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.OutputLayerUtil;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.Serializable;
import java.util.*;

@Data
@EqualsAndHashCode(exclude = {"trainingWorkspaceMode", "inferenceWorkspaceMode", "cacheMode", "topologicalOrder", "topologicalOrderStr"})
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@NoArgsConstructor
public class ComputationGraphConfiguration implements Serializable, Cloneable {
    private static Logger log = LoggerFactory.getLogger(ComputationGraphConfiguration.class);

    protected Map<String, GraphVertex> vertices = new LinkedHashMap<>();
    protected Map<String, List<String>> vertexInputs = new LinkedHashMap<>();

    @Getter
    @Setter
    protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;

    @Getter
    @Setter
    protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;

    @Getter
    @Setter
    protected CacheMode cacheMode;

    @Getter
    @Setter
    protected DataType dataType = DataType.FLOAT;   //Default to float for 1.0.0-beta3 and earlier nets

    protected boolean validateOutputLayerConfig = true;     //Default for 1.0.0-beta3 and earlier nets

    /**
     * List of inputs to the network, by name
     */
    protected List<String> networkInputs;

    /**
     * List of network outputs, by name
     */
    protected List<String> networkOutputs;
    protected BackpropType backpropType = BackpropType.Standard;
    protected int tbpttFwdLength = 20;
    protected int tbpttBackLength = 20;

    protected NeuralNetConfiguration defaultConfiguration;

    //Counter for the number of parameter updates so far
    // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted
    // for Spark and model serialization
    protected int iterationCount = 0;

    //Counter for the number of epochs completed so far. Used for per-epoch schedules
    protected int epochCount = 0;

    protected int[] topologicalOrder;
    protected List<String> topologicalOrderStr;

    /**
     * @return YAML representation of configuration
     */
    public String toYaml() {
        ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
        synchronized (mapper) {
            try {
                return mapper.writeValueAsString(this);
            } catch (org.nd4j.shade.jackson.core.JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        }
    }

    /**
     * Create a neural net configuration from YAML
     *
     * @param json the neural net configuration from YAML
     * @return {@link ComputationGraphConfiguration}
     */
    public static ComputationGraphConfiguration fromYaml(String json) {
        ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
        try {
            return mapper.readValue(json, ComputationGraphConfiguration.class);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * @return JSON representation of computation graph configuration
     */
    public String toJson() {
        //As per MultiLayerConfiguration.toJson()
        ObjectMapper mapper = NeuralNetConfiguration.mapper();
        synchronized (mapper) {
            //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
            //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
            try {
                return mapper.writeValueAsString(this);
            } catch (org.nd4j.shade.jackson.core.JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        }
    }

    /**
     * Create a computation graph configuration from json
     *
     * @param json the neural net configuration from json
     * @return {@link ComputationGraphConfiguration}
     */
    public static ComputationGraphConfiguration fromJson(String json) {
        //As per MultiLayerConfiguration.fromJson()
        ObjectMapper mapper = NeuralNetConfiguration.mapper();
        ComputationGraphConfiguration conf;
        try {
            conf = mapper.readValue(json, ComputationGraphConfiguration.class);
        } catch (InvalidTypeIdException e){
            if(e.getMessage().contains("@class")){
                try{
                    //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
                    return JsonMappers.getLegacyMapper().readValue(json, ComputationGraphConfiguration.class);
                } catch (InvalidTypeIdException e2){
                    //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..."
                    //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work
                    String msg = e2.getMessage();
                    if(msg != null && msg.contains("Could not resolve type id")){
                        throw new RuntimeException("Error deserializing ComputationGraphConfiguration - configuration may have a custom " +
                                "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" +
                                " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", e);
                    }
                    throw new RuntimeException(e2);
                } catch (IOException e2){
                    throw new RuntimeException(e2);
                }
            }
            throw new RuntimeException(e);
        } catch (Exception e) {
            //Check if this exception came from legacy deserializer...
            String msg = e.getMessage();
            if(msg != null && msg.contains("legacy")){
                throw new RuntimeException("Error deserializing ComputationGraphConfiguration - configuration may have a custom " +
                        "layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be " +
                        "deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)", e);
            }
            throw new RuntimeException(e);
        }

        //To maintain backward compatibility after activation function refactoring (configs generated with v0.7.1 or earlier)
        // Previously: enumeration used for activation functions. Now: use classes
        int layerCount = 0;
        Map<String, GraphVertex> vertexMap = conf.getVertices();
        JsonNode vertices = null;
        for (Map.Entry<String, GraphVertex> entry : vertexMap.entrySet()) {
            if (!(entry.getValue() instanceof LayerVertex)) {
                continue;
            }

            LayerVertex lv = (LayerVertex) entry.getValue();
            if (lv.getLayerConf() != null && lv.getLayerConf().getLayer() != null) {
                Layer layer = lv.getLayerConf().getLayer();

                if (layer instanceof BaseLayer && ((BaseLayer) layer).getActivationFn() == null) {
                    String layerName = layer.getLayerName();

                    try {
                        if (vertices == null) {
                            JsonNode jsonNode = mapper.readTree(json);
                            vertices = jsonNode.get("vertices");
                        }

                        JsonNode vertexNode = vertices.get(layerName);
                        JsonNode layerVertexNode = vertexNode.get("LayerVertex");
                        if (layerVertexNode == null || !layerVertexNode.has("layerConf")
                                || !layerVertexNode.get("layerConf").has("layer")) {
                            continue;
                        }
                        JsonNode layerWrapperNode = layerVertexNode.get("layerConf").get("layer");

                        if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
                            continue;
                        }

                        JsonNode layerNode = layerWrapperNode.elements().next();
                        JsonNode activationFunction = layerNode.get("activationFunction"); //Should only have 1 element: "dense", "output", etc

                        if (activationFunction != null) {
                            IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
                            ((BaseLayer) layer).setActivationFn(ia);
                        }

                    } catch (IOException e) {
                        log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",
                                e);
                    }
                }

                handleLegacyWeightInitFromJson(json, layer, mapper, vertices);
            }
        }

        return conf;
    }

    /**
     * Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied from handling of {@link Activation}
     * above.
     * @return True if all is well and layer iteration shall continue. False else-wise.
     */
    private static void handleLegacyWeightInitFromJson(String json, Layer layer, ObjectMapper mapper, JsonNode vertices) {
        if (layer instanceof BaseLayer && ((BaseLayer) layer).getWeightInitFn() == null) {
            String layerName = layer.getLayerName();

            try {
                if (vertices == null) {
                    JsonNode jsonNode = mapper.readTree(json);
                    vertices = jsonNode.get("vertices");
                }

                JsonNode vertexNode = vertices.get(layerName);
                JsonNode layerVertexNode = vertexNode.get("LayerVertex");
                if (layerVertexNode == null || !layerVertexNode.has("layerConf")
                        || !layerVertexNode.get("layerConf").has("layer")) {
                    return;
                }
                JsonNode layerWrapperNode = layerVertexNode.get("layerConf").get("layer");

                if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
                    return;
                }

                JsonNode layerNode = layerWrapperNode.elements().next();
                JsonNode weightInit = layerNode.get("weightInit"); //Should only have 1 element: "dense", "output", etc
                JsonNode distribution = layerNode.get("dist");

                Distribution dist = null;
                if(distribution != null) {
                    dist = mapper.treeToValue(distribution, Distribution.class);
                }

                if (weightInit != null) {
                    final IWeightInit wi = WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist);
                    ((BaseLayer) layer).setWeightInitFn(wi);
                }

            } catch (IOException e) {
                log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",
                        e);
            }
        }
    }

    @Override
    public String toString() {
        return toJson();
    }

    @Override
    public ComputationGraphConfiguration clone() {
        ComputationGraphConfiguration conf = new ComputationGraphConfiguration();

        conf.vertices = new LinkedHashMap<>();
        for (Map.Entry<String, GraphVertex> entry : this.vertices.entrySet()) {
            conf.vertices.put(entry.getKey(), entry.getValue().clone());
        }

        conf.vertexInputs = new LinkedHashMap<>();
        for (Map.Entry<String, List<String>> entry : this.vertexInputs.entrySet()) {
            conf.vertexInputs.put(entry.getKey(), new ArrayList<>(entry.getValue()));
        }

        conf.networkInputs = new ArrayList<>();

        conf.networkInputs = new ArrayList<>(this.networkInputs);
        conf.networkOutputs = new ArrayList<>(this.networkOutputs);

        conf.backpropType = backpropType;
        conf.tbpttFwdLength = tbpttFwdLength;
        conf.tbpttBackLength = tbpttBackLength;
        conf.defaultConfiguration = defaultConfiguration.clone();
        conf.trainingWorkspaceMode = trainingWorkspaceMode;
        conf.inferenceWorkspaceMode = inferenceWorkspaceMode;
        conf.cacheMode = this.cacheMode;
        conf.defaultConfiguration.cacheMode = this.cacheMode;
        conf.validateOutputLayerConfig = this.validateOutputLayerConfig;
        conf.dataType = this.dataType;

        return conf;
    }


    /**
     * Check the configuration, make sure it is valid
     *
     * @throws IllegalStateException if configuration is not valid
     */
    public void validate() {
        validate(false, false);
    }

    /**
     * Check the configuration, make sure it is valid
     *
     * @param allowDisconnected If true: don't throw an exception on vertices that are 'disconnected'. A disconnected
     *                          vertex is one that is not an output, and doesn't connect to any other vertices. i.e.,
     *                          it's output activations don't go anywhere
     * @throws IllegalStateException if configuration is not valid
     */
    public void validate(boolean allowDisconnected, boolean allowNoOutput){

        if (networkInputs == null || networkInputs.isEmpty()) {
            throw new IllegalStateException( "Invalid configuration: network has no inputs. " +
                    "Use .addInputs(String...) to label (and give an ordering to) the network inputs");
        }
        if ((networkOutputs == null || networkOutputs.isEmpty()) && !allowNoOutput) {
            throw new IllegalStateException("Invalid configuration: network has no outputs. " +
                    "Use .setOutput(String...) to specify (and give an ordering to) the output vertices, " +
                    "or use allowNoOutputs(true) to disable this check");
        }

        //Check uniqueness of names for inputs, layers, GraphNodes
        for (String s : networkInputs) {
            if (vertices.containsKey(s)) {
                throw new IllegalStateException("Invalid configuration: name \"" + s
                        + "\" is present in both network inputs and graph vertices/layers");
            }
        }

        //Check: each layer & node has at least one input
        for (Map.Entry<String, List<String>> e : vertexInputs.entrySet()) {
            String nodeName = e.getKey();
            if (e.getValue() == null || e.getValue().isEmpty()) {
                throw new IllegalStateException("Invalid configuration: vertex \"" + nodeName + "\" has no inputs");
            }
            for (String inputName : e.getValue()) {
                if (!vertices.containsKey(inputName) && !networkInputs.contains(inputName)) {
                    throw new IllegalStateException("Invalid configuration: Vertex \"" + nodeName + "\" has input \""
                            + inputName + "\" that does not exist");
                }
            }
        }

        //Check output names:
        if(networkOutputs != null) {
            for (String s : networkOutputs) {
                if (!vertices.containsKey(s)) {
                    throw new IllegalStateException(
                            "Invalid configuration: Output name \"" + s + "\" is not a valid vertex");
                }
            }
        }

        //Check that there aren't any disconnected vertices
        if(!allowDisconnected){
            //A vertex is considered disconnected if it is (a) not an output vertex, and (b) isn't used an as input
            // to another layer

            Set<String> seenAsInput = new HashSet<>();
            seenAsInput.addAll(networkOutputs);
            for(Map.Entry<String,List<String>> e : vertexInputs.entrySet()){
                seenAsInput.addAll(e.getValue());
            }

            Set<String> disconnected = new HashSet<>();
            disconnected.addAll(networkInputs);
            disconnected.addAll(vertices.keySet());
            disconnected.removeAll(seenAsInput);
            if(!disconnected.isEmpty() && !allowNoOutput){  //If allowing no output: by definition we have disconnected vertices
                throw new IllegalStateException("Invalid configuration: disconnected vertices found - " + disconnected
                        + ". Disconnected vertices are those that do not connect to either another vertex, and are also"
                        + " not a network output. This vertex can be set as an output using setOutputs(String...). "
                        + "To disable this error (i.e., allow network configurations with" +
                        " disconnected vertices) use GraphBuilder.allowDisconnected(true)");
            }
        }

        //Check for no graph cycles: done in ComputationGraph.init()
    }

    /**
     * Add preprocessors automatically, given the specified types of inputs for the network. Inputs are specified using the
     * {@link InputType} class, in the same order in which the inputs were defined in the original configuration.<br>
     * For example, in a network with two inputs: a convolutional input (28x28x1 images) and feed forward inputs, use
     * {@code .addPreProcessors(InputType.convolutional(28,28,1),InputType.feedForward())}.<br>
     * For the CNN->Dense and CNN->RNN transitions, the nIns on the Dense/RNN layers will also be added automatically.
     * <b>NOTE</b>: This method will be called automatically when using the
     * {@link GraphBuilder#setInputTypes(InputType...)} functionality.
     * See that method for details.
     */
    public void addPreProcessors(InputType... inputTypes) {
        getLayerActivationTypes(true, inputTypes);
    }

    /**
     * Add preprocessors automatically, given the specified types of inputs for the network. Inputs are specified using the
     * {@link InputType} class, in the same order in which the inputs were defined in the original configuration.<br>
     * For example, in a network with two inputs: a convolutional input (28x28x1 images) and feed forward inputs, use
     * {@code .addPreProcessors(InputType.convolutional(28,28,1),InputType.feedForward())}.<br>
     * For the CNN->Dense and CNN->RNN transitions, the nIns on the Dense/RNN layers will also be added automatically.
     * <b>NOTE</b>: This method will be called automatically when using the
     * {@link GraphBuilder#setInputTypes(InputType...)} functionality.
     * See that method for details.
     * @param forceOverrideInputs  whether to forcibly over ride inputs or not
     *                             when setting up pre processing
     * @param inputTypes  the input types to set
     */
    public void addPreProcessors(boolean addPreprocIfNecessary,boolean forceOverrideInputs,InputType... inputTypes) {
        getLayerActivationTypes(addPreprocIfNecessary,forceOverrideInputs, inputTypes);
    }

    /**
     * Add preprocessors automatically, given the specified types of inputs for the network. Inputs are specified using the
     * {@link InputType} class, in the same order in which the inputs were defined in the original configuration.<br>
     * For example, in a network with two inputs: a convolutional input (28x28x1 images) and feed forward inputs, use
     * {@code .addPreProcessors(InputType.convolutional(28,28,1),InputType.feedForward())}.<br>
     * For the CNN->Dense and CNN->RNN transitions, the nIns on the Dense/RNN layers will also be added automatically.
     * <b>NOTE</b>: This method will be called automatically when using the
     * {@link GraphBuilder#setInputTypes(InputType...)} functionality.
     * See that method for details.
     * @param forceOverrideInputs  whether to forcibly over ride inputs or not
     *                             when setting up pre processing
     * @param inputTypes  the input types to set
     */
    public void addPreProcessors(boolean forceOverrideInputs,InputType... inputTypes) {
        getLayerActivationTypes(true,forceOverrideInputs, inputTypes);
    }



    /**
     * For the given input shape/type for the network, return a map of activation sizes for each layer and vertex
     * in the graph. Note that this method will automatically add preprocessors if required, to handle (for example)
     * the transition between CNN and dense layers.
     * @param inputTypes                Input types for the network
     * @return A map of activation types for the graph (key: vertex name. value: type of activations out of that vertex)
     */
    public Map<String,InputType> getLayerActivationTypes(InputType... inputTypes) {
        return getLayerActivationTypes(true, inputTypes);
    }

    /**
     * For the given input shape/type for the network, return a map of activation sizes for each layer and vertex
     * in the graph. Note that this method can also add preprocessors if required (to handle transitions between some
     * layer types such as convolutional -> dense, for example)
     * @param addPreprocIfNecessary     If true: add any required preprocessors, in the process of calculating the layer
     *                                  activation sizes
     * @param overrideInputs            whether to forcibly over ride inputs when
     *                                  setting inputs
     * @param inputTypes                Input types for the network
     * @return A map of activation types for the graph (key: vertex name. value: type of activations out of that vertex)
     */
    public Map<String,InputType> getLayerActivationTypes(boolean addPreprocIfNecessary,boolean overrideInputs, InputType... inputTypes) {

        if (inputTypes == null || inputTypes.length != networkInputs.size()) {
            throw new IllegalArgumentException(
                    "Invalid number of InputTypes: cannot add preprocessors if number of InputType "
                            + "objects differs from number of network inputs");
        }

        //Now: need to do essentially a forward pass through the network, to work out what type of preprocessors to add
        //To do this: need to know what the output types are for each GraphVertex.

        //Do topological sort
        List<String> topologicalOrdering = topologicalOrdering();

        //Now, given the topological sort: do equivalent of forward pass
        Map<String, InputType> vertexOutputs = new LinkedHashMap<>();
        int currLayerIdx = -1;
        for (String s : topologicalOrdering) {
            int inputIdx = networkInputs.indexOf(s);
            if (inputIdx != -1) {
                vertexOutputs.put(s, inputTypes[inputIdx]);
                continue;
            }

            GraphVertex gv = vertices.get(s);

            List<InputType> inputTypeList = new ArrayList<>();

            if (gv instanceof LayerVertex) {
                //Add preprocessor, if necessary:
                String in = vertexInputs.get(s).get(0);
                InputType layerInput = vertexOutputs.get(in);
                inputTypeList.add(layerInput);

                LayerVertex lv = (LayerVertex) gv;
                Layer l = lv.getLayerConf().getLayer();

                //Preprocessors - add if necessary
                if (lv.getPreProcessor() == null) {
                    //But don't override preprocessors that are manually defined; if none has been defined,
                    //add the appropriate preprocessor for this input type/layer combination
                    InputPreProcessor preproc = l.getPreProcessorForInputType(layerInput);
                    lv.setPreProcessor(preproc);
                }

                //Set nIn value for layer (if not already set)
                InputType afterPreproc = layerInput;
                if (lv.getPreProcessor() != null  && addPreprocIfNecessary) {
                    InputPreProcessor ip = lv.getPreProcessor();
                    afterPreproc = ip.getOutputType(layerInput);
                }

                l.setNIn(afterPreproc, overrideInputs);

                currLayerIdx++;
            } else {
                List<String> inputs = vertexInputs.get(s);
                if (inputs != null) {
                    for (String inputVertexName : inputs) {
                        inputTypeList.add(vertexOutputs.get(inputVertexName));
                    }
                }
            }

            InputType outputFromVertex =
                    gv.getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
            vertexOutputs.put(s, outputFromVertex);
        }

        return vertexOutputs;
    }

    /**
     * For the given input shape/type for the network, return a map of activation sizes for each layer and vertex
     * in the graph. Note that this method can also add preprocessors if required (to handle transitions between some
     * layer types such as convolutional -> dense, for example)
     * @param addPreprocIfNecessary     If true: add any required preprocessors, in the process of calculating the layer
     *                                  activation sizes
     * @param inputTypes                Input types for the network
     * @return A map of activation types for the graph (key: vertex name. value: type of activations out of that vertex)
     */
    public Map<String,InputType> getLayerActivationTypes(boolean addPreprocIfNecessary, InputType... inputTypes) {
        return getLayerActivationTypes(addPreprocIfNecessary,true,inputTypes);
    }

    private Map<String, List<String>> verticesOutputTo() {
        Map<String, List<String>> verticesOutputTo = new HashMap<>(); //Key: vertex. Values: vertices that this node is an input for
        for (Map.Entry<String, GraphVertex> entry : vertices.entrySet()) {
            String vertexName = entry.getKey();
            List<String> vertexInputNames;
            vertexInputNames = vertexInputs.get(vertexName);

            if (vertexInputNames == null)
                continue;

            //Build reverse network structure:
            for (String s : vertexInputNames) {
                List<String> list = verticesOutputTo.get(s);
                if (list == null) {
                    list = new ArrayList<>();
                    verticesOutputTo.put(s, list);
                }
                list.add(vertexName); //Edge: s -> vertexName
            }
        }

        return verticesOutputTo;
    }

    private List<String> topologicalOrdering() {
        //First step: build network in reverse order (i.e., define map of a -> list(b) instead of list(a) -> b)
        Map<String, List<String>> verticesOutputTo = verticesOutputTo();
        LinkedList<String> noIncomingEdges = new LinkedList<>(networkInputs); //Set of all nodes with no incoming edges
        List<String> topologicalOrdering = new ArrayList<>();

        Map<String, Set<String>> inputEdges = new HashMap<>();
        for (Map.Entry<String, List<String>> entry : vertexInputs.entrySet()) {
            inputEdges.put(entry.getKey(), new HashSet<>(entry.getValue()));
        }

        while (!noIncomingEdges.isEmpty()) {
            String next = noIncomingEdges.removeFirst();
            topologicalOrdering.add(next);

            //Remove edges next -> vertexOuputsTo[...] from graph;
            List<String> nextEdges = verticesOutputTo.get(next);

            if (nextEdges != null && !nextEdges.isEmpty()) {
                for (String s : nextEdges) {
                    Set<String> set = inputEdges.get(s);
                    set.remove(next);
                    if (set.isEmpty()) {
                        noIncomingEdges.add(s); //No remaining edges for vertex i -> add to list for processing
                    }
                }
            }
        }

        //If any edges remain in the graph: graph has cycles:
        for (Map.Entry<String, Set<String>> entry : inputEdges.entrySet()) {
            Set<String> set = entry.getValue();
            if (set == null)
                continue;
            if (!set.isEmpty())
                throw new IllegalStateException(
                        "Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle ("
                                + "cycle includes vertex \"" + entry.getKey() + "\")");
        }

        return topologicalOrdering;
    }

    /**
     * Get a {@link MemoryReport} for the given computation graph configuration. This is used to estimate the
     * memory requirements for the given network configuration and input
     *
     * @param inputTypes Input types for the network
     * @return Memory report for the network
     */
    public NetworkMemoryReport getMemoryReport(InputType... inputTypes) {


        Map<String, MemoryReport> memoryReportMap = new LinkedHashMap<>();
        List<String> topologicalOrdering = topologicalOrdering();

        Map<String, InputType> vertexOutputs = new HashMap<>();
        int currLayerIdx = -1;
        for (String s : topologicalOrdering) {
            int inputIdx = networkInputs.indexOf(s);
            if (inputIdx != -1) {
                vertexOutputs.put(s, inputTypes[inputIdx]);
                continue;
            }

            GraphVertex gv = vertices.get(s);

            List<InputType> inputTypeList = new ArrayList<>();

            if (gv instanceof LayerVertex) {
                //Add preprocessor, if necessary:
                String in = vertexInputs.get(s).get(0);
                InputType layerInput = vertexOutputs.get(in);
                inputTypeList.add(layerInput);
                currLayerIdx++;
            } else {
                List<String> inputs = vertexInputs.get(s);
                if (inputs != null) {
                    for (String inputVertexName : inputs) {
                        inputTypeList.add(vertexOutputs.get(inputVertexName));
                    }
                }
            }



            InputType outputFromVertex =
                    gv.getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
            vertexOutputs.put(s, outputFromVertex);

            MemoryReport mr = gv.getMemoryReport(inputTypeList.toArray(new InputType[inputTypeList.size()]));

            memoryReportMap.put(s, mr);
        }

        return new NetworkMemoryReport(memoryReportMap, ComputationGraphConfiguration.class, "ComputationGraph",
                inputTypes);
    }

    @Data
    public static class GraphBuilder {
        private static final int DEFAULT_TBPTT_LENGTH = 20;

        protected Map<String, GraphVertex> vertices = new LinkedHashMap<>();

        /**
         * Key: graph node. Values: input to that node
         */
        protected Map<String, List<String>> vertexInputs = new LinkedHashMap<>();

        protected List<String> networkInputs = new ArrayList<>();
        protected List<InputType> networkInputTypes = new ArrayList<>();
        protected List<String> networkOutputs = new ArrayList<>();
        protected BackpropType backpropType = BackpropType.Standard;
        protected int tbpttFwdLength = DEFAULT_TBPTT_LENGTH;
        protected int tbpttBackLength = DEFAULT_TBPTT_LENGTH;

        protected Map<String, InputPreProcessor> inputPreProcessors = new LinkedHashMap<>();

        protected NeuralNetConfiguration.Builder globalConfiguration;

        protected boolean allowDisconnected = false;
        protected boolean allowNoOutput = false;
        protected boolean validateOutputConfig = true;
        protected boolean validateTbpttConfig = true;

        protected String lastAdded = null;

        public GraphBuilder(NeuralNetConfiguration.Builder globalConfiguration) {
            this.globalConfiguration = globalConfiguration;
        }

        public GraphBuilder(ComputationGraphConfiguration newConf, NeuralNetConfiguration.Builder globalConfiguration) {

            ComputationGraphConfiguration clonedConf = newConf.clone();

            this.vertices = clonedConf.getVertices();
            this.vertexInputs = clonedConf.getVertexInputs();

            this.networkInputs = clonedConf.getNetworkInputs();
            this.networkOutputs = clonedConf.getNetworkOutputs();
            this.backpropType = clonedConf.getBackpropType();
            this.tbpttFwdLength = clonedConf.getTbpttFwdLength();
            this.tbpttBackLength = clonedConf.getTbpttBackLength();
            this.globalConfiguration = globalConfiguration;
            //this.getGlobalConfiguration().setSeed(clonedConf.getDefaultConfiguration().getSeed());
        }

        /**
         * Specify the processors for a given layer
         * These are used at each layer for doing things like normalization and shaping of input.<br>
         * <b>Note</b>: preprocessors can also be defined using the {@link #addLayer(String, Layer, InputPreProcessor, String...)} method.
         *
         * @param layer     the name of the layer that this preprocessor will be used with
         * @param processor the preprocessor to use for the specified layer
         */
        public GraphBuilder inputPreProcessor(String layer, InputPreProcessor processor) {
            inputPreProcessors.put(layer, processor);
            return this;
        }

        /**
         * The type of backprop. Default setting is used for most networks (MLP, CNN etc),
         * but optionally truncated BPTT can be used for training recurrent neural networks.
         * If using TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength()
         *
         * @param type Type of backprop. Default: BackpropType.Standard
         */
        public GraphBuilder backpropType(BackpropType type) {
            this.backpropType = type;
            return this;
        }

        /**
         * When doing truncated BPTT: how many steps of forward pass should we do
         * before doing (truncated) backprop?<br>
         * Only applicable when doing backpropType(BackpropType.TruncatedBPTT)<br>
         * Typically tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter,
         * but may be larger than it in some circumstances (but never smaller)<br>
         * Ideally your training data time series length should be divisible by this
         * This is the k1 parameter on pg23 of
         * <a href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
         *
         * @param forwardLength Forward length > 0, >= backwardLength
         */
        public GraphBuilder tBPTTForwardLength(int forwardLength) {
            this.tbpttFwdLength = forwardLength;
            return this;
        }

        /**
         * When doing truncated BPTT: how many steps of backward should we do?<br>
         * Only applicable when doing backpropType(BackpropType.TruncatedBPTT)<br>
         * This is the k2 parameter on pg23 of
         * <a href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
         *
         * @param backwardLength <= forwardLength
         */
        public GraphBuilder tBPTTBackwardLength(int backwardLength) {
            this.tbpttBackLength = backwardLength;
            return this;
        }

        /**
         * When doing truncated backpropagation through time (tBPTT): how many steps should we do?<br>
         * Only applicable when doing backpropType(BackpropType.TruncatedBPTT)<br>
         * See: <a href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
         *
         * @param tbpttLength length > 0
         */
        public GraphBuilder tBPTTLength(int tbpttLength){
            tBPTTForwardLength(tbpttLength);
            return tBPTTBackwardLength(tbpttLength);
        }

        /**
         * Add a layer, with no {@link InputPreProcessor}, with the specified name and specified inputs.
         *
         * @param layerName   Name/label of the layer to add
         * @param layer       The layer configuration
         * @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
         *                    on a combination of the two.
         * @see #addLayer(String, Layer, InputPreProcessor, String...)
         */
        public GraphBuilder addLayer(String layerName, Layer layer, String... layerInputs) {
            return addLayer(layerName, layer, null, layerInputs);
        }

        /**
         * Add a layer, with no {@link InputPreProcessor}, with the specified name
         * and input from the last added layer/vertex.
         *
         * @param layerName   Name/label of the layer to add
         * @param layer       The layer configuration
         * @see #addLayer(String, Layer, InputPreProcessor, String...)
         */
        public GraphBuilder appendLayer(String layerName, Layer layer) {
            return appendLayer(layerName, layer, null);
        }

        /**
         * Add a layer, with no {@link InputPreProcessor}, with the specified name and specified inputs.
         *
         * @param layerName   Name/label of the layer to add
         * @param layer       The layer configuration
         * @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
         *                    on a combination of the two.
         * @see #addLayer(String, Layer, InputPreProcessor, String...)
         */
        public GraphBuilder layer(int layerName, Layer layer, String... layerInputs) {
            return addLayer(String.valueOf(layerName), layer, null, layerInputs);
        }

        /**
         * Add a layer, with no {@link InputPreProcessor}, with the specified name and specified inputs.
         *
         * @param layerName   Name/label of the layer to add
         * @param layer       The layer configuration
         * @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
         *                    on a combination of the two.
         * @see #addLayer(String, Layer, InputPreProcessor, String...)
         */
        public GraphBuilder layer(String layerName, Layer layer, String... layerInputs) {
            return addLayer(layerName, layer, null, layerInputs);
        }

        /**
         * Add a layer and an {@link InputPreProcessor}, with the specified name and specified inputs.
         *
         * @param layerName    Name/label of the layer to add
         * @param layer        The layer configuration
         * @param preProcessor The InputPreProcessor to use with this layer.
         * @param layerInputs  Inputs to this layer. Inputs may be other layers, GraphVertex objects,
         *                     on a combination of the two.
         */
        public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor,
                                     String... layerInputs) {
            NeuralNetConfiguration.Builder builder = globalConfiguration.clone();
            builder.layer(layer);
            addVertex(layerName, new LayerVertex(builder.build(), preProcessor), layerInputs);
            layer.setLayerName(layerName);
            return this;
        }

        /**
         * Add a layer and an {@link InputPreProcessor}, with the specified name
         * and input from the last added layer/vertex.
         *
         * @param layerName    Name/label of the layer to add
         * @param layer        The layer configuration
         * @param preProcessor The InputPreProcessor to use with this layer.
         */
        public GraphBuilder appendLayer(String layerName, Layer layer, InputPreProcessor preProcessor) {

            if(lastAdded == null){
                throw new IllegalStateException("Can not use appendLayer with no previous layers");
            }

            addLayer(layerName, layer, preProcessor, lastAdded);
            return this;
        }

        /**
         * Add a layer and an {@link InputPreProcessor}, with the specified name and specified inputs.
         *
         * @param layerName    Name/label of the layer to add
         * @param layer        The layer configuration
         * @param preProcessor The InputPreProcessor to use with this layer.
         * @param layerInputs  Inputs to this layer. Inputs may be other layers, GraphVertex objects,
         *                     on a combination of the two.
         */
        public GraphBuilder layer(String layerName, Layer layer, InputPreProcessor preProcessor,
                                  String... layerInputs) {
            return addLayer(layerName, layer, preProcessor, layerInputs);
        }

        /**
         * Intended for use with the transfer learning API. Users discouraged from employing it directly.
         * Removes the specified vertex from the vertices list, it's connections and associated preprocessor
         * If the vertex removed is an output vertex it will also be removed from the list of outputs
         * @param vertexName Name of the vertex to remove
         */
        public GraphBuilder removeVertex(String vertexName) {
            removeVertex(vertexName, true);
            return this;
        }

        /**
         * Intended for use with the transfer learning API. Users discouraged from employing it directly.
         * Removes the specified vertex from the vertices list,
         * Removes it's connections (associated preprocessor and if an output also removes it from list of outputs) if "removeConnections" is specified as true
         * Specifying as false can leave the graph in an invalid state with references to vertices that donot exist unless a new vertex is added back in with the same name
         * @param removeConnections Specify true to remove connections
         * @param vertexName Name of the vertex to remove
         */
        public GraphBuilder removeVertex(String vertexName, boolean removeConnections) {
            vertices.remove(vertexName);
            vertexInputs.remove(vertexName);
            if (networkInputs.contains(vertexName)) {
                networkInputs.remove(vertexName);
            }
            if (removeConnections) {
                if (networkOutputs.contains(vertexName)) {
                    networkOutputs.remove(vertexName);
                }
                Map<String,List<String>> newVertexInputs = new LinkedHashMap<>();
                for (Map.Entry<String, List<String>> entry : this.vertexInputs.entrySet()) {
                    List<String> inputs = entry.getValue();
                    if (inputs.contains(vertexName)) {
                        //Some lists are not modifiable. So we'll make a new copy, minus the one to be removed
                        List<String> newList = new ArrayList<>(inputs.size()-1);
                        for(String s : inputs){
                            if(!vertexName.equals(s)){
                                newList.add(s);
                            }
                        }
                        newVertexInputs.put(entry.getKey(), newList);
                    } else {
                        newVertexInputs.put(entry.getKey(), entry.getValue());
                    }
                }
                this.vertexInputs = newVertexInputs;

                if (inputPreProcessors.containsKey(vertexName)) {
                    inputPreProcessors.remove(vertexName);
                }
            }
            return this;
        }

        /**
         * Specify the inputs to the network, and their associated labels.
         *
         * @param inputNames The names of the inputs. This also defines their order
         */
        public GraphBuilder addInputs(String... inputNames) {
            Collections.addAll(networkInputs, inputNames);
            lastAdded = networkInputs.get(networkInputs.size() - 1);
            return this;
        }

        /**
         * Specify the inputs to the network, and their associated labels.
         *
         * @param inputNames The names of the inputs. This also defines their order
         */
        public GraphBuilder addInputs(Collection<String> inputNames) {
            networkInputs.addAll(inputNames);
            lastAdded = networkInputs.get(networkInputs.size() - 1);
            return this;
        }

        /**Specify the types of inputs to the network, so that:<br>
         * (a) preprocessors can be automatically added, and<br>
         * (b) the nIns (input size) for each layer can be automatically calculated and set<br>
         * The order here is the same order as .addInputs(). Thus, if you do .addInputs("a","b") and .setInputTypes(InputType.feedForward(),
         * InputType.convolutional(28,28,1)) then the input labelled "a" is a feed forward input, whereas the input labelled "b" in a CNN
         * input, with 28x28x1 images as input.<br>
         * <b>Note</b>: Using setInputTypes is not always necessary, but can be especially helpful for example with CNNs such that
         * the calculations on input/output sizes (width, height, channels, etc) don't need to be done manually.<br>
         * <b>Note 2</b>: If a preprocessor is manually added for a given layer, it will not be overridden by the automatic
         * addition of preprocessors.
         * <b>Note 3</b>: If a layer has an nIn set manually, this will not be overridden
         */
        public GraphBuilder setInputTypes(InputType... inputTypes) {
            if (inputTypes != null && inputTypes.length > 0) {
                if (networkInputs.size() > 0 &&     //If no network inputs have been set here - can't valid number of input types here...
                        networkInputTypes.size() + inputTypes.length != networkInputs.size()) {
                    throw new IllegalArgumentException(
                            "Invalid number of InputTypes: " +
                                    "existing inputTypes ("+networkInputTypes.size()+") + additional inputTypes ("+inputTypes.length+")" +
                                    " != number of network inputs ("+networkInputs.size()+")");
                }
                Collections.addAll(networkInputTypes, inputTypes);
            }
            return this;
        }


        /**
         * Set the network output labels. These should be the names of the OutputLayer instances in the network
         *
         * @param outputNames The names of the output layers. This also defines their order.
         */
        public GraphBuilder setOutputs(String... outputNames) {
            networkOutputs.clear();
            Collections.addAll(networkOutputs, outputNames);

            return this;
        }

        /**
         * Add a {@link GraphVertex} to the network configuration. A GraphVertex defines forward and backward pass methods,
         * and can contain a {@link LayerVertex}, a {@link org.deeplearning4j.nn.conf.graph.ElementWiseVertex} to do element-wise
         * addition/subtraction, a {@link MergeVertex} to combine/concatenate the activations out of multiple layers or vertices,
         * a {@link org.deeplearning4j.nn.conf.graph.SubsetVertex} to select a subset of the activations out of another layer/GraphVertex.<br>
         * Custom GraphVertex objects (that extend the abstract {@link GraphVertex} class) may also be used.
         *
         * @param vertexName   The name of the GraphVertex to add
         * @param vertex       The GraphVertex to add
         * @param vertexInputs The inputs/activations to this GraphVertex.
         */
        public GraphBuilder addVertex(String vertexName, GraphVertex vertex, String... vertexInputs) {

            Preconditions.checkState(!vertices.containsKey(vertexName), "Cannot add vertex: a vertex with name \"%s\" already exists", vertexName);
            vertices.put(vertexName, vertex);

            //Automatically insert a MergeNode if this vertex can only take 1 input (layer vertices, etc)
            if (vertex.maxVertexInputs() == 1 && vertexInputs != null && vertexInputs.length > 1) {
                String mergeName = vertexName + "-merge";
                addVertex(mergeName, new MergeVertex(), vertexInputs);
                this.vertexInputs.put(vertexName, Collections.singletonList(mergeName));
            } else if (vertexInputs != null) {
                this.vertexInputs.put(vertexName, Arrays.asList(vertexInputs));
            }

            this.lastAdded = vertexName;

            return this;
        }

        /**
         * Add a {@link GraphVertex} to the network configuration, with input from the last added vertex/layer. A GraphVertex defines forward and backward pass methods,
         * and can contain a {@link LayerVertex}, a {@link org.deeplearning4j.nn.conf.graph.ElementWiseVertex} to do element-wise
         * addition/subtraction, a {@link MergeVertex} to combine/concatenate the activations out of multiple layers or vertices,
         * a {@link org.deeplearning4j.nn.conf.graph.SubsetVertex} to select a subset of the activations out of another layer/GraphVertex.<br>
         * Custom GraphVertex objects (that extend the abstract {@link GraphVertex} class) may also be used.
         *
         * @param vertexName   The name of the GraphVertex to add
         * @param vertex       The GraphVertex to add
         */
        public GraphBuilder appendVertex(String vertexName, GraphVertex vertex) {

            if(lastAdded == null){
                throw new IllegalStateException("Can not use appendLayer with no previous layers");
            }

            addVertex(vertexName, vertex, lastAdded);
            return this;
        }

        /**
         * Used only during validation after building.<br>
         * If true: don't throw an exception on configurations containing vertices that are 'disconnected'. A disconnected
         * vertex is one that is not an output, and doesn't connect to any other vertices. i.e., it's output activations
         * don't go anywhere. Most users can (and should) leave this as the default value of false.
         *
         * @param allowDisconnected Whether to allow disconnected vertices, during validation
         */
        public GraphBuilder allowDisconnected(boolean allowDisconnected){
            this.allowDisconnected = allowDisconnected;
            return this;
        }

        /**
         * Used only during validation after building.<br>
         * If true: don't throw an exception on configurations without any outputs. This is enabled by default
         * to avoid creating invalid graphs, but can be disabled if required.<br>
         * Most users can (and should) leave this as the default value of false.
         *
         * @param allowNoOutput Whether to allow no outputs, during validation
         */
        public GraphBuilder allowNoOutput(boolean allowNoOutput){
            this.allowNoOutput = allowNoOutput;
            return this;
        }

        /**
         * Enabled by default. If enabled, the output layer configuration will be validated, to throw an exception on
         * likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.<br>
         * If disabled (false) no output layer validation will be performed.<br>
         * Disabling this validation is not recommended, as the configurations that fail validation usually will
         * not be able to learn correctly. However, the option to disable this validation is provided for advanced users
         * when creating non-standard architectures.
         *
         * @param validate If true: validate output layer configuration. False: don't validate
         */
        public GraphBuilder validateOutputLayerConfig(boolean validate) {
            this.validateOutputConfig = validate;
            return this;
        }

        /**
         * Enabled by default. If enabled, an exception will be throw when using the (invalid) combination of truncated
         * backpropagation through time (TBPTT) with either a GlobalPoolingLayer or LastTimeStepLayer.<br>
         * It is possible to disable this validation to allow what is almost certainly an invalid configuration to be used,
         * however this is not recommended.
         *
         * @param validate Whether TBPTT validation should be performed
         */
        public GraphBuilder validateTbpttConfig(boolean validate){
            this.validateTbpttConfig = validate;
            return this;
        }

        /**
         * For the (perhaps partially constructed) network configuration, return a map of activation sizes for each
         * layer and vertex in the graph.<br>
         * Note 1: The network configuration may be incomplete, but the inputs have been added to the layer already.<br>
         * Note 2: To use this method, the network input types must have been set using {@link #setInputTypes(InputType...)}
         * first
         * @return A map of activation types for the graph (key: vertex name. value: type of activations out of that vertex)
         */
        public Map<String,InputType> getLayerActivationTypes() {
            Preconditions.checkArgument(networkInputs != null && networkInputs.size() > 0,
                    "Cannot calculate activation types if no inputs have been set (use addInputs(String...))");
            Preconditions.checkArgument(networkInputTypes != null && networkInputTypes.size() == networkInputs.size(),
                    "Cannot calculate layer activation types if network if network input types have not" +
                            "been set (use ");

            //Instantiate temporary ComputationGraphConfiguration and calculate output shapes
            ComputationGraphConfiguration conf;
            try{
                conf = buildConfig();
            } catch (Exception e){
                throw new RuntimeException("Error calculating activation types for layers: error occured when constructing " +
                        "temporary ComputationGraphConfiguration)", e);
            }

            try{
                conf.validate(true, true);
            } catch (Exception e){
                throw new RuntimeException("Error calculating activation types for layers: validation of temporary" +
                        " ComputationGraphConfiguration failed", e);
            }

            return conf.getLayerActivationTypes(true, networkInputTypes.toArray(new InputType[networkInputTypes.size()]));
        }


        private ComputationGraphConfiguration buildConfig(){
            //Validate BackpropType setting
            if((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH) && backpropType != BackpropType.TruncatedBPTT){
                log.warn("Truncated backpropagation through time lengths have been configured with values " + tbpttFwdLength
                        + " and " + tbpttBackLength + " but backprop type is set to " + backpropType + ". TBPTT configuration" +
                        " settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT");
            }

            ComputationGraphConfiguration conf = new ComputationGraphConfiguration();
            conf.backpropType = backpropType;
            conf.tbpttBackLength = tbpttBackLength;
            conf.tbpttFwdLength = tbpttFwdLength;

            conf.networkInputs = networkInputs;
            conf.networkOutputs = networkOutputs;

            conf.vertices = this.vertices;
            conf.vertexInputs = this.vertexInputs;
            conf.trainingWorkspaceMode = globalConfiguration.trainingWorkspaceMode;
            conf.inferenceWorkspaceMode = globalConfiguration.inferenceWorkspaceMode;
            conf.cacheMode = globalConfiguration.cacheMode;
            conf.validateOutputLayerConfig = validateOutputConfig;
            conf.dataType = globalConfiguration.dataType;

            conf.defaultConfiguration = globalConfiguration.build();

            //Add preprocessors that were defined separately to the Layers to which they belong
            for (Map.Entry<String, InputPreProcessor> entry : inputPreProcessors.entrySet()) {
                GraphVertex gv = vertices.get(entry.getKey());
                if (gv instanceof LayerVertex) {
                    LayerVertex lv = (LayerVertex) gv;
                    lv.setPreProcessor(entry.getValue());
                } else {
                    throw new IllegalStateException(
                            "Invalid configuration: InputPreProcessor defined for GraphVertex \""
                                    + entry.getKey() + "\", but this vertex is not a LayerVertex");
                }

            }

            for (Map.Entry<String, GraphVertex> gv : vertices.entrySet()) {
                if (gv.getValue() instanceof LayerVertex) {
                    LayerVertex lv = (LayerVertex) gv.getValue();
                    Layer l = lv.getLayerConf().getLayer();
                }
                if (gv.getValue() instanceof SameDiffVertex)
                    ((SameDiffVertex) gv.getValue()).applyGlobalConfig(globalConfiguration);

            }

            return conf;
        }


        /**
         * Create the ComputationGraphConfiguration from the Builder pattern
         */
        public ComputationGraphConfiguration build() {

            ComputationGraphConfiguration conf = buildConfig();
            conf.validate(allowDisconnected, allowNoOutput); //throws exception for invalid configuration

            //Automatically add preprocessors, set nIns for CNN->dense transitions, etc
            if (!networkInputTypes.isEmpty()) {
                conf.addPreProcessors(networkInputTypes.toArray(new InputType[networkInputs.size()]));
            }

            if(validateOutputConfig) {
                //Validate output layer configurations...
                for (Map.Entry<String, GraphVertex> e : conf.getVertices().entrySet()) {
                    if (e.getValue() instanceof LayerVertex) {
                        Layer l = ((LayerVertex) e.getValue()).getLayerConf().getLayer();
                        OutputLayerUtil.validateOutputLayer(e.getKey(), l); //No-op for non output/loss layers
                    }
                }
            }

            if(backpropType == BackpropType.TruncatedBPTT && validateTbpttConfig) {
                //Check for invalid combination - tbptt plus LastTimeStepLayer or
                for(Map.Entry<String,GraphVertex> e : vertices.entrySet()){
                    GraphVertex gv = e.getValue();
                    Layer l = (gv instanceof LayerVertex ? ((LayerVertex)gv).getLayerConf().getLayer() : null);
                    if(gv instanceof LastTimeStepVertex || (l != null && (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer))){
                        String s = (l == null ? gv.getClass().getName() : l.getClass().getName());
                        String n = e.getKey();
                        throw new IllegalStateException("Invalid network configuration detected: Truncated backpropagation through time (TBPTT)" +
                                " cannot be used with layer \"" + n + "\" of type " + s + ": TBPTT is incompatible with this layer type (which is designed " +
                                "to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n" +
                                "This check can be disabled using validateTbpttConfig(false) but this is not recommended.");
                    }
                }
            }

            return conf;
        }
    }
}