deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java

Summary

Maintainability
F
1 mo
Test Coverage
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.nn.graph;

import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.api.*;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.graph.util.GraphIndices;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.graph.vertex.impl.FrozenVertex;
import org.deeplearning4j.nn.graph.vertex.impl.InputVertex;
import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.util.CrashReportingUtil;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.NetworkUtils;
import org.deeplearning4j.util.OutputLayerUtil;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.workspace.ND4JWorkspaceException;
import org.nd4j.linalg.workspace.WorkspaceUtils;
import org.nd4j.util.OneTimeLogger;

import java.io.*;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;

/**
 * A ComputationGraph network is a neural network with arbitrary (directed acyclic graph) connection structure.
 * A ComputationGraph may also have an arbitrary number of inputs and outputs.
 *
 * @author Alex Black
 */
@Slf4j
public class ComputationGraph implements Serializable, Model, NeuralNetwork {

    protected ComputationGraphConfiguration configuration;
    protected boolean initCalled = false;
    protected transient Solver solver; //Used to call optimizers during backprop
    protected INDArray flattenedParams; //Params for all layers are a view/subset of this array
    @Getter
    protected transient INDArray flattenedGradients; //Gradients for all layers are a view/subset of this array
    protected Gradient gradient;
    protected double score;
    @Setter
    private boolean initDone = false;
    protected boolean clearTbpttState = true;  //Mainly for unit testing (should be enabled otherwise)
    //Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers
    @Getter
    protected transient Map<String,Pointer> helperWorkspaces = new HashMap<>();

    private transient final AtomicLong occupiedBy = new AtomicLong(-1);

    /**
     * Workspace for working memory for a single layer: forward pass and backward pass
     * Note that this is opened/closed once per op (activate/backpropGradient call)
     */
    protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM";
    /**
     * Workspace for storing all layers' activations - used only to store activations (layer inputs) as part of backprop
     * Not used for inference
     */
    protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT";
    /**
     * Workspace for working memory in RNNs - opened and closed once per RNN time step
     */
    protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM";

    /**
     * Workspace for output methods that use OutputAdapter
     */
    protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM";

    protected final WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG;

    protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder()
            .initialSize(0)
            .overallocationLimit(0.05)
            .policyLearning(LearningPolicy.FIRST_LOOP)
            .policyReset(ResetPolicy.BLOCK_LEFT)
            .policySpill(SpillPolicy.REALLOCATE)
            .policyAllocation(AllocationPolicy.OVERALLOCATE)
            .build();

    protected final WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG;

    protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder()
            .initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT)
            .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE)
            .policyLearning(LearningPolicy.FIRST_LOOP).build();


    protected transient ThreadLocal<Long> lastEtlTime = new ThreadLocal<>();

    /**
     * All GraphVertex objects in the network.
     */
    protected GraphVertex[] vertices;
    /**
     * Map of vertices by name
     */
    protected Map<String, GraphVertex> verticesMap;
    /**
     * Indexes of graph vertices, in topological order. The topological order defines the order in which forward pass
     * (and hence also backward pass, which is the opposite to this) is conducted in the network.
     */
    protected int[] topologicalOrder;
    /**
     * Topological sort and vertex index/name + name/index mapping
     */
    protected GraphIndices graphIndices;

    /**
     * A list of layers. Each of these layers is present in a GraphVertex, but are here for easy reference.
     * This array also defines the order in which the getLayer(int) method returns layers.
     */
    protected Layer[] layers;

    /**
     * The number of input arrays to the network. Many networks only have 1 input; however, a ComputationGraph may
     * have an arbitrary number (>=1) separate input arrays
     */
    private int numInputArrays;
    /**
     * The number of output arrays to the network. Many networks only have 1 output; however, a ComputationGraph may
     * have an arbitrary number (>=1) separate output arrays
     */
    private int numOutputArrays;

    //Current inputs, labels, input mask arrays and label mask arrays
    private transient INDArray[] inputs;
    private transient INDArray[] labels;
    private transient INDArray[] inputMaskArrays;
    private transient INDArray[] labelMaskArrays;

    private transient int[] outputLayerIdxs;

    private NeuralNetConfiguration defaultConfiguration;
    private Collection<TrainingListener> trainingListeners = new ArrayList<>();


    public ComputationGraph(ComputationGraphConfiguration configuration) {
        this.configuration = configuration;
        this.numInputArrays = configuration.getNetworkInputs().size();
        this.numOutputArrays = configuration.getNetworkOutputs().size();
        this.inputs = new INDArray[numInputArrays];
        this.labels = new INDArray[numOutputArrays];
        this.defaultConfiguration = configuration.getDefaultConfiguration();

        //Working memory: should learn over course of: (a) full forward pass, and (b) full backward pass
        //Working memory should be opened once per vertex, for each of forward and backward passes
        int numWorkingMem = 2 * configuration.getVertices().size();
        WS_LAYER_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder()
                .initialSize(0)
                .overallocationLimit(0.02)
                .policyLearning(LearningPolicy.OVER_TIME)
                .cyclesBeforeInitialization(numWorkingMem)
                .policyReset(ResetPolicy.BLOCK_LEFT)
                .policySpill(SpillPolicy.REALLOCATE)
                .policyAllocation(AllocationPolicy.OVERALLOCATE)
                .build();

        //Activations memory: opened once per layer - for every second layer (preprocessors are within the loop).
        //Technically we could set learning to numLayers / 2, but will set to numLayers for simplicity, and also to
        // account for a backward pass
        WS_LAYER_ACT_X_CONFIG = WorkspaceConfiguration.builder()
                .initialSize(0)
                .overallocationLimit(0.02)
                .policyLearning(LearningPolicy.OVER_TIME)
                .cyclesBeforeInitialization(configuration.getVertices().size())
                .policyReset(ResetPolicy.BLOCK_LEFT)
                .policySpill(SpillPolicy.REALLOCATE)
                .policyAllocation(AllocationPolicy.OVERALLOCATE)
                .build();
    }

    /**
     * This method allows to set ETL field time, useful for performance tracking
     *
     * @param time
     */
    public void setLastEtlTime(long time) {
        lastEtlTime.set(time);
    }

    /**
     * This method returns ETL time field value
     *
     * @return
     */
    public long getLastEtlTime() {
        Long time = lastEtlTime.get();
        return time == null ? 0L : time;
    }

    /**
     * This method sets specified CacheMode for all layers within network
     *
     * @param mode
     */
    public void setCacheMode(CacheMode mode) {
        if (mode == null)
            mode = CacheMode.NONE;

        for (Layer layer : layers) {
            layer.setCacheMode(mode);
        }
    }

    /**
     * This method returns configuration of this ComputationGraph
     *
     * @return
     */
    public ComputationGraphConfiguration getConfiguration() {
        return configuration;
    }

    /**
     * Returns the number of layers in the ComputationGraph
     */
    public int getNumLayers() {
        return (layers != null ? layers.length : 0);
    }

    /**
     * Get the layer by the number of that layer, in range 0 to getNumLayers()-1
     * NOTE: This is different from the internal GraphVertex index for the layer
     */
    public Layer getLayer(int idx) {
        return layers[idx];
    }

    /**
     * Get all layers in the ComputationGraph
     */
    public Layer[] getLayers() {
        return layers;
    }

    /**
     * Get a given layer by name.
     */
    public Layer getLayer(String name) {
        Preconditions.checkState(verticesMap.containsKey(name), "Layer with name %s does not exist in the network", name);
        return verticesMap.get(name).getLayer();
    }

    /**
     * Returns an array of all GraphVertex objects.
     */
    public GraphVertex[] getVertices() {
        return vertices;
    }

    /**
     * Return a given GraphVertex by name, or null if no vertex with that name exists
     */
    public GraphVertex getVertex(String name) {
        return verticesMap.get(name);
    }

    /**
     * The number of inputs to this network
     */
    public int getNumInputArrays() {
        return numInputArrays;
    }

    /**
     * The number of output (arrays) for this network
     */
    public int getNumOutputArrays() {
        return numOutputArrays;
    }

    /**
     * Set the specified input for the ComputationGraph
     */
    public void setInput(int inputNum, INDArray input) {
        if (inputs == null) {
            //May be null after clear()
            inputs = new INDArray[numInputArrays];
        }
        inputs[inputNum] = input;
    }

    /**
     * Set all inputs for the ComputationGraph network
     */
    public void setInputs(INDArray... inputs) {
        if (inputs != null && inputs.length != this.numInputArrays) {
            throw new IllegalArgumentException("Invalid input array: network has " + numInputArrays
                    + " inputs, but array is of length " + inputs.length);
        }
        this.inputs = inputs;
    }

    /**
     * Get the previously set input for the ComputationGraph
     */
    public INDArray getInput(int inputNum) {
        if (inputs == null)
            return null;
        return inputs[inputNum];
    }

    /**
     * Get the previously set inputs for the ComputationGraph
     */
    public INDArray[] getInputs() {
        return inputs;
    }

    /**
     * Get the previously set feature/input mask arrays for the ComputationGraph
     */
    public INDArray[] getInputMaskArrays() {
        return inputMaskArrays;
    }

    /**
     * Get the previously set label/output mask arrays for the ComputationGraph
     */
    public INDArray[] getLabelMaskArrays() {
        return labelMaskArrays;
    }

    /**
     * Set the specified label for the ComputationGraph
     */
    public void setLabel(int labelNum, INDArray label) {
        labels[labelNum] = label;
    }

    /**
     * Set all labels for the ComputationGraph network
     */
    public void setLabels(INDArray... labels) {
        if (labels != null && labels.length != this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid output array: network has " + numOutputArrays
                    + " outputs, but array is of length " + labels.length);
        }
        this.labels = labels;
    }

    /**
     * This method allows you to specificy GradientsAccumulator instance to be used with this model
     * <p>
     * PLEASE NOTE: Do not use this method unless you understand how to use GradientsAccumulator & updates sharing.
     * PLEASE NOTE: Do not use this method on standalone model
     *
     * @param accumulator
     */
    public void setGradientsAccumulator(GradientsAccumulator accumulator) {
        if (!initCalled)
            init();

        solver.getOptimizer().setGradientsAccumulator(accumulator);
    }

    /**
     * Initialize the ComputationGraph network
     */
    public void init() {
        init(null, false);
    }

    /**
     * Initialize the ComputationGraph, optionally with an existing parameters array.
     * If an existing parameters array is specified, it will be used (and the values will not be modified) in the network;
     * if no parameters array is specified, parameters will be initialized randomly according to the network configuration.
     *
     * @param parameters           Network parameter. May be null. If null: randomly initialize.
     * @param cloneParametersArray Whether the parameter array (if any) should be cloned, or used directly
     */
    public void init(INDArray parameters, boolean cloneParametersArray) {
        if (initCalled)
            return;

        DataType netDtype = getConfiguration().getDataType();
        if(parameters != null && parameters.dataType() != netDtype){
            Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters);
            if(cloneParametersArray){
                try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
                    parameters = parameters.castTo(netDtype);
                }
            } else {
                throw new IllegalStateException("Error initializing network: Network datatype is set to " + netDtype
                        + " but provided array has datatype " + parameters.dataType() + " with cloneParametersArray argument" +
                        " set to false. Cannot initialize net with specified datatype array if that array does not match network datatype");
            }
        }

        if (configuration.getTrainingWorkspaceMode() == null)
            configuration.setTrainingWorkspaceMode(WorkspaceMode.NONE);

        if (configuration.getInferenceWorkspaceMode() == null)
            configuration.setInferenceWorkspaceMode(WorkspaceMode.NONE);

        if (configuration.getCacheMode() == null)
            configuration.setCacheMode(CacheMode.NONE);

        OneTimeLogger.info(log, "Starting ComputationGraph with WorkspaceModes set to [training: {}; inference: {}], cacheMode set to [{}]",
                configuration.getTrainingWorkspaceMode(), configuration.getInferenceWorkspaceMode(), configuration.getCacheMode());

        //First: build topological ordering, based on configuration. Used for forward pass, backprop and order of parameters/gradients
        GraphIndices indices = calculateIndices();
        topologicalOrder = indices.getTopologicalSortOrder();

        //Initialization: create the GraphVertex objects, based on configuration structure
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> configVertexMap = configuration.getVertices();

        //Names of all of the (data) inputs to the ComputationGraph
        List<String> networkInputNames = configuration.getNetworkInputs();

        //Inputs for each layer and GraphNode:
        Map<String, List<String>> vertexInputs = configuration.getVertexInputs();
        this.vertices = new GraphVertex[networkInputNames.size() + configuration.getVertices().size()];

        //All names: inputs, layers and graph nodes (index to name map)
        Map<String, Integer> allNamesReverse = new HashMap<>();

        //Create network input vertices:
        int vertexNumber = 0;
        for (String name : networkInputNames) {
            GraphVertex gv = new InputVertex(this, name, vertexNumber, null, netDtype); //Output vertices: set later
            allNamesReverse.put(name, vertexNumber);
            vertices[vertexNumber++] = gv;
        }

        //Go through layers, and work out total number of parameters. Then allocate full parameters array
        long numParams = 0;
        long[] numParamsForVertex = new long[topologicalOrder.length];
        int i = 0;
        for (; i < configuration.getNetworkInputs().size(); i++) {
            numParamsForVertex[i] = 0; //No parameters for input vertices
        }
        for(; i<topologicalOrder.length; i++ ){
            String name = indices.getIdxToName().get(i);
            org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
            n.setDataType(netDtype);
            numParamsForVertex[i] = n.numParams(true);
            numParams += numParamsForVertex[i];
        }

        boolean initializeParams;
        if (parameters != null) {
            if (!parameters.isRowVectorOrScalar())
                throw new IllegalArgumentException("Invalid parameters: should be a row vector");
            if (parameters.length() != numParams)
                throw new IllegalArgumentException("Invalid parameters: expected length " + numParams + ", got length "
                        + parameters.length());

            if (cloneParametersArray)
                flattenedParams = parameters.dup();
            else
                flattenedParams = parameters;

            initializeParams = false;
        } else if(numParams > 0){
            flattenedParams = Nd4j.create(netDtype, 1, numParams);
            initializeParams = true;
        } else {
            flattenedParams = null;
            initializeParams = false;
        }

        //Set RNG seed, for repeatability between initializations when set
        if (initializeParams) {
            Nd4j.getRandom().setSeed(conf().getSeed());
        }

        //Given the topological ordering: work out the subset of the parameters array used for each layer
        // Then extract out for use when initializing the Layers
        INDArray[] paramsViewForVertex = new INDArray[topologicalOrder.length];
        long paramOffsetSoFar = 0;
        i = 0;
        for (int vertexIdx : topologicalOrder) {
            long nParamsThisVertex = numParamsForVertex[vertexIdx];
            if (nParamsThisVertex != 0) {
                paramsViewForVertex[vertexIdx] = flattenedParams.get(NDArrayIndex.interval(0,0,true),
                        NDArrayIndex.interval(paramOffsetSoFar, paramOffsetSoFar + nParamsThisVertex));
            }
            i++;
            paramOffsetSoFar += nParamsThisVertex;
        }


        int numLayers = 0;
        List<Layer> tempLayerList = new ArrayList<>();
        defaultConfiguration.clearVariables();
        List<String> variables = defaultConfiguration.variables(false);
        i = configuration.getNetworkInputs().size();
        for(; i<topologicalOrder.length; i++ ){
            String name = indices.getIdxToName().get(i);
            org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);

            GraphVertex gv = n.instantiate(this, name, vertexNumber, paramsViewForVertex[vertexNumber],
                    initializeParams, netDtype);

            if(gv == null){
                throw new IllegalStateException("Encountered null layer/vertex during initialization for layer \"" + name +
                        "\": " + n.getClass().getSimpleName() + " initialization returned null layer/vertex?");
            }

            if (gv.hasLayer()) {
                numLayers++;
                Layer l = gv.getLayer();
                tempLayerList.add(l);
                List<String> layerVariables = l.conf().variables();
                if (layerVariables != null) {
                    for (String s : layerVariables) {
                        variables.add(gv.getVertexName() + "_" + s);
                    }
                }
            }

            allNamesReverse.put(name, vertexNumber);
            vertices[vertexNumber++] = gv;
        }
        layers = tempLayerList.toArray(new Layer[numLayers]);

        //Create the lookup table, so we can find vertices easily by name
        verticesMap = new HashMap<>();
        for (GraphVertex gv : vertices) {
            verticesMap.put(gv.getVertexName(), gv);
        }

        //Now: do another pass to set the input and output indices, for each vertex
        // These indices are used during forward and backward passes
        //To get output indices: need to essentially build the graph in reverse...
        Map<String, List<String>> verticesOutputTo = new HashMap<>(); //Key: vertex. Values: vertices that this node is an input for
        for (GraphVertex gv : vertices) {
            String vertexName = gv.getVertexName();
            List<String> vertexInputNames;
            vertexInputNames = vertexInputs.get(vertexName);

            if (vertexInputNames == null)
                continue;

            //Build reverse network structure:
            for (String s : vertexInputNames) {
                List<String> list = verticesOutputTo.get(s);
                if (list == null) {
                    list = new ArrayList<>();
                    verticesOutputTo.put(s, list);
                }
                list.add(vertexName); //Edge: s -> vertexName
            }
        }


        for (GraphVertex gv : vertices) {
            String vertexName = gv.getVertexName();
            int vertexIndex = gv.getVertexIndex();
            List<String> vertexInputNames;
            vertexInputNames = vertexInputs.get(vertexName);

            if (vertexInputNames == null)
                continue;

            VertexIndices[] inputIndices = new VertexIndices[vertexInputNames.size()];
            for (int j = 0; j < vertexInputNames.size(); j++) {
                String inName = vertexInputNames.get(j);
                int inputVertexIndex = allNamesReverse.get(inName);

                //Here: we have x -> gv connection
                //gv might have multiple inputs, not just x
                //Need to know which input x is
                int inputNumber = vertexInputs.get(vertexName).indexOf(inName);

                if (inputNumber == -1)
                    throw new IllegalStateException("Could not find vertex " + vertexIndex + " in the list of inputs "
                            + "for vertex " + gv.getVertexName() + "; error in graph structure?");

                inputIndices[j] = new VertexIndices(inputVertexIndex, inputNumber);
            }

            gv.setInputVertices(inputIndices);
        }

        //Handle the outputs for this vertex
        for (GraphVertex gv : vertices) {
            String vertexName = gv.getVertexName();

            List<String> thisVertexOutputsTo = verticesOutputTo.get(vertexName);

            if (thisVertexOutputsTo == null || thisVertexOutputsTo.isEmpty())
                continue; //Output vertex
            VertexIndices[] outputIndices = new VertexIndices[thisVertexOutputsTo.size()];
            int j = 0;
            for (String s : new HashSet<>(thisVertexOutputsTo)) {
                //First, we have gv -> s
                //Which input in s does gv connect to? s may in general have multiple inputs...
                List<String> nextVertexInputNames = vertexInputs.get(s);

                for (int k = 0; k < nextVertexInputNames.size(); k++) {
                    if(vertexName.equals(nextVertexInputNames.get(k))){
                        int outputVertexIndex = allNamesReverse.get(s);
                        outputIndices[j++] = new VertexIndices(outputVertexIndex, k);
                    }
                }
            }
            gv.setOutputVertices(outputIndices);
        }

        //Mark any output vertices as outputs:
        for (String s : configuration.getNetworkOutputs()) {
            GraphVertex gv = verticesMap.get(s);
            gv.setOutputVertex(true);
        }

        // now we init solver & optimizer
        if (solver == null) {
            try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                solver.initOptimizer();
            }
        }

        //Mark which layers can safely modify their input in-place. This is a performance optimization for
        // dropout and similar operations.
        // Safe when the input is: (a) it's not a graph input, and (b) isn't shared by any other layers/vertices

        Map<String,List<String>> seenAsInputTo = new HashMap<>();
        for(Map.Entry<String,List<String>> entry : configuration.getVertexInputs().entrySet()){
            for(String s : entry.getValue() ){
                if (!seenAsInputTo.containsKey(s)) {
                    seenAsInputTo.put(s, new ArrayList<String>());
                }
                List<String> seen = seenAsInputTo.get(s);
                seen.add(entry.getKey());
            }
        }

        for(Layer l : layers){
            String layerName = l.conf().getLayer().getLayerName();
            List<String> inputs = configuration.getVertexInputs().get(layerName);
            String in = inputs.get(0);  //For now: layers should have exactly 1 input

            if(configuration.getNetworkInputs().contains(in)){
                //TODO When is it safe to NOT allow input modifucation? It's not always safe...
                // For example dropout + iterating over List<MultiDataSet> that is used for multiple epochs...
                continue;
            }

            List<String> seen = seenAsInputTo.get(in);
            if(seen.size() == 1){
                l.allowInputModification(true);
            } else {
                //For the count > 1 case, we can work out if it's the last one in the topological order... at which point,
                // it should be safe to use
                int thisIdx = indices.getNameToIdx().get(layerName);
                int thisTopoPos = ArrayUtils.indexOf(indices.getTopologicalSortOrder(), thisIdx);
                int maxTopoPosition = -1;
                for(String s : seen){
                    int idx = indices.getNameToIdx().get(s);
                    int topoPos = ArrayUtils.indexOf(indices.getTopologicalSortOrder(), idx);
                    maxTopoPosition = Math.max(maxTopoPosition, topoPos);
                }

                if(thisTopoPos == maxTopoPosition){
                    //Last one in the topo sort... all other layers have already consumed this input by the time this layer's
                    // forward pass is done
                    l.allowInputModification(true);
                }   //Otherwise: keep default of false
            }
        }

        synchronizeIterEpochCounts();
        initCalled = true;
    }

    /**
     * This method: initializes the flattened gradients array (used in backprop) and sets the appropriate subset in all layers.
     * As a general rule, this shouldn't ever need to be called manually when doing training via fit(DataSet), fit(DataSetIterator)
     * or fit(MultiDataSet) methods
     */
    public void initGradientsView() {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
            if (!initCalled)
                init();

            GraphIndices indices = calculateIndices();

            //Go through layers, and work out total number of parameters. Then allocate full parameters array
            long numParams = 0;
            long[] numParamsForVertex = new long[topologicalOrder.length];
            int i = 0;
            for (; i < configuration.getNetworkInputs().size(); i++) {
                numParamsForVertex[i] = 0; //No parameters for input vertices
            }
            Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> configVertexMap = configuration.getVertices();
            for (; i < topologicalOrder.length; i++) {
                String name = indices.getIdxToName().get(i);
                org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
                numParamsForVertex[i] = n.numParams(true);
                numParams += numParamsForVertex[i];
            }

            if(numParams > 0) {
                flattenedGradients = Nd4j.create(flattenedParams.dataType(), 1, numParams);
            }

            //Given the topological ordering: work out the subset of the gradient array used for each layer, and set it
            long paramOffsetSoFar = 0;
            i = 0;
            for (int vertexIdx : topologicalOrder) {
                long nParamsThisVertex = numParamsForVertex[vertexIdx];
                if (nParamsThisVertex != 0) {
                    INDArray gradientView = flattenedGradients.get(NDArrayIndex.interval(0,0,true),
                            NDArrayIndex.interval(paramOffsetSoFar, paramOffsetSoFar + nParamsThisVertex));
                    vertices[vertexIdx].setBackpropGradientsViewArray(gradientView);
                }
                i++;
                paramOffsetSoFar += nParamsThisVertex;
            }
        }
    }

    protected int[] getOutputLayerIndices(){
        if(outputLayerIdxs == null) {
            outputLayerIdxs = new int[numOutputArrays];
            int i = 0;
            for (String s : configuration.getNetworkOutputs()) {
                outputLayerIdxs[i++] = verticesMap.get(s).getVertexIndex();
            }
        }
        return  outputLayerIdxs;
    }

    /**
     * Perform layerwise pretraining for one epoch - see {@link #pretrain(DataSetIterator, int)}
     */
    public void pretrain(DataSetIterator iter) {
        pretrain(iter, 1);
    }

    /**
     * Pretrain network with a single input and single output. DataSetIterators can only be used if the number of input
     * arrays for the ComputationGraph is 1.<br>
     * This method performs layerwise pretraining on all pre-trainable layers in the network (VAEs, Autoencoders, etc), for the specified
     * number of epochs each. For example, if numEpochs=3, then layer 0 will be fit for 3 epochs, followed by layer 1
     * for 3 epochs, and so on.<br>
     * For networks with more than one input use {@link #pretrain(MultiDataSetIterator)}
     */
    public void pretrain(DataSetIterator iter, int numEpochs) {
        if (numInputArrays != 1) {
            throw new UnsupportedOperationException(
                    "Cannot train ComputationGraph network with  multiple inputs using a DataSetIterator");
        }

        pretrain(ComputationGraphUtil.toMultiDataSetIterator(iter), numEpochs);
    }

    /**
     * Pretrain network with multiple inputs and/or outputs
     */
    public void pretrain(MultiDataSetIterator iter) {
        pretrain(iter, 1);
    }

    /**
     * Pretrain network with multiple inputs and/or outputs<br>
     * This method performs layerwise pretraining on all pre-trainable layers in the network (VAEs, Autoencoders, etc), for the specified
     * number of epochs each. For example, if numEpochs=3, then layer 0 will be fit for 3 epochs, followed by layer 1
     * for 3 epochs, and so on.<br>
     * Non-pretrainable layers are ignored
     *
     * @param iter      Training data
     * @param numEpochs Number of epochs to fit each layer with
     * @see #pretrainLayer(String, MultiDataSetIterator)
     */
    public void pretrain(MultiDataSetIterator iter, int numEpochs) {
        try{
            pretrainHelper(iter, numEpochs);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private void pretrainHelper(MultiDataSetIterator iter, int numEpochs){
        if (flattenedGradients == null) {
            initGradientsView();
        }

        //Assume here that all layers are pretrainable layers
        for (int i = 0; i < topologicalOrder.length; i++) {
            if (!vertices[i].hasLayer())
                continue;
            if (vertices[i].getLayer() instanceof IOutputLayer)
                continue; //Don't pretrain output layer
            if (!vertices[i].getLayer().isPretrainLayer())
                continue; //Skip layers that aren't pretrainable

            pretrainLayerHelper(vertices[i].getVertexName(), iter, numEpochs);
        }
    }

    /**
     * Pretrain a specified layer with the given DataSetIterator
     *
     * @param layerName       Layer name
     * @param dataSetIterator Data
     */
    public void pretrainLayer(String layerName, DataSetIterator dataSetIterator) {
        if (numInputArrays != 1) {
            throw new UnsupportedOperationException(
                    "Cannot train ComputationGraph network with  multiple inputs using a DataSetIterator");
        }

        pretrainLayer(layerName, ComputationGraphUtil.toMultiDataSetIterator(dataSetIterator));
    }

    /**
     * Pretrain a specified layer with the given MultiDataSetIterator
     *
     * @param layerName Layer name
     * @param iter      Training data
     */
    public void pretrainLayer(String layerName, MultiDataSetIterator iter) {
        try{
            pretrainLayerHelper(layerName, iter, 1);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private void pretrainLayerHelper(String layerName, MultiDataSetIterator iter, int numEpochs){
        if (flattenedGradients == null) {
            initGradientsView();
        }

        if (!verticesMap.containsKey(layerName)) {
            throw new IllegalStateException("Invalid vertex name: " + layerName + " - all vertex names: " +
                    verticesMap.keySet());
        }
        if (!verticesMap.get(layerName).hasLayer()) {
            //No op
            return;
        }

        GraphVertex toTrain = verticesMap.get(layerName);
        int idx = toTrain.getVertexIndex();

        LayerWorkspaceMgr workspaceMgr;
        if(configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE){
            workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
        } else {
            workspaceMgr = LayerWorkspaceMgr.builder()
                    .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                    .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                    .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                    .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                    .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                    .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                    //Use FF/BP working memory for updater also
                    .with(ArrayType.UPDATER_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                    .build();
        }
        workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);

        if(!iter.hasNext() && iter.resetSupported())
            iter.reset();

        MultiDataSetIterator withAsync = iter.asyncSupported() ? new AsyncMultiDataSetIterator(iter) : iter;

        while(withAsync.hasNext()) {
            MultiDataSet mds = withAsync.next();
            try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS)) {
                //FF - note should be using TEST mode here for the layers that feed into the specified layer
                Map<String, INDArray> activations = ffToLayerActivationsInWS(false, idx, new int[]{idx}, FwdPassType.STANDARD,
                        false, mds.getFeatures(), mds.getFeaturesMaskArrays(), mds.getLabelsMaskArrays(), true);

                //Get input to the current layer
                VertexIndices[] inputsToLayer = toTrain.getInputVertices();
                for (VertexIndices vi : inputsToLayer) {
                    String inName = vertices[vi.getVertexIndex()].getVertexName();
                    INDArray act = activations.get(inName);
                    toTrain.setInput(vi.getVertexEdgeNumber(), act, workspaceMgr);
                }

                Layer layer = toTrain.getLayer();
                layer.fit(layer.input(), workspaceMgr);
            }
        }
    }

    /**
     * Fit the ComputationGraph using a DataSet.
     * Note that this method can only be used with ComputationGraphs with 1 input and 1 output.
     * For networks with more than one input or output, use {@link #fit(MultiDataSetIterator)}
     */
    public void fit(DataSet dataSet) {
        if (numInputArrays != 1 || numOutputArrays != 1)
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with "
                    + " multiple inputs or outputs using a DataSet");

        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            INDArray[] fMask = (dataSet.getFeaturesMaskArray() != null ? new INDArray[]{dataSet.getFeaturesMaskArray()}
                    : null);
            INDArray[] lMask = (dataSet.getLabelsMaskArray() != null ? new INDArray[]{dataSet.getLabelsMaskArray()}
                    : null);
            fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()}, fMask, lMask);
        } else {
            fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()});
        }

        if (hasMaskArrays)
            clearLayerMaskArrays();

        clearLayersStates();
    }

    /**
     * Perform minibatch training on all minibatches in the DataSetIterator, for the specified number of epochs.
     * Equvalent to calling {@link #fit(DataSetIterator)} numEpochs times in a loop
     *
     * @param iterator  Training data (DataSetIterator). Iterator must support resetting
     * @param numEpochs Number of training epochs, >= 1
     */
    public void fit(@NonNull DataSetIterator iterator, int numEpochs){
        Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs);
        Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" +
                "iterator thas does not support resetting (iterator.resetSupported() returned false)");

        for(int i=0; i<numEpochs; i++ ){
            fit(iterator);
        }
    }

    /**
     * Fit the ComputationGraph using a DataSetIterator.<br>
     * Note that this method can only be used with ComputationGraphs with 1 input and 1 output<br>
     * Method doesn't do layerwise  pretraining.<br>
     * For pretraining use method pretrain.. {@link #pretrain(DataSetIterator)}<br>
     * @param iterator Training data (DataSetIterator)
     */
    public void fit(@NonNull DataSetIterator iterator) {
        fit(new MultiDataSetIteratorAdapter(iterator));
    }

    /**
     * Fit the ComputationGraph using a MultiDataSet
     */
    public void fit(MultiDataSet multiDataSet) {
        fit(multiDataSet.getFeatures(), multiDataSet.getLabels(), multiDataSet.getFeaturesMaskArrays(),
                multiDataSet.getLabelsMaskArrays());
        if (multiDataSet.hasMaskArrays())
            clearLayerMaskArrays();
    }

    /**
     * Perform minibatch training on all minibatches in the MultiDataSetIterator, for the specified number of epochs.
     * Equvalent to calling {@link #fit(MultiDataSetIterator)} numEpochs times in a loop
     *
     * @param iterator  Training data (DataSetIterator). Iterator must support resetting
     * @param numEpochs Number of training epochs, >= 1
     */
    public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs){
        Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs);
        Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" +
                "iterator thas does not support resetting (iterator.resetSupported() returned false)");

        for(int i=0; i<numEpochs; i++ ){
            fit(iterator);
        }
    }

    /**
     * Fit the ComputationGraph using a MultiDataSetIterator
     * Method doesn't do layerwise  pretraining.<br>
     * For pretraining use method pretrain.. {@link #pretrain(MultiDataSetIterator)}<br>
     * @param multi Training data (MultiDataSetIterator)
     */
    public synchronized void fit(MultiDataSetIterator multi) {
        if (flattenedGradients == null) {
            initGradientsView();
        }

        if(!multi.hasNext() && multi.resetSupported()){
            multi.reset();
        }

        for (TrainingListener tl : trainingListeners) {
            tl.onEpochStart(this);
        }

        boolean destructable = false;

        MultiDataSetIterator multiDataSetIterator;
        if (multi.asyncSupported()) {
            multiDataSetIterator = new AsyncMultiDataSetIterator(multi, Math.max(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true);
            destructable = true;
        } else
            multiDataSetIterator = multi;

        long time1 = System.currentTimeMillis();
        while(multiDataSetIterator.hasNext()){
            MultiDataSet mds = multiDataSetIterator.next();
            long time2 = System.currentTimeMillis();
            lastEtlTime.set((time2 - time1));

            fit(mds.getFeatures(),mds.getLabels(), mds.getFeaturesMaskArrays(), mds.getLabelsMaskArrays());
            time1 = System.currentTimeMillis();
        }

        if (destructable)
            ((AsyncMultiDataSetIterator) multiDataSetIterator).shutdown();

        for (TrainingListener tl : trainingListeners) {
            tl.onEpochEnd(this);
        }

        incrementEpochCount();
    }
    /**
     * Fit the ComputationGraph given arrays of inputs and labels.
     *
     * @param inputs The network inptus
     * @param labels The labels
     */
    public void fit(INDArray[] inputs, INDArray[] labels) {
        fit(inputs, labels, null, null);
    }

    /**
     * Fit the ComputationGraph using the specified inputs and labels (and mask arrays)
     *
     * @param inputs            The network inputs (features)
     * @param labels            The network labels
     * @param featureMaskArrays Mask arrays for inputs/features. Typically used for RNN training. May be null.
     * @param labelMaskArrays   Mas arrays for the labels/outputs. Typically used for RNN training. May be null.
     */
    public void fit(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
        try{
            fitHelper(inputs, labels, featureMaskArrays, labelMaskArrays);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private synchronized void fitHelper(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
        if (numParams() == 0) {
            return; //Edge case: net with no params: fitting is a no-op
        }

        if (flattenedGradients == null) {
            initGradientsView();
        }

        setInputs(inputs);
        setLabels(labels);
        setLayerMaskArrays(featureMaskArrays, labelMaskArrays);
        update(TaskUtils.buildTask(inputs, labels));

        LayerWorkspaceMgr workspaceMgr;
        if(configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE){
            workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
        } else {
            workspaceMgr = LayerWorkspaceMgr.builder()
                    .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                    .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                    .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                    .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                    .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                    .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                    //Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM
                    // as these should be closed by the time updaters are executed
                    //Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this
                    .with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                    .build();
        }
        workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);

        if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
            doTruncatedBPTT(inputs, labels, featureMaskArrays, labelMaskArrays, workspaceMgr);
        } else {
            if (solver == null) {
                try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                    solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                }
            }

            //TODO: cache workspace
            solver.optimize(workspaceMgr);

        }

        if (featureMaskArrays != null || labelMaskArrays != null) {
            clearLayerMaskArrays();
        }

        clearLayersStates();
        synchronizeIterEpochCounts();
    }



    /**
     * Calculate a topological sort order for the vertices in the graph.
     * Note that this is used for
     * (a) working out what order to do forward pass,
     * (b) what order to do backprop (i.e., reverse of this)
     * (c) order to flatten parameters (and gradients)
     * <p>
     * Specifically, gradients/params/forward pass are executed on vertex[topologicalSortOrder[i]], for i=0..nVertices-1
     */
    public int[] topologicalSortOrder() {
        return calculateIndices().getTopologicalSortOrder();
    }

    /**
     * Calculate the indices needed for the network:<br>
     * (a) topological sort order<br>
     * (b) Map: vertex index -> vertex name<br>
     * (c) Map: vertex name -> vertex index<br>
     *
     * @return Calculated indices
     */
    public GraphIndices calculateIndices(){
        if(graphIndices != null)
            return graphIndices;


        //Get cached topological sort order from config, if present
        if(configuration.getTopologicalOrder() != null && configuration.getTopologicalOrderStr() != null){
            int[] t = configuration.getTopologicalOrder();
            List<String> s = configuration.getTopologicalOrderStr();
            Map<String,Integer> m1 = new HashMap<>();
            Map<Integer,String> m2 = new HashMap<>();
            for( int i=0; i<t.length; i++ ){
                m1.put(s.get(i), t[i]);
                m2.put(t[i], s.get(i));
            }

            graphIndices = GraphIndices.builder()
                    .topologicalSortOrder(t)
                    .nameToIdx(m1)
                    .idxToName(m2)
                    .build();
            return graphIndices;
        }


        //https://en.wikipedia.org/wiki/Topological_sorting#Kahn.27s_algorithm
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeMap = configuration.getVertices();
        List<String> networkInputNames = configuration.getNetworkInputs();
        int numVertices = networkInputNames.size() + configuration.getVertices().size();
        int[] out = new int[numVertices];
        int outCounter = 0;

        //First: represent the graph more usefully as a Map<Integer,Set<Integer>>, where map represents edges i -> j
        // key represents j, set is set of i (inputs) for vertices j
        Map<Integer, String> vertexNamesMap = new HashMap<>();
        Map<String, Integer> vertexNamesMap2 = new HashMap<>();
        int i = 0;
        for (String inputName : configuration.getNetworkInputs()) {
            vertexNamesMap.put(i, inputName);
            vertexNamesMap2.put(inputName, i);
            i++;
        }
        for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> entry : nodeMap.entrySet()) {
            String name = entry.getKey();
            vertexNamesMap.put(i, name);
            vertexNamesMap2.put(name, i);
            i++;
        }

        Map<Integer, Set<Integer>> inputEdges = new HashMap<>(); //key: vertex. Values: vertices that the key vertex receives input from
        Map<Integer, Set<Integer>> outputEdges = new HashMap<>(); //key: vertex. Values: vertices that the key vertex outputs to

        for (String s : configuration.getNetworkInputs()) {
            int idx = vertexNamesMap2.get(s);
            inputEdges.put(idx, null);
        }

        for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> entry : nodeMap.entrySet()) {
            String thisVertexName = entry.getKey();
            int idx = vertexNamesMap2.get(thisVertexName);
            List<String> inputsToThisVertex = configuration.getVertexInputs().get(thisVertexName);

            if (inputsToThisVertex == null || inputsToThisVertex.isEmpty()) {
                inputEdges.put(idx, null);
                continue;
            }

            Set<Integer> inputSet = new HashSet<>();
            for (String s : inputsToThisVertex) {
                Integer inputIdx = vertexNamesMap2.get(s);
                inputSet.add(inputIdx);
                Set<Integer> outputSetForInputIdx = outputEdges.get(inputIdx);
                if (outputSetForInputIdx == null) {
                    outputSetForInputIdx = new HashSet<>();
                    outputEdges.put(inputIdx, outputSetForInputIdx);
                }
                outputSetForInputIdx.add(idx); //input vertex outputs to the current vertex
            }

            inputEdges.put(idx, inputSet);
        }

        //Now: do topological sort
        //Set of all nodes with no incoming edges: (this would be: input vertices)
        LinkedList<Integer> noIncomingEdges = new LinkedList<>();
        for (Map.Entry<Integer, Set<Integer>> entry : inputEdges.entrySet()) {
            Set<Integer> inputsFrom = entry.getValue();
            if (inputsFrom == null || inputsFrom.isEmpty()) {
                noIncomingEdges.add(entry.getKey());
            }
        }

        while (!noIncomingEdges.isEmpty()) {
            int next = noIncomingEdges.removeFirst();
            out[outCounter++] = next; //Add to sorted list

            Set<Integer> vertexOutputsTo = outputEdges.get(next);

            //Remove edges next -> vertexOuputsTo[...] from graph;
            if (vertexOutputsTo != null) {
                for (Integer v : vertexOutputsTo) {
                    Set<Integer> set = inputEdges.get(v);
                    set.remove(next);
                    if (set.isEmpty()) {
                        noIncomingEdges.add(v); //No remaining edges for vertex i -> add to list for processing
                    }
                }
            }
        }

        //If any edges remain in the graph: graph has cycles:
        for (Map.Entry<Integer, Set<Integer>> entry : inputEdges.entrySet()) {
            Set<Integer> set = entry.getValue();
            if (set == null)
                continue;
            if (!set.isEmpty())
                throw new IllegalStateException(
                        "Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle ("
                                + "cycle includes vertex \"" + vertexNamesMap.get(entry.getKey())
                                + "\")");
        }

        //Store: the topological sort order in the configuraation... this is to ensure that when the config is
        // deserialized, it has exactly the same topological sort order on all platforms
        List<String> s = new ArrayList<>(out.length);
        for( int idx : out){
            s.add(vertexNamesMap.get(idx));
        }
        configuration.setTopologicalOrder(out);
        configuration.setTopologicalOrderStr(s);

        graphIndices = GraphIndices.builder()
                .topologicalSortOrder(out)
                .nameToIdx(vertexNamesMap2)
                .idxToName(vertexNamesMap)
                .build();
        return graphIndices;
    }

    @Override
    public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr){
        computeGradientAndScore();
    }

    public void computeGradientAndScore() {
        synchronizeIterEpochCounts();

        LayerWorkspaceMgr workspaceMgr;
        if(configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE){
            workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
        } else {
            workspaceMgr = LayerWorkspaceMgr.builder()
                    .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                    .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                    .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                    .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                    .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                    .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                    //Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM
                    // as these should be closed by the time updaters are executed
                    //Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this
                    .with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                    .build();
        }
        workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);

        boolean tbptt = configuration.getBackpropType() == BackpropType.TruncatedBPTT;
        FwdPassType fwdType = (tbptt ? FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE : FwdPassType.STANDARD);
        synchronizeIterEpochCounts();

        //Calculate activations (which are stored in each layer, and used in backprop)
        try(MemoryWorkspace wsAllActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS)) {
            Map<String, INDArray> activations = ffToLayerActivationsInWS(true, -1, getOutputLayerIndices(),
                    fwdType, tbptt, inputs, inputMaskArrays, labelMaskArrays, false);
            if (!trainingListeners.isEmpty()) {
                try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                    for (TrainingListener tl : trainingListeners) {
                        tl.onForwardPass(this, activations);
                    }
                }
            }
            calcBackpropGradients(false,false);

            workspaceMgr.assertCurrentWorkspace(ArrayType.ACTIVATIONS, null);

            //Score: sum of the scores for the various output layers...
            double r = calcRegularizationScore(true);

            score = 0.0;
            int outNum = 0;
            for (String s : configuration.getNetworkOutputs()) {
                GraphVertex gv = verticesMap.get(s);
                if(gv instanceof LayerVertex) {
                    //At this point: the input to the output layer might not be set on the layer itself - just the vertex
                    LayerVertex lv = (LayerVertex) gv;
                    if(!lv.isSetLayerInput()) {
                        lv.applyPreprocessorAndSetInput(workspaceMgr);
                    }
                }
                Layer vertexLayer = gv.getLayer();
                if (vertexLayer instanceof FrozenLayerWithBackprop) {
                    vertexLayer = ((FrozenLayerWithBackprop) vertexLayer).getInsideLayer();
                }
                vertexLayer.setMaskArray((labelMaskArrays == null) ? null : labelMaskArrays[outNum]);

                try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) {
                    score += ((IOutputLayer) vertexLayer).computeScore(r, true, workspaceMgr);
                }

                //Only want to add l1/l2 component once...
                r = 0.0;
                outNum++;
            }

            //Listeners
            if (!trainingListeners.isEmpty()) {
                try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                    for (TrainingListener tl : trainingListeners) {
                        tl.onBackwardPass(this);
                    }
                }
            }
        }

        for(GraphVertex gv : vertices){
            gv.clear();
        }
    }


    /**
     * Conduct forward pass using a single input array. Note that this method can only be used with ComputationGraphs
     * with a single input array.
     *
     * @param input The input array
     * @param layerTillIndex the layer to feed forward to
     * @param train If true: do forward pass at training time
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map<String, INDArray> feedForward(INDArray input, int layerTillIndex,boolean train) {
        if (numInputArrays != 1)
            throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with "
                    + numInputArrays + " expected inputs");
        setInput(0, input);
        return feedForward(train,layerTillIndex);
    }



    /**
     * Conduct forward pass using an array of inputs. This overload allows the forward pass to be conducted, optionally
     * (not) clearing the layer input arrays.<br>
     * Note: when using clearInputs=false, there can be some performance and memory overhead: this is because the arrays are
     * defined outside of workspaces (which are enabled by default) - otherwise, old/invalidated arrays could still be
     * accessed after calling this method. Consequently: Don't use clearInputs=false unless you have a use case that
     * requires them to remain after feed-forward has been completed
     *
     * @param input An array of ComputationGraph inputs
     * @param layerTillIndex the index of the layer to feed forward to
     * @param train If true: do forward pass at training time; false: do forward pass at test time
     * @param clearInputs If true (default for other methods): clear the inputs of all layers after doing forward
     *                    pass. False don't clear layer inputs.
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map<String, INDArray> feedForward(INDArray[] input, int layerTillIndex,boolean train, boolean clearInputs) {
        setInputs(input);
        try {
            return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layerTillIndex, null,
                    input, inputMaskArrays, labelMaskArrays, clearInputs);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }


    /**
     * Conduct forward pass using an array of inputs
     *
     * @param input An array of ComputationGraph inputs
     * @param layerTillIndex the index of the layer to feed forward to
     * @param train If true: do forward pass at training time; false: do forward pass at test time
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map<String, INDArray> feedForward(INDArray[] input, int layerTillIndex,boolean train) {
        setInputs(input);
        return feedForward(train, layerTillIndex);
    }


    /**
     * Conduct forward pass using the stored inputs
     *
     * @param train If true: do forward pass at training time; false: do forward pass at test time
     * @param layerTillIndex the index of the layer to feed forward to
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map<String, INDArray> feedForward(boolean train,int layerTillIndex) {
        int graphVertexIndexOfLayer = layers[layerTillIndex].getIndex();
        try{
            return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, graphVertexIndexOfLayer,
                null, inputs, inputMaskArrays, labelMaskArrays, true);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }



    /**
     * Conduct forward pass using a single input array. Note that this method can only be used with ComputationGraphs
     * with a single input array.
     *
     * @param input The input array
     * @param train If true: do forward pass at training time
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map<String, INDArray> feedForward(INDArray input, boolean train) {
        if (numInputArrays != 1)
            throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with "
                    + numInputArrays + " expected inputs");
        setInput(0, input);
        return feedForward(train);
    }

    /**
     * Conduct forward pass using an array of inputs
     *
     * @param input An array of ComputationGraph inputs
     * @param train If true: do forward pass at training time; false: do forward pass at test time
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map<String, INDArray> feedForward(INDArray[] input, boolean train) {
        return feedForward(input, train, true);
    }

    /**
     * Conduct forward pass using an array of inputs. This overload allows the forward pass to be conducted, optionally
     * (not) clearing the layer input arrays.<br>
     * Note: this method should NOT be used with clearInputs = true, unless you know what you are doing. Specifically:
     * when using clearInputs=false, in combination with workspaces, the layer input fields may leak outside of the
     * workspaces in which they were defined - potentially causing a crash. See <a href="https://deeplearning4j.org/docs/latest/deeplearning4j-config-workspaces">
     *     https://deeplearning4j.org/docs/latest/deeplearning4j-config-workspaces</a>
     * for more details
     *
     * @param input An array of ComputationGraph inputs
     * @param train If true: do forward pass at training time; false: do forward pass at test time
     * @param clearInputs If true (default for other methods): clear the inputs of all layers after doing forward
     *                    pass. False don't clear layer inputs.
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map<String, INDArray> feedForward(INDArray[] input, boolean train, boolean clearInputs){
        setInputs(input);
        try {
            return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, vertices.length - 1,
                    null, input, inputMaskArrays, labelMaskArrays, clearInputs);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    /**
     * Conduct forward pass using the stored inputs, at test time
     *
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map<String, INDArray> feedForward() {
        return feedForward(false);
    }

    /**
     * Conduct forward pass using the stored inputs
     *
     * @param train If true: do forward pass at training time; false: do forward pass at test time
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map<String, INDArray> feedForward(boolean train) {
        try {
            return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, vertices.length - 1,
                    null, inputs, inputMaskArrays, labelMaskArrays, true);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    /**
     * @param train                            True: training time. False: test time
     * @param excludeOutputLayers              Should we exclude the output layers during forward pass? (usually: false)
     * @param includeNonLayerVertexActivations Include non-layer vertices in the output may?
     * @return Map of activations. Key: vertex name. Value: activations.
     */
    public Map<String, INDArray> feedForward(boolean train, boolean excludeOutputLayers,
                                             boolean includeNonLayerVertexActivations) {
        int[] exclude = null;
        if(excludeOutputLayers){
            exclude = getOutputLayerIndices();
        }

        Map<String,INDArray> m = ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false,
                vertices.length-1, exclude, inputs, inputMaskArrays, labelMaskArrays, true);
        if(includeNonLayerVertexActivations){
            return m;
        } else {
            //Include only layers - in previous versions, we've always included inputs too for this method...
            Map<String,INDArray> out = new HashMap<>();
            for(Map.Entry<String,INDArray> e : m.entrySet()){
                GraphVertex v = verticesMap.get(e.getKey());
                if(v instanceof LayerVertex || v instanceof InputVertex){
                    out.put(e.getKey(), e.getValue());
                }
            }
            return out;
        }
    }

    /**
     * Return an array of network outputs (predictions) at test time, given the specified network inputs
     * Network outputs are for output layers only.
     *
     * @param input Inputs to the network
     * @return Output activations (order: same as defined in network configuration)
     */
    public INDArray[] output(INDArray... input) {
        return output(false, input, inputMaskArrays, labelMaskArrays);
    }

    /**
     * A convenience method that returns a single INDArray, instead of an INDArray[].
     * Useful for ComputationGraphs that have only a single output.
     * Otherwise identical to {@link #output(INDArray...)}
     *
     * @param input Inputs to the network
     * @return Output activations array
     */
    public INDArray outputSingle(INDArray... input) {
        return outputSingle(false, input);
    }

    /**
     * Return an array of network outputs (predictions), given the specified network inputs
     * Network outputs are for output layers only.
     *
     * @param train If true: do forward pass at training time; false: do forward pass at test time
     * @param input Inputs to the network
     * @return Output activations (order: same as defined in network configuration)
     */
    public INDArray[] output(boolean train, INDArray... input) {
        return output(train, (MemoryWorkspace)null, input);
    }

    /**
     * Return an array of network outputs (predictions), given the specified network inputs
     * Network outputs are for output layers only.<br>
     * If no memory workspace is provided, the output will be detached (not in any workspace).<br>
     * If a memory workspace is provided, the output activation array (i.e., the INDArray returned by this method)
     * will be placed in the specified workspace. This workspace must be opened by the user before calling this method -
     * and the user is responsible for (a) closing this workspace, and (b) ensuring the output array is not used out
     * of scope (i.e., not used after closing the workspace to which it belongs - as this is likely to cause either
     * an exception when used, or a crash).
     *
     * @param train           If true: do forward pass at training time; false: do forward pass at test time
     * @param outputWorkspace May be null. If not null: the workspace MUST be opened before calling this method.
     * @param input           Inputs to the network
     * @return Output activations (order: same as defined in network configuration)
     */
    public INDArray[] output(boolean train, MemoryWorkspace outputWorkspace, INDArray... input) {
        return output(train, input, inputMaskArrays, labelMaskArrays, outputWorkspace);
    }

    /**
     * Return an array of network outputs (predictions), given the specified network inputs
     * Network outputs are for output layers only.
     *
     * @param train      If true: forward pass for training mode. False: test mode
     * @param input      Input arrays to the netwonk
     * @param inputMasks Optional input mask arrays (may be null)
     * @return           Network output activations
     */
    public INDArray[] output(boolean train, @NonNull INDArray[] input, INDArray[] inputMasks){
        return output(train, input, inputMasks, null);
    }

    /**
     * Return an array of network outputs (predictions), given the specified network inputs
     * Network outputs are for output layers only.
     *
     * @param train      If true: forward pass for training mode. False: test mode
     * @param input      Input arrays to the netwonk
     * @param inputMasks Optional input mask arrays (may be null)
     * @param labelMasks Optional label mask arrays (may be null
     * @return           Network output activations
     */
    public INDArray[] output(boolean train, @NonNull INDArray[] input, INDArray[] inputMasks, INDArray[] labelMasks) {
        return output(train, input, inputMasks, labelMasks, null);
    }

    /**
     * This method uses provided OutputAdapter to return custom object built from INDArray
     *
     * PLEASE NOTE: This method uses dedicated Workspace for output generation to avoid redundant allocations
     *
     * @param inputs Input arrays to the netwonk
     * @param inputMasks Optional input mask arrays (may be null)
     * @param labelMasks Optional label mask arrays (may be null
     * @param outputAdapter OutputAdapter<T> instance
     * @param <T> T extends Object
     * @return T instance produced by OutputAdapter
     */
    public synchronized <T> T output(@NonNull INDArray[] inputs, INDArray[] inputMasks, INDArray[] labelMasks, @NonNull OutputAdapter<T> outputAdapter) {
        try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM)) {
            if (outputAdapter instanceof ModelAdapter)
                return ((ModelAdapter<T>) outputAdapter).apply(this, inputs, inputMasks, labelMasks);
            else
                return outputAdapter.apply(output(false, inputs, inputMasks, labelMasks, ws));
        }
    }

    /**
     * Return an array of network outputs (predictions), given the specified network inputs
     * Network outputs are for output layers only.<br>
     * If no memory workspace is provided, the output will be detached (not in any workspace).<br>
     * If a memory workspace is provided, the output activation array (i.e., the INDArray returned by this method)
     * will be placed in the specified workspace. This workspace must be opened by the user before calling this method -
     * and the user is responsible for (a) closing this workspace, and (b) ensuring the output array is not used out
     * of scope (i.e., not used after closing the workspace to which it belongs - as this is likely to cause either
     * an exception when used, or a crash).
     *
     * @param train           If true: forward pass for training mode. False: test mode
     * @param input           Input arrays to the netwonk
     * @param inputMasks      Optional input mask arrays (may be null)
     * @param labelMasks      Optional label mask arrays (may be null
     * @param outputWorkspace May be null. If not null: the workspace MUST be opened before calling this method.
     * @return Network output activations
     */
    public synchronized INDArray[] output(boolean train, @NonNull INDArray[] input, INDArray[] inputMasks, INDArray[] labelMasks, MemoryWorkspace outputWorkspace){
        try {
            setLayerMaskArrays(inputMasks, labelMasks);
            INDArray[] out = outputOfLayersDetached(train, FwdPassType.STANDARD, getOutputLayerIndices(), input, inputMasks, labelMasks, true, false, outputWorkspace);
            clearLayerMaskArrays();
            clearLayersStates();
            return out;
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    /**
     * A convenience method that returns a single INDArray, instead of an INDArray[].
     * Useful for ComputationGraphs that have only a single output.
     * Otherwise identical to {@link #output(boolean, INDArray...)}
     *
     * @param train If true: do forward pass at training time; false: do forward pass at test time
     * @param input Inputs to the network
     * @return Output activations array
     */
    public INDArray outputSingle(boolean train, INDArray... input) {
        return outputSingle(train, true, input);
    }

    /**
     * Identical to {@link #outputSingle(boolean, boolean, INDArray...)} but has the option of not clearing the input
     * arrays (useful when later backpropagating external errors). Most users should use {@link #outputSingle(boolean, INDArray...)}
     * in preference to this method.
     */
    public INDArray outputSingle(boolean train, boolean clearInputs, INDArray... input){
        if (numOutputArrays != 1) {
            throw new IllegalStateException(
                    "Cannot use outputSingle with ComputationGraph that does not have exactly 1 output. nOutputs: "
                            + numOutputArrays);
        }
        return output(train, clearInputs, input)[0];
    }

    /**
     * An output method for the network, with optional clearing of the layer inputs.<br>
     * Note: most users should use {@link #output(boolean, INDArray...)} or similar methods, unless they are doing
     * non-standard operations (like providing the input arrays externally)
     *
     * @param train       If true: output during training. False: output during testing. Affects some things such as
     *                    dropout
     * @param clearInputs If true: clear the input arrays for all layers. False: leave the input arrays as-is - which
     *                    can be useful for "external errors" (no output layer) backprop use cases
     * @param input       Input to the network
     * @return            Output from the network
     */
    public synchronized INDArray[] output(boolean train, boolean clearInputs, INDArray... input){
        boolean detachedInputs = !clearInputs;  //If !clearInputs, then inputs should be detached (otherwise: will be out of scope)
        try {
            return outputOfLayersDetached(train, FwdPassType.STANDARD, getOutputLayerIndices(), input, null, null, clearInputs, detachedInputs, null);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    /**
     * Generate the output for all examples/batches in the input iterator, and concatenate them into a single array
     * per network output
     *
     * @param iterator Data to pass through the network
     * @return output for all examples in the iterator
     */
    public INDArray[] output(DataSetIterator iterator){
        return output(new MultiDataSetIteratorAdapter(iterator));
    }

    /**
     * Generate the output for all examples/batches in the input iterator, and concatenate them into a single array
     * per network output
     *
     * @param iterator Data to pass through the network
     * @return output for all examples in the iterator
     */
    public INDArray[] output(MultiDataSetIterator iterator){
        List<INDArray[]> outputs = new ArrayList<>();
        while(iterator.hasNext()){
            MultiDataSet next = iterator.next();
            INDArray[] out = output(false, next.getFeatures(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
            outputs.add(out);
        }
        INDArray[][] arr = outputs.toArray(new INDArray[outputs.size()][0]);
        return DataSetUtil.mergeFeatures(arr, null).getFirst();
    }

    /**
     * Generate the output for all examples/batches in the input iterator, and concatenate them into a single array.
     * Can only be used with ComputationGraphs with 1 output
     *
     * @param iterator Data to pass through the network
     * @return output for all examples in the iterator
     */
    public INDArray outputSingle(DataSetIterator iterator){
        Preconditions.checkArgument(numOutputArrays == 1, "Cannot use this method with nets that have more" +
                " than 1 output array. This network has %s outputs", numOutputArrays);
        return output(iterator)[0];
    }

    /**
     * Generate the output for all examples/batches in the input iterator, and concatenate them into a single array.
     * Can only be used with ComputationGraphs with 1 output
     *
     * @param iterator Data to pass through the network
     * @return output for all examples in the iterator
     */
    public INDArray outputSingle(MultiDataSetIterator iterator){
        Preconditions.checkArgument(numOutputArrays == 1, "Cannot use this method with nets that have more" +
                " than 1 output array. This network has %s outputs", numOutputArrays);
        return output(iterator)[0];
    }

    /**
     * Get the activations for the specific layers only
     * @param layers       Layers to get the specified activations for
     * @param train        If true: train mode. False: test (inference) mode
     * @param features     Features array
     * @param featureMasks Feature masks array. May be null
     * @return Activations of the selected layers, in the same order as the "layers" arg/list
     */
    public INDArray[] output(List<String> layers, boolean train, INDArray[] features, INDArray[] featureMasks){
        Preconditions.checkState(layers != null && layers.size() > 0, "Layers must not be null: got later names %s", layers);
        int[] layerNums = new int[layers.size()];
        for( int i=0; i<layers.size(); i++ ){
            String n = layers.get(i);
            Preconditions.checkState(verticesMap.containsKey(n), "Layer with name %s not found in network", n);
            layerNums[i] = verticesMap.get(n).getVertexIndex();
        }
        INDArray[] out = outputOfLayersDetached(train, FwdPassType.STANDARD, layerNums, features, featureMasks, null, true,
                false, null);
        return out;
    }


    protected void validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op){
        try{
            mgr.validateArrayLocation(arrayType, array, false, isInputVertex);
        } catch (ND4JWorkspaceException e){
            String clazz;
            GraphVertex v = verticesMap.get(vertexName);
            if(v instanceof LayerVertex){
                clazz = v.getLayer().getClass().getSimpleName();
            } else {
                clazz = v.getClass().getSimpleName();
            }
            throw new IllegalStateException(op + ": array (" + arrayType + ") workspace validation failed (vertex " +
                    vertexName + " - class: " + clazz + ") - array is defined in incorrect workspace", e);
        }
    }

    /**
     * Feed-forward through the network - returning all array activations detached from any workspace.
     * Note that no workspace should be active externally when calling this method (an exception will be thrown
     * if a workspace is open externally)
     *
     * @param train             Training mode (true) or test/inference mode (false)
     * @param fwdPassType       Type of forward pass to perform (STANDARD or RNN_ACTIVATE_WITH_STORED_STATE only)
     * @param storeLastForTBPTT ONLY used if fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE
     * @param layerIndex        Index (inclusive) to stop forward pass at. For all layers, use numLayers-1
     * @param excludeIdxs       Layers (vertices) to exclude from forward pass. These layers will be skipped, and hence
     *                          are usually output layers or at the end of the network. May be null.
     * @param features          Input feature arrays
     * @param fMask             Feature mask arrays. May be null.
     * @param lMask             Label mask array. May be null.
     * @param clearLayers       Whether the layer inputs should be cleared
     * @return Map of activations (including the input), detached from any workspace
     */
    protected synchronized Map<String,INDArray> ffToLayerActivationsDetached(boolean train, @NonNull FwdPassType fwdPassType, boolean storeLastForTBPTT,
                                                                int layerIndex, int[] excludeIdxs, @NonNull INDArray[] features,
                                                                INDArray[] fMask, INDArray[] lMask, boolean clearLayers){
        if(layerIndex < 0 || layerIndex >= topologicalOrder.length){
            throw new IllegalArgumentException("Invalid layer index - index must be >= 0 and < " + topologicalOrder.length
                    + ", got index " + layerIndex);
        }

        setInputs(features);
        setLayerMaskArrays(fMask, lMask);

        //Verify that no workspace is open externally
        WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active before call to ffToLayerActivationsDetached", true);

        LayerWorkspaceMgr workspaceMgr;
        WorkspaceMode wsm = (train ? configuration.getTrainingWorkspaceMode() : configuration.getInferenceWorkspaceMode());
        if (wsm == WorkspaceMode.NONE) {
            workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
        } else {
            workspaceMgr = LayerWorkspaceMgr.builder()
                    .noWorkspaceFor(ArrayType.ACTIVATIONS)
                    .with(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                    .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                    .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                    .build();

            if(features[0].isAttached()){
                //Don't leverage out of async DataMultiSetIterator workspaces
                workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId());
            }
        }
        workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);

        Map<String, INDArray> activations = new HashMap<>();

        //Add the inputs:
        for( int i=0; i<features.length; i++){
            activations.put(configuration.getNetworkInputs().get(i), features[i]);
        }

        boolean traceLog = log.isTraceEnabled();

        //Do forward pass according to the topological ordering of the network
        for (int i = 0; i <= layerIndex; i++) {
            GraphVertex current = vertices[topologicalOrder[i]];
            String vName = current.getVertexName();
            int vIdx = current.getVertexIndex();

            if(excludeIdxs != null && ArrayUtils.contains(excludeIdxs, vIdx)){
                continue;
            }

            if(traceLog){
                log.trace("About forward pass: {} (\"{}\") - {}", i, vName, current.getClass().getSimpleName());
            }

            try(MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)){
                VertexIndices[] inputsTo = current.getOutputVertices();

                INDArray out;
                if(current.isInputVertex()){
                    out = inputs[vIdx];
                } else {

                    if(fwdPassType == FwdPassType.STANDARD) {
                        //Standard feed-forward case
                        out = current.doForward(train, workspaceMgr);
                    } else if(fwdPassType == FwdPassType.RNN_TIMESTEP){
                        if (current.hasLayer()) {
                            //Layer
                            INDArray input = current.getInputs()[0];
                            Layer l = current.getLayer();
                            if (l instanceof RecurrentLayer) {
                                out = ((RecurrentLayer) l).rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
                            }  else if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer){
                                RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying());
                                out = rl.rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
                            } else if (l instanceof MultiLayerNetwork) {
                                out = ((MultiLayerNetwork) l).rnnTimeStep(reshapeTimeStepInput(input));
                            } else {
                                //non-recurrent layer
                                out = current.doForward(train, workspaceMgr);
                            }
                        } else {
                            //GraphNode
                            out = current.doForward(train, workspaceMgr);
                        }
                    } else if(fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) {
                        if (current.hasLayer()) {
                            Layer l = current.getLayer();
                            if (l instanceof RecurrentLayer) {
                                out = ((RecurrentLayer) l).rnnActivateUsingStoredState(current.getInputs()[0], train, storeLastForTBPTT, workspaceMgr);
                            } else if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer) {
                                RecurrentLayer rl = (RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying();
                                out = rl.rnnActivateUsingStoredState(current.getInputs()[0], train,storeLastForTBPTT, workspaceMgr);
                            } else if (l instanceof MultiLayerNetwork) {
                                List<INDArray> temp = ((MultiLayerNetwork) l).rnnActivateUsingStoredState(
                                        current.getInputs()[0], train, storeLastForTBPTT);
                                out = temp.get(temp.size() - 1);
                            } else {
                                //non-recurrent layer
                                out = current.doForward(train, workspaceMgr);
                            }
                        } else {
                            out = current.doForward(train, workspaceMgr);
                        }
                    } else {
                        throw new IllegalArgumentException("Unsupported forward pass type for this method: " + fwdPassType);
                    }
                    validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
                }
                activations.put(current.getVertexName(), out);

                if(inputsTo != null) {  //May be null for output vertices (which don't feed into any other vertices)
                    for (VertexIndices v : inputsTo) {
                        //Note that we don't have to do anything special here: the activations are always detached in
                        // this method
                        int inputToIndex = v.getVertexIndex();
                        int vIdxEdge = v.getVertexEdgeNumber();
                        vertices[inputToIndex].setInput(vIdxEdge, out, workspaceMgr);
                    }
                }

                if(clearLayers) {
                    current.clear();
                }
            }

            if(traceLog){
                log.trace("Completed forward pass: {} (\"{}\") - {}", i, vName, current.getClass().getSimpleName());
            }
        }

        return activations;
    }

    /**
     * Feed-forward through the network - if workspaces are used, all returned activations will be present in workspace
     * WS_ALL_LAYERS_ACT.<br>
     * Note: if using workspaces for training, requires that WS_ALL_LAYERS_ACT is open externally.
     * If using NO workspaces, requires that no external workspace is open
     *
     * @param train             Training mode (true) or test/inference mode (false)
     * @param layerIndex        Index (inclusive) to stop forward pass at. For all layers, use -1
     * @param excludeIdxs       Layers (vertices) to exclude from forward pass. These layers will be skipped, and hence
     *                          are usually output layers or at the end of the network. May be null.
     * @param fwdPassType       Type of forward pass to perform (STANDARD or RNN_ACTIVATE_WITH_STORED_STATE only)
     * @param storeLastForTBPTT ONLY used if fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE
     * @param input             Input feature arrays
     * @param fMask             Feature mask arrays. May be null.
     * @param lMask             Label mask array. May be null.
     * @param clearInputs       Whether the layer inputs should be cleared
     * @return Map of activations (including the input), in workspace WS_ALL_LAYERS_ACT if workspaces are used (detached
     * otherwise)
     */
    protected synchronized Map<String,INDArray> ffToLayerActivationsInWS(boolean train, int layerIndex, int[] excludeIdxs,
                                                            FwdPassType fwdPassType, boolean storeLastForTBPTT,
                                                            INDArray[] input, INDArray[] fMask, INDArray[] lMask, boolean clearInputs) {
        if(layerIndex != -1 && (layerIndex < 0 || layerIndex >= topologicalOrder.length)){
            throw new IllegalArgumentException("Invalid input index - index must be >= 0 and < " + topologicalOrder.length
                    + ", got index " + layerIndex);
        }
        setInputs(input);
        setLayerMaskArrays(fMask, lMask);

        LayerWorkspaceMgr workspaceMgr;
        WorkspaceMode wsm = (train ? configuration.getTrainingWorkspaceMode() : configuration.getInferenceWorkspaceMode());
        if(wsm == WorkspaceMode.NONE){
            //Verify that no workspace is open externally
            WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in ffToLayerActivationsDetached", true);

            workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
        } else {
            WorkspaceUtils.assertOpenAndActive(WS_ALL_LAYERS_ACT, "ffToLayerActivationsInWs method requires workspace WS_ALL_LAYERS_ACT to be open");

            workspaceMgr = LayerWorkspaceMgr.builder()
                    .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                    .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                    .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                    .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                    .build();

            if(input[0].isAttached()){
                //Don't leverage out of async DataMultiSetIterator workspaces
                workspaceMgr.setNoLeverageOverride(input[0].data().getParentWorkspace().getId());
            }

            if(configuration.getCacheMode() != CacheMode.NONE){
                //For now: store cache mode activations in activations workspace
                workspaceMgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG);
            }
        }
        workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);

        boolean traceLog = log.isTraceEnabled();

        Map<String, INDArray> activations = new HashMap<>();
        //Do forward pass according to the topological ordering of the network
        int stopIndex;
        if (layerIndex > 0) {
            stopIndex = ArrayUtils.indexOf(topologicalOrder, layerIndex);
        } else {
            stopIndex = topologicalOrder.length -1;
        }
        for (int i = 0; i <= stopIndex; i++) {
            GraphVertex current = vertices[topologicalOrder[i]];
            String vName = current.getVertexName();
            int vIdx = current.getVertexIndex();

            if(traceLog){
                log.trace("About forward pass: {} (\"{}\") - {}", i, vName, current.getClass().getSimpleName());
            }

            if(excludeIdxs != null && ArrayUtils.contains(excludeIdxs, vIdx)){
                continue;
            }

            try(MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)){
                VertexIndices[] inputsTo = current.getOutputVertices();

                INDArray out;
                if(current.isInputVertex()){
                    out = inputs[vIdx];
                } else {

                    if(fwdPassType == FwdPassType.STANDARD){
                        out = current.doForward(train, workspaceMgr);
                    } else if(fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) {
                        if (current.hasLayer()) {
                            Layer l = current.getLayer();
                            if (l instanceof RecurrentLayer) {
                                out = ((RecurrentLayer) l).rnnActivateUsingStoredState(current.getInputs()[0], train,
                                        storeLastForTBPTT, workspaceMgr);
                            } else if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer) {
                                RecurrentLayer rl = (RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying();
                                out = rl.rnnActivateUsingStoredState(current.getInputs()[0], train,storeLastForTBPTT, workspaceMgr);
                            } else if (l instanceof MultiLayerNetwork) {
                                List<INDArray> temp = ((MultiLayerNetwork) l).rnnActivateUsingStoredState(
                                        current.getInputs()[0], train, storeLastForTBPTT);
                                out = temp.get(temp.size() - 1);
                            } else {
                                //non-recurrent layer
                                out = current.doForward(train, workspaceMgr);
                            }
                        } else {
                            out = current.doForward(train, workspaceMgr);
                        }
                    } else {
                        throw new IllegalStateException("FwdPassType not supported for this method: " + fwdPassType);
                    }

                    validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
                }
                activations.put(current.getVertexName(), out);

                if(inputsTo != null) {
                    //Can be null for output layers
                    for (VertexIndices v : inputsTo) {
                        //Note that we don't have to do anything special here: the activations are always detached in
                        // this method
                        int inputToIndex = v.getVertexIndex();
                        int vIdxEdge = v.getVertexEdgeNumber();
                        vertices[inputToIndex].setInput(vIdxEdge, out, workspaceMgr);
                    }
                }

                if(clearInputs) {
                    current.clear();
                }
            }

            if(traceLog){
                log.trace("Completed forward pass: {} (\"{}\") - {}", i, vName, current.getClass().getSimpleName());
            }
        }
        return activations;
    }


    /**
     * Provide the output of the specified layers, detached from any workspace. This is most commonly used at inference/test
     * time, and is more memory efficient than {@link #ffToLayerActivationsDetached(boolean, FwdPassType, boolean, int, int[], INDArray[], INDArray[], INDArray[], boolean)}
     * and {@link #ffToLayerActivationsInWS(boolean, int, int[], FwdPassType, boolean, INDArray[], INDArray[], INDArray[], boolean)}.<br>
     * This method clears all layer inputs.
     *
     * NOTE: in general, no workspaces should be activated externally for this method!
     * This method handles the workspace activation as required
     *
     * @param train             Training mode (true) or test/inference mode (false)
     * @param fwdPassType       Type of forward pass to perform (STANDARD or RNN_TIMESTEP only)
     * @param layerIndexes      Indexes of the layers to get the activations for
     * @param features          Input features for the network
     * @param fMask             Input/feature mask array. May be null.
     * @param lMasks            Labels mask array. May be null
     * @param clearLayerInputs  If true: the layer input fields will be cleared
     * @param detachedInputs    If true: the layer input fields will be detached. Usually used for external errors cases
     * @param outputWorkspace   Optional - if provided, outputs should be placed in this workspace. NOTE: this workspace
     *                          must be open
     * @return                  Output of the specified layers, detached from any workspace
     */
    protected INDArray[] outputOfLayersDetached(boolean train, @NonNull FwdPassType fwdPassType, @NonNull int[] layerIndexes, @NonNull INDArray[] features,
                                                INDArray[] fMask, INDArray[] lMasks, boolean clearLayerInputs, boolean detachedInputs, MemoryWorkspace outputWorkspace){
        if(features.length != numInputArrays){
            throw new IllegalArgumentException("Invalid number of input arrays: network has " + numInputArrays
                    + " inputs, got " + features.length + " input arrays");
        }
        for( int i=0; i<layerIndexes.length; i++ ) {
            if(layerIndexes[i] < 0 || layerIndexes[i] >= topologicalOrder.length) {
                throw new IllegalArgumentException("Invalid input index - index must be >= 0 and < " + topologicalOrder.length
                        + ", got index " + layerIndexes[i]);
            }
        }
        setInputs(features);
        setLayerMaskArrays(fMask, lMasks);

        MemoryWorkspace outputPrevious = null;
        if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) {
            //Verify that no workspace is open externally
            WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active before call to outputOfLayersDetached");
        } else {
            Preconditions.checkState(outputWorkspace.isScopeActive(), "Workspace \"" + outputWorkspace.getId() +
                    "\" was provided for the network/layer outputs. When provided, this workspace must be opened before " +
                    "calling the output method; furthermore, closing the workspace is the responsibility of the user");
            outputPrevious = outputWorkspace.getParentWorkspace();
        }


        //First: for each vertex, determine the highest index of the vertex that consumes it's output
        //Then: for each vertex, determine the forward pass step that each vertex's output has been fully consumed on
        //In other words, if vertex X -> Y and X -> Z, and topological sort order is X,a,Y,b,Z,
        //Then we know X's output activations have been fully consumed by step index 4 in the topological sort
        //thus vertexOutputsFullyConsumedByStep[X.index] = IndexOf(topologicalSort, Z.index)

        //Position in array: index of vertex. Value at position: the step (in topological order) that the activations
        // have been consumed by
        //Put another way: this is the step that it's safe to deallocate the layer's activations by closing the
        // corresponding workspace
        int[] vertexOutputsFullyConsumedByStep = new int[topologicalOrder.length];
        for(GraphVertex gv : vertices){
            int idx = gv.getVertexIndex();
            int maxStepOfOutputTo = -1;
            VertexIndices[] outputsTo = gv.getOutputVertices();
            if(outputsTo != null) {
                //May be null for final/output layers
                for (VertexIndices vi : outputsTo) {
                    int posInTopoSort = ArrayUtils.indexOf(topologicalOrder, vi.getVertexIndex());
                    if (posInTopoSort == -1) {
                        throw new IllegalStateException("Did not find vertex " + vi.getVertexIndex() + " in topological sort array");
                    }
                    maxStepOfOutputTo = Math.max(maxStepOfOutputTo, posInTopoSort);
                }
            } else {
                maxStepOfOutputTo = topologicalOrder.length-1;
            }
            vertexOutputsFullyConsumedByStep[idx] = maxStepOfOutputTo;
        }

        //Do forward pass according to the topological ordering of the network
        INDArray[] outputs = new INDArray[layerIndexes.length];
        int stopIndex = -1;
        for( int i=0; i<layerIndexes.length; i++ ){
            stopIndex = Math.max(stopIndex, ArrayUtils.indexOf(topologicalOrder, layerIndexes[i]));
        }
        List<LayerWorkspaceMgr> allWorkspaceManagers = new ArrayList<>();
        List<LayerWorkspaceMgr> freeWorkspaceManagers = new ArrayList<>();  //Basically used as a stack
        Map<MemoryWorkspace, LayerWorkspaceMgr> openActivationsWorkspaces = new IdentityHashMap<>();

        WorkspaceMode wsm = (train ? configuration.getTrainingWorkspaceMode() : configuration.getInferenceWorkspaceMode());
        boolean noWS = wsm == WorkspaceMode.NONE;
        LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
        List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
        MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        Throwable t = null;
        try {
            for (int i = 0; i <= stopIndex; i++) {
                GraphVertex current = vertices[topologicalOrder[i]];
                String vName = current.getVertexName();
                int vIdx = current.getVertexIndex();

                //First: determine what workspace manager we should use for forward pass in this vertex
                LayerWorkspaceMgr workspaceMgr;
                if (noWS) {
                    workspaceMgr = allNone;
                } else {
                    //First: is there a free forward pass workspace we can use?
                    if (freeWorkspaceManagers.size() > 0) {
                        workspaceMgr = freeWorkspaceManagers.remove(freeWorkspaceManagers.size() - 1);
                    } else {
                        //No existing free workspace managers for forward pass - create a new one...
                        String wsName = "WS_LAYER_ACT_" + allWorkspaceManagers.size();
                        workspaceMgr = LayerWorkspaceMgr.builder()
                                .with(ArrayType.INPUT, wsName, WS_LAYER_ACT_X_CONFIG)
                                .with(ArrayType.ACTIVATIONS, wsName, WS_LAYER_ACT_X_CONFIG)
                                .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                                .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                                .build();

                        if (detachedInputs) {
                            //Sometimes (like: external errors use cases) we don't want the activations/inputs to be
                            // in a workspace
                            workspaceMgr.setScopedOutFor(ArrayType.INPUT);
                            workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
                        } else {
                            //Don't leverage out of async MultiDataSetIterator workspaces
                            if (features[0].isAttached()) {
                                workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId());
                            }
                        }

                        allWorkspaceManagers.add(workspaceMgr);
                    }
                }
                workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);

                //Is this one of the layers/vertices that we want the output for?
                boolean isRequiredOutput = false;
                String origWSAct = null;
                WorkspaceConfiguration origWSActConf = null;
                if (ArrayUtils.contains(layerIndexes, vIdx)) {
                    isRequiredOutput = true;

                    if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
                        //Place activations in user-specified workspace
                        origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
                        origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
                        workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration());
                    } else {
                        //Standard case
                        if (!workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS)) {
                            //Activations/output to return: don't want this in any workspace
                            origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
                            origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
                            workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
                        }
                    }
                }

                //Open the relevant workspace for the activations.
                //Note that this will be closed only once the current vertex's activations have been consumed
                MemoryWorkspace wsActivations = null;
                if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput) {    //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS
                    wsActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS);
                    openActivationsWorkspaces.put(wsActivations, workspaceMgr);
                }

                //Note that because we're opening activation workspaces not in any defined order (i.e., workspace
                // use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we
                // close these workspaces, the "current" workspace may be set to the incorrect one
                if (wsActivations != null)
                    wsActivations.setPreviousWorkspace(initialWorkspace);

                int closeableAt = vertexOutputsFullyConsumedByStep[vIdx];
                if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) {
                    if (closeAtEndIteraton[closeableAt] == null) {
                        closeAtEndIteraton[closeableAt] = new ArrayList<>();
                    }
                    closeAtEndIteraton[closeableAt].add(wsActivations);
                }


                try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) {
                    VertexIndices[] inputsTo = current.getOutputVertices();

                    INDArray out;
                    if (current.isInputVertex()) {
                        out = features[vIdx];
                    } else {

                        if (fwdPassType == FwdPassType.STANDARD) {
                            //Standard feed-forward case
                            out = current.doForward(train, workspaceMgr);
                        } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
                            if (current.hasLayer()) {
                                //Layer
                                INDArray input = current.getInputs()[0];
                                Layer l = current.getLayer();
                                if (l instanceof RecurrentLayer) {
                                    out = ((RecurrentLayer) l).rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
                                } else if (l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying() instanceof RecurrentLayer) {
                                    RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying());
                                    out = rl.rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
                                } else if (l instanceof MultiLayerNetwork) {
                                    out = ((MultiLayerNetwork) l).rnnTimeStep(reshapeTimeStepInput(input));
                                } else {
                                    //non-recurrent layer
                                    out = current.doForward(train, workspaceMgr);
                                }
                            } else {
                                //GraphNode
                                out = current.doForward(train, workspaceMgr);
                            }
                        } else {
                            throw new IllegalArgumentException("Unsupported forward pass type for this method: " + fwdPassType);
                        }
                        validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
                    }

                    if (inputsTo != null) {  //Output vertices may not input to any other vertices
                        for (VertexIndices v : inputsTo) {
                            //Note that we don't have to do anything special here: the activations are always detached in
                            // this method
                            int inputToIndex = v.getVertexIndex();
                            int vIdxEdge = v.getVertexEdgeNumber();
                            vertices[inputToIndex].setInput(vIdxEdge, out, workspaceMgr);
                        }
                    }

                    if (clearLayerInputs) {
                        current.clear();
                    }

                    if (isRequiredOutput) {
                        outputs[ArrayUtils.indexOf(layerIndexes, vIdx)] = out;
                        if (origWSAct != null) {
                            //Reset the configuration, as we may reuse this workspace manager...
                            workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, origWSAct, origWSActConf);
                        }
                    }
                }

                //Close any activations workspaces that we no longer require
                //Note that activations workspaces can be closed only once the corresponding output activations have
                // been fully consumed
                if (closeAtEndIteraton[i] != null) {
                    for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
                        wsAct.close();
                        LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
                        freeWorkspaceManagers.add(canNowReuse);
                    }
                }
            }
        } catch (Throwable t2){
            t = t2;
        } finally {
            //Close all open workspaces... usually this list will be empty, but not if an exception is thrown
            //Though if stopIndex < numLayers, some might still be open
            for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
                while (ws.isScopeActive()) {
                    //Edge case here: seems that scoping out can increase the tagScope of the current WS
                    //and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
                    // number of times to actually close it, in all cases
                    try{
                        ws.close();
                    } catch (Throwable t2){
                        if(t != null){
                            log.error("Encountered second exception while trying to close workspace after initial exception");
                            log.error("Original exception:", t);
                            throw t2;
                        }
                    }
                }
            }
            Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);

            if(t != null){
                if(t instanceof RuntimeException){
                    throw ((RuntimeException)t);
                }
                throw new RuntimeException("Error during neural network forward pass", t);
            }

            if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) {
                WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached");
            } else {
                Preconditions.checkState(outputWorkspace.isScopeActive(), "Expected output workspace to still be open" +
                        "at end of outputOfLayerDetached, but ");
                outputWorkspace.setPreviousWorkspace(outputPrevious);
            }
        }

        return outputs;
    }

    private INDArray reshapeTimeStepInput(INDArray input) {
        if (input.rank() == 2) { // dynamically reshape to 3D input with one time-step.
            long[] inShape = input.shape();
            input = input.reshape(inShape[0], inShape[1], 1);
        }
        return input;
    }


    /**
     * Calculate the gradient of the network with respect to some external errors.
     * Note that this is typically used for things like reinforcement learning, not typical networks that include
     * an OutputLayer or RnnOutputLayer
     *
     * @param epsilons Epsilons (errors) at the output. Same order with which the output layers are defined in configuration setOutputs(String...)
     * @return Gradient for the network
     */
    public Gradient backpropGradient(INDArray... epsilons) {
        if (epsilons == null || epsilons.length != numOutputArrays)
            throw new IllegalArgumentException(
                    "Invalid input: must have epsilons length equal to number of output arrays");


        try {
            calcBackpropGradients(true, configuration.getBackpropType() == BackpropType.TruncatedBPTT, epsilons);
            return gradient;
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    /**
     * Do backprop (gradient calculation)
     *
     * @param truncatedBPTT    false: normal backprop. true: calculate gradients using truncated BPTT for RNN layers
     * @param externalEpsilons null usually (for typical supervised learning). If not null (and length > 0) then assume that
     *                         the user has provided some errors externally, as they would do for example in reinforcement
     *                         learning situations.
     */
    protected void calcBackpropGradients(boolean clearLayers, boolean truncatedBPTT, INDArray... externalEpsilons) {
        if (flattenedGradients == null) {
            initGradientsView();
        }

        /*
         Design for workspaces use in backprop for ComputationGraph is similar to MultiLayerNetwork and shares some
         features with outputOfLayersDetached

         Specifically:
         1. We assume forward pass has already been done, and hence layer input fields are set (with all arrays/activations in
            workspace WS_ALL_LAYERS_ACT if appropriate)
         2. We use a set of small workspaces to contain the activation gradients for a single layer
            These are opened once per layer, and are closed only once the corresponding activation gradients have been
            consumed by all layers
         */

        if(externalEpsilons == null || externalEpsilons.length == 0 && configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE){
            WorkspaceUtils.assertOpenAndActive(WS_ALL_LAYERS_ACT, "Expected workspace WS_ALL_LAYERS_ACT to be active and open" +
                    " in calcBackpropGradients when workspace mode is not set to NONE");
        }

        //Validate the network configuration for external errors - no output layers
        if(externalEpsilons != null && externalEpsilons.length > 0){
            List<String> outputLayers = configuration.getNetworkOutputs();
            for(String s : outputLayers ){
                GraphVertex gv = getVertex(s);
                if(gv instanceof LayerVertex && ((LayerVertex)gv).getLayer() instanceof IOutputLayer){
                    throw new IllegalStateException("Cannot perform backprop with external errors in conjunction with an output layer:" +
                            " output layers cannot use external errors for backprop. Layer name: " + s);
                }
            }

        }

        //Position in array: index of vertex. Value at position: the step (in topological order) that the activation
        // gradients of the specified vertex have been consumed by
        //Put another way: this is the step that it's safe to deallocate the layer's activation gradients by closing the
        // corresponding workspace
        //TODO we can probably cache this...
        int[] vertexActGradsFullyConsumedByStep = new int[topologicalOrder.length];
        for(GraphVertex gv : vertices){
            int idx = gv.getVertexIndex();
            int minStepOfInputFrom = Integer.MAX_VALUE;
            VertexIndices[] inputsFrom = gv.getInputVertices();
            if(inputsFrom != null) {
                //inputsFrom may be null for input vertex
                for (VertexIndices vi : inputsFrom) {
                    int posInTopoSort = ArrayUtils.indexOf(topologicalOrder, vi.getVertexIndex());
                    if (posInTopoSort == -1) {
                        throw new IllegalStateException("Did not find vertex " + vi.getVertexIndex() + " in topological sort array");
                    }
                    minStepOfInputFrom = Math.min(minStepOfInputFrom, posInTopoSort);
                }
            }

            if(minStepOfInputFrom == Integer.MAX_VALUE){
                //Input vertex, etc
                vertexActGradsFullyConsumedByStep[idx] = 0;
            } else {
                vertexActGradsFullyConsumedByStep[idx] = minStepOfInputFrom;
            }
        }


        boolean noWS = configuration.getInferenceWorkspaceMode() == WorkspaceMode.NONE;
        LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;

        List<LayerWorkspaceMgr> allWorkspaceManagers = new ArrayList<>();
        List<LayerWorkspaceMgr> freeWorkspaceManagers = new ArrayList<>();  //Basically used as a stack
        Map<MemoryWorkspace, LayerWorkspaceMgr> openActivationsWorkspaces = new IdentityHashMap<>();
        List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];

        //Do backprop, in reverse topological order
        LinkedList<Triple<String, INDArray, Character>> gradients = new LinkedList<>();
        boolean[] setVertexEpsilon = new boolean[topologicalOrder.length]; //If true: already set epsilon for this vertex; later epsilons should be *added* to the existing one, not set
        MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();

        boolean traceLog = log.isTraceEnabled();

        Throwable t = null;
        try {
            for (int i = topologicalOrder.length - 1; i >= 0; i--) {
                boolean hitFrozen = false;
                GraphVertex current = vertices[topologicalOrder[i]];
                int vIdx = current.getVertexIndex();
                String vertexName = current.getVertexName();

                if (traceLog) {
                    log.trace("About backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
                }

                //FIXME: make the frozen vertex feature extraction more flexible
                if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex) {
                    hitFrozen = true;
                }

                if (current.isInputVertex() || hitFrozen) {
                    //Close any activation gradient workspaces that we no longer require
                    //Note that activation gradient workspaces can be closed only once the corresponding activations
                    // gradients have been fully consumed
                    if (closeAtEndIteraton[i] != null) {
                        for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
                            wsAct.close();
                            LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
                            freeWorkspaceManagers.add(canNowReuse);
                        }
                    }
                    closeAtEndIteraton[i] = null;
                    continue;
                }


                //First: determine what workspace manager we should use for the activation gradients from this vertex
                LayerWorkspaceMgr workspaceMgr;
                if (noWS) {
                    workspaceMgr = allNone;
                } else {
                    //First: is there a free activation gradient workspace we can use?
                    if (freeWorkspaceManagers.size() > 0) {
                        workspaceMgr = freeWorkspaceManagers.remove(freeWorkspaceManagers.size() - 1);
                    } else {
                        //No existing free workspace managers for forward pass - create a new one...
                        String wsName = "WS_LAYER_ACT_" + allWorkspaceManagers.size();
                        workspaceMgr = LayerWorkspaceMgr.builder()
                                .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                                .with(ArrayType.ACTIVATION_GRAD, wsName, WS_LAYER_ACT_X_CONFIG)
                                .with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) //For forward pass in the context of BP
                                .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                                .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                                .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                                .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                                .build();

                        allWorkspaceManagers.add(workspaceMgr);
                    }
                }
                workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);

                if (current.isOutputVertex()) {
                    //Two reasons for a vertex to be an output vertex:
                    //(a) it's an output layer (i.e., instanceof IOutputLayer), or
                    //(b) it's a normal layer, but it has been marked as an output layer for use in external errors - for reinforcement learning, for example

                    int thisOutputNumber = configuration.getNetworkOutputs().indexOf(current.getVertexName());
                    Layer currentLayer = current.getLayer();
                    if (currentLayer instanceof FrozenLayerWithBackprop) {
                        currentLayer = ((FrozenLayerWithBackprop) currentLayer).getInsideLayer();
                    }
                    if (currentLayer instanceof IOutputLayer) {
                        IOutputLayer outputLayer = (IOutputLayer) currentLayer;

                        INDArray currLabels = labels[thisOutputNumber];
                        outputLayer.setLabels(currLabels);
                    } else {
                        if ((externalEpsilons == null || externalEpsilons.length == 0)
                                && labels[thisOutputNumber] != null) {
                            throw new DL4JException("Layer \"" + current.getVertexName() + "\" of type "
                                    + current.getLayer().getClass().getSimpleName()
                                    + " is set as network output "
                                    + "(but isn't an IOutputLayer). Only IOutputLayer layers can be fit via backprop with"
                                    + " a labels array. ");
                        }
                        current.setEpsilon(externalEpsilons[thisOutputNumber]);
                        setVertexEpsilon[topologicalOrder[i]] = true;
                    }
                }

                //Actually execute backprop for the specified vertex
                //First: Open the relevant workspace for the activations.
                //Note that this will be closed only once the current vertex's activations have been consumed
                MemoryWorkspace wsActivationGrads = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD);
                openActivationsWorkspaces.put(wsActivationGrads, workspaceMgr);

                //Note that because we're opening activation gradient workspaces not in any defined order (i.e., workspace
                // use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we
                // close these workspaces, the "current" workspace may be set to the incorrect one
                wsActivationGrads.setPreviousWorkspace(initialWorkspace);

                int closeableAt = vertexActGradsFullyConsumedByStep[vIdx];
                if (closeableAt >= 0) {
                    if (closeAtEndIteraton[closeableAt] == null) {
                        closeAtEndIteraton[closeableAt] = new ArrayList<>();
                    }
                    closeAtEndIteraton[closeableAt].add(wsActivationGrads);
                }

                Pair<Gradient, INDArray[]> pair;
                INDArray[] epsilons;
                try (MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) {
                    pair = current.doBackward(truncatedBPTT, workspaceMgr);
                    epsilons = pair.getSecond();

                    //Validate workspace location for the activation gradients:
                    //validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op){
                    for (INDArray epsilon : epsilons) {
                        if (epsilon != null) {
                            //May be null for EmbeddingLayer, etc
                            validateArrayWorkspaces(workspaceMgr, epsilon, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop");
                        }
                    }
                }

                //Inputs to the current GraphVertex:
                VertexIndices[] inputVertices = current.getInputVertices();

                //Set epsilons for the vertices that provide inputs to this vertex:
                if (inputVertices != null) {
                    int j = 0;
                    for (VertexIndices v : inputVertices) {
                        GraphVertex gv = vertices[v.getVertexIndex()];
                        if (setVertexEpsilon[gv.getVertexIndex()]) {
                            //This vertex: must output to multiple vertices... we want to add the epsilons here
                            INDArray currentEps = gv.getEpsilon();
                            gv.setEpsilon(currentEps.addi(epsilons[j++]));  //TODO is this always safe?
                        } else {
                            gv.setEpsilon(epsilons[j++]);
                        }
                        setVertexEpsilon[gv.getVertexIndex()] = true;
                    }
                }

                if (pair.getFirst() != null) {
                    Gradient g = pair.getFirst();
                    Map<String, INDArray> map = g.gradientForVariable();
                    LinkedList<Triple<String, INDArray, Character>> tempList = new LinkedList<>();
                    for (Map.Entry<String, INDArray> entry : map.entrySet()) {
                        String origName = entry.getKey();
                        String newName = current.getVertexName() + "_" + origName;
                        tempList.addFirst(new Triple<>(newName, entry.getValue(),
                                g.flatteningOrderForVariable(origName)));
                    }
                    for (Triple<String, INDArray, Character> triple : tempList)
                        gradients.addFirst(triple);
                }

                //Close any activation gradient workspaces that we no longer require
                //Note that activation gradient workspaces can be closed only once the corresponding activations
                // gradients have been fully consumed
                if (closeAtEndIteraton[i] != null) {
                    for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
                        wsAct.close();
                        LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
                        freeWorkspaceManagers.add(canNowReuse);
                    }
                    closeAtEndIteraton[i] = null;
                }

                if (traceLog) {
                    log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
                }
            }
        } catch (Throwable t2){
            t = t2;
        } finally {
            //Close all open workspaces... usually this list will be empty, but not if an exception is thrown
            for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
                try{
                    ws.close();
                } catch (Throwable t2){
                    if(t != null){
                        log.error("Encountered second exception while trying to close workspace after initial exception");
                        log.error("Original exception:", t);
                        throw t2;
                    }
                }
            }
            Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);

            if(t != null){
                if(t instanceof RuntimeException){
                    throw ((RuntimeException)t);
                }
                throw new RuntimeException("Error during neural network backpropagation calculation", t);
            }
        }

        //Now, add the gradients in the order we need them in for flattening (same as params order)
        Gradient gradient = new DefaultGradient(flattenedGradients);
        for (Triple<String, INDArray, Character> tr : gradients) {
            gradient.setGradientFor(tr.getFirst(), tr.getSecond(), tr.getThird());
        }

        this.gradient = gradient;

        if(truncatedBPTT && clearTbpttState){
            rnnClearPreviousState();
        }

        //Clear inputs and epsilons:
        if(clearLayers) {
            for (GraphVertex gv : vertices) {
                gv.clear();
            }
        }
    }

    @Override
    public ComputationGraph clone() {
        ComputationGraph cg = new ComputationGraph(configuration.clone());
        cg.init(params().dup(), false);
        if (solver != null) {
            //If  solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however
            ComputationGraphUpdater u = this.getUpdater();
            INDArray updaterState = u.getStateViewArray();
            if (updaterState != null) {
                cg.getUpdater().setStateViewArray(updaterState.dup());
            }
        }
        cg.trainingListeners = this.trainingListeners;
        for (int i = 0; i < topologicalOrder.length; i++) {
            if (!vertices[topologicalOrder[i]].hasLayer())
                continue;
            String layerName = vertices[topologicalOrder[i]].getVertexName();
            if (getLayer(layerName) instanceof FrozenLayer) {
                cg.getVertex(layerName).setLayerAsFrozen();
            }
        }
        return cg;
    }


    public double calcRegularizationScore(boolean backpropParamsOnly){
        double scoreSum = 0.0;
        for (int i = 0; i < layers.length; i++) {
            scoreSum += layers[i].calcRegularizationScore(backpropParamsOnly);
        }
        return scoreSum;
    }

    /**
     * Set the trainingListeners for the ComputationGraph (and all layers in the network)
     */
    public void setListeners(Collection<TrainingListener> listeners) {
        if (layers == null)
            init();

        for (Layer l : layers) {
            l.setListeners(listeners);
        }

        if (solver != null) {
            solver.setListeners(listeners);
        }

        this.trainingListeners.clear();
        if (listeners != null) {
            this.trainingListeners.addAll(listeners);
        }
    }

    /**
     * Set the trainingListeners for the ComputationGraph (and all layers in the network)
     */
    public void setListeners(TrainingListener... listeners) {
        List<TrainingListener> list = new ArrayList<>();
        //Check: user might have done setListeners(null) thinking this would clear the current listeners.
        //This results in an TrainingListener[1] with a single null value -> results in a NPE later
        if (listeners != null && listeners.length > 0) {
            for (TrainingListener i : listeners) {
                if (i != null)
                    list.add(i);
            }
        }
        setListeners(list);
    }

    /**
     * This method ADDS additional TrainingListener to existing listeners
     *
     * @param listeners Listeners to add
     */
    @Override
    public void addListeners(TrainingListener... listeners) {
        if (this.trainingListeners == null) {
            setListeners(listeners);
            return;
        } else {
            List<TrainingListener> newListeners = new ArrayList<>(this.trainingListeners);   //To avoid immutable list issues
            Collections.addAll(newListeners, listeners);
            setListeners(newListeners);
        }

        if (solver != null) {
            solver.setListeners(this.trainingListeners);
        }
    }

    /**
     * Get the trainingListeners for the ComputationGraph
     */
    public Collection<TrainingListener> getListeners() {
        return trainingListeners;
    }

    /**
     * Get the ComputationGraphUpdater for the network. Creates one on demand, if required
     */
    public ComputationGraphUpdater getUpdater() {
        return getUpdater(true);
    }

    /**
     * Get the ComputationGraphUpdater for this network
     * @param initializeIfAbsent If true: create the updater if one is absent. False: return null if absent.
     * @return Updater
     */
    public ComputationGraphUpdater getUpdater(boolean initializeIfAbsent){
        if (solver == null && initializeIfAbsent) {
            solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
            solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this));
        }
        if(solver != null) {
            return solver.getOptimizer().getComputationGraphUpdater(initializeIfAbsent);
        }
        return null;
    }

    /**
     * Set the computationGraphUpdater for the network
     */
    public void setUpdater(ComputationGraphUpdater updater) {
        if (solver == null) {
            solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        }
        solver.getOptimizer().setUpdaterComputationGraph(updater);
    }

    /**
     * Get the specified output layer, by index. The index of the output
     * layer may be 0 to {@link #getNumOutputArrays()}-1
     */
    public Layer getOutputLayer(int outputLayerIdx) {
        if (outputLayerIdx >= numOutputArrays)
            throw new IllegalArgumentException("Invalid index: cannot get output layer " + outputLayerIdx
                    + ", total number of network outputs = " + numOutputArrays);
        return getLayer(configuration.getNetworkOutputs().get(outputLayerIdx));
    }

    /**
     * @deprecated To be removed. Use {@link #params()}
     */
    @Deprecated
    public INDArray params(boolean backwardOnly) {
        return params();
    }

    /**
     * Sets the input and labels and returns a score for the prediction with respect to the true labels<br>
     * This is equivalent to {@link #score(DataSet, boolean)} with training==true.<br>
     * <b>NOTE:</b> this version of the score function can only be used with ComputationGraph networks that have
     * a single input and a single output.
     *
     * @param dataSet the data to score
     * @return the score for the given input,label pairs
     * @see #score(DataSet, boolean)
     */
    public double score(DataSet dataSet) {
        return score(dataSet, false);
    }

    /**
     * Sets the input and labels and returns a score for the prediction with respect to the true labels<br>
     * <b>NOTE:</b> this version of the score function can only be used with ComputationGraph networks that have
     * a single input and a single output. Use {@link #score(MultiDataSet, boolean)} for multiple input/output networks
     *
     * @param dataSet  the data to score
     * @param training whether score is being calculated at training time (true) or test time (false)
     * @return the score for the given input,label pairs
     * @see #score(DataSet, boolean)
     */
    public double score(DataSet dataSet, boolean training) {
        if (numInputArrays != 1 || numOutputArrays != 1)
            throw new UnsupportedOperationException("Cannot score ComputationGraph network with "
                    + " DataSet: network does not have 1 input and 1 output arrays");
        return score(ComputationGraphUtil.toMultiDataSet(dataSet), training);
    }

    /**
     * Score the network given the MultiDataSet, at test time
     */
    public double score(MultiDataSet dataSet) {
        return score(dataSet, false);
    }

    /**
     * Sets the input and labels and returns a score for the prediction with respect to the true labels<br>
     *
     * @param dataSet  the data to score
     * @param training whether score is being calculated at training time (true) or test time (false)
     * @return the score for the given input,label pairs
     */
    public double score(MultiDataSet dataSet, boolean training) {
        try{
            return scoreHelper(dataSet, training);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private double scoreHelper(MultiDataSet dataSet, boolean training){
        LayerWorkspaceMgr mgr;
        WorkspaceMode wsm = (training ? configuration.getTrainingWorkspaceMode() : configuration.getInferenceWorkspaceMode());
        if(wsm == WorkspaceMode.NONE){
            mgr = LayerWorkspaceMgr.noWorkspaces();
        } else {
            mgr = LayerWorkspaceMgr.builder()
                    .noWorkspaceFor(ArrayType.ACTIVATIONS)
                    .noWorkspaceFor(ArrayType.INPUT)
                    .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                    .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                    .build();
        }
        mgr.setHelperWorkspacePointers(helperWorkspaces);

        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            setLayerMaskArrays(dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays());
        }

        double score = 0.0;
        setInputs(dataSet.getFeatures());

        //Need to feed forward, but not the output layers
        try(MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS)){
            //TODO Can possibly optimize this, in terms of memory use/workspaces
            ffToLayerActivationsDetached(training, FwdPassType.STANDARD, false, vertices.length-1,
                    getOutputLayerIndices(), dataSet.getFeatures(), dataSet.getFeaturesMaskArrays(),dataSet.getLabelsMaskArrays(), false);

            INDArray[] labels = dataSet.getLabels();
            setLabels(labels);

            //Score: sum of the scores for the various output layers...
            double r = calcRegularizationScore(true);

            int i = 0;
            for (String s : configuration.getNetworkOutputs()) {
                GraphVertex gv = verticesMap.get(s);
                Layer outLayer = gv.getLayer();
                if (outLayer == null || !(outLayer instanceof IOutputLayer)) {
                    log.warn("Cannot calculate score: vertex \"" + s + "\" is not an output layer");
                    return 0.0;
                }

                IOutputLayer ol = (IOutputLayer) outLayer;
                ol.setLabels(labels[i++]);

                score += ((LayerVertex) gv).computeScore(r, training, mgr);

                //Only want to add l1/l2 once...
                r = 0.0;
            }
        }

        clearLayersStates();    //Clean up layer inputs/mask arrays - may be invalidated by workspace
        return score;
    }

    /**
     * Calculate the score for each example in a DataSet individually. Unlike {@link #score(DataSet)} and {@link #score(DataSet, boolean)}
     * this method does not average/sum over examples. This method allows for examples to be scored individually (at test time only), which
     * may be useful for example for autoencoder architectures and the like.<br>
     * Each row of the output (assuming addRegularizationTerms == true) is equivalent to calling score(DataSet) with a single example.
     *
     * @param data                   The data to score
     * @param addRegularizationTerms If true: add l1/l2 regularization terms (if any) to the score. If false: don't add regularization terms
     * @return An INDArray (column vector) of size input.numRows(); the ith entry is the score (loss value) of the ith example
     */
    public INDArray scoreExamples(DataSet data, boolean addRegularizationTerms) {
        if (numInputArrays != 1 || numOutputArrays != 1)
            throw new UnsupportedOperationException("Cannot score ComputationGraph network with "
                    + " DataSet: network does not have 1 input and 1 output arrays");
        return scoreExamples(ComputationGraphUtil.toMultiDataSet(data), addRegularizationTerms);
    }

    /**
     * Calculate the score for each example in a DataSet individually. Unlike {@link #score(MultiDataSet)} and {@link #score(MultiDataSet, boolean)}
     * this method does not average/sum over examples. This method allows for examples to be scored individually (at test time only), which
     * may be useful for example for autoencoder architectures and the like.<br>
     * Each row of the output (assuming addRegularizationTerms == true) is equivalent to calling score(MultiDataSet) with a single example.
     *
     * @param dataSet                The data to score
     * @param addRegularizationTerms If true: add l1/l2 regularization terms (if any) to the score. If false: don't add regularization terms
     * @return An INDArray (column vector) of size input.numRows(); the ith entry is the score (loss value) of the ith example
     */
    public INDArray scoreExamples(MultiDataSet dataSet, boolean addRegularizationTerms) {
        try{
            return scoreExamplesHelper(dataSet, addRegularizationTerms);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private INDArray scoreExamplesHelper(MultiDataSet dataSet, boolean addRegularizationTerms){
        LayerWorkspaceMgr mgr;
        if(configuration.getInferenceWorkspaceMode() == WorkspaceMode.NONE){
            mgr = LayerWorkspaceMgr.noWorkspaces();
        } else {
            mgr = LayerWorkspaceMgr.builder()
                    .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                    .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
                    .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
                    .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
                    .build();
        }
        mgr.setHelperWorkspacePointers(helperWorkspaces);

        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            setLayerMaskArrays(dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays());
        }

        INDArray out = null;
        setInputs(dataSet.getFeatures());

        //Need to feed forward, but not the output layers
        try(MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS)) {
            //TODO maybe optimize? We only need *some* of the activations in the WS...
            ffToLayerActivationsInWS(false, vertices.length - 1, getOutputLayerIndices(), FwdPassType.STANDARD, false,
                    dataSet.getFeatures(), dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays(), false);

            INDArray[] labels = dataSet.getLabels();
            setLabels(labels);


            double r = (addRegularizationTerms ? calcRegularizationScore(true) : 0.0);
            int i = 0;
            for (String s : configuration.getNetworkOutputs()) {
                GraphVertex gv = verticesMap.get(s);
                Layer outLayer = gv.getLayer();
                if (outLayer == null || !(outLayer instanceof IOutputLayer)) {
                    throw new UnsupportedOperationException(
                            "Cannot calculate score: vertex \"" + s + "\" is not an output layer");
                }

                IOutputLayer ol = (IOutputLayer) outLayer;
                ol.setLabels(labels[i++]);

                INDArray scoreCurrLayer;
                try(MemoryWorkspace wsFF = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) {
                    scoreCurrLayer =((LayerVertex) gv).computeScoreForExamples(r, mgr);
                }
                if (out == null)
                    out = scoreCurrLayer.detach();
                else
                    out.addi(scoreCurrLayer);

                //Only want to add l1/l2 once...
                r = 0.0;
            }
        }

        if (dataSet.hasMaskArrays())
            clearLayerMaskArrays();
        clearLayersStates();
        return out;
    }


    //------------------------------------------------------
    //Model methods:

    @Override
    public void fit() {
        fit(inputs, labels, inputMaskArrays, labelMaskArrays);
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void update(Gradient gradient) {
        if (gradient.gradient().length() != numParams(true))
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + numParams(true));
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            INDArray val = entry.getValue();
            int idx = key.lastIndexOf('_');
            if (idx == -1)
                throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
            String layerName = key.substring(0, idx);
            String paramType = key.split("_")[1];
            // Update graph gradient
            this.gradient.gradientForVariable().put(key, val);
            // Update layer params
            getLayer(layerName).update(val, paramType);
        }
        // Update layerwise gradient view
        setBackpropGradientsViewArray(gradient.gradient());
    }

    private void update(Task task) {
        if (!initDone) {
            initDone = true;
            Heartbeat heartbeat = Heartbeat.getInstance();
            task = ModelSerializer.taskByModel(this);
            Environment env = EnvironmentUtils.buildEnvironment();
            heartbeat.reportEvent(Event.STANDALONE, env, task);
        }
    }

    @Override
    public double score() {
        return score;
    }

    public void setScore(double score) {
        this.score = score;
    }

    @Override
    public INDArray params() {
        return flattenedParams;
    }

    @Override
    public INDArray updaterState() {
        return getUpdater() != null ? getUpdater().getUpdaterStateViewArray() : null;
    }

    @Override
    public long numParams() {
        return numParams(true);
    }

    @Override
    public long numParams(boolean backwards) {
        int nParams = 0;
        for (Layer layer : layers) {
            nParams += layer.numParams(backwards);
        }
        return nParams;
    }

    @Override
    public void setParams(INDArray params) {
        if (params == flattenedParams)
            return; //No op

        if (this.flattenedParams != null && this.flattenedParams.length() == params.length()) {
            this.flattenedParams.assign(params);
            return;
        }

        int idx = 0;
        for (int i = 0; i < topologicalOrder.length; i++) {
            if (!vertices[topologicalOrder[i]].hasLayer())
                continue;

            Layer layer = vertices[topologicalOrder[i]].getLayer();
            long range = layer.numParams();
            if (range <= 0)
                continue; //Some layers: no parameters (subsampling etc)
            INDArray get = params.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(idx, range + idx));
            layer.setParams(get);
            idx += range;
        }
    }

    @Override
    public void setParamsViewArray(INDArray gradient) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public INDArray getGradientsViewArray() {
        return flattenedGradients;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradient) {
        int paramsSoFar = 0;
        for (int i = 0; i < topologicalOrder.length; i++) {
            if (!vertices[topologicalOrder[i]].hasLayer())
                continue;

            Layer layer = vertices[topologicalOrder[i]].getLayer();
            long range = layer.numParams();
            if (range <= 0)
                continue; //Some layers: no parameters (subsampling etc)
            layer.setBackpropGradientsViewArray(gradient.get(NDArrayIndex.interval(0,0,true),
                    NDArrayIndex.interval(paramsSoFar, paramsSoFar + range)));
            paramsSoFar += range;
        }
    }

    @Override
    public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr){
        throw new UnsupportedOperationException("Cannot pretrain ComputationGraph with single INDArray");
    }

    @Override
    public Gradient gradient() {
        return gradient;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), score());
    }

    @Override
    public int batchSize() {
        //In 99+% of cases, the input and labels dimension 0 size should be identical
        //The only real exceptions: space to batch, and batch to space layers
        //In those cases, we should base it on the labels size, as this impacts gradient calculation
        return labels == null || labels[0] == null ? (int) inputs[0].size(0) : (int)labels[0].size(0);
    }

    @Override
    public NeuralNetConfiguration conf() {
        return defaultConfiguration;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray input() {
        if (numInputArrays == 1)
            return (inputs != null ? inputs[0] : null);
        else
            throw new UnsupportedOperationException(
                    "Cannot return single input: ComputationGraph  has multiple inputs");
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        return solver.getOptimizer();
    }

    @Override
    public INDArray getParam(String paramName) {
        //        throw new UnsupportedOperationException("Not implemented");
        int idx = paramName.lastIndexOf('_');
        if (idx == -1)
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + paramName + "\"");
        String layerName = paramName.substring(0, idx);
        String paramType = paramName.substring(idx + 1);
        return getLayer(layerName).getParam(paramType);

    }

    @Override
    public Map<String, INDArray> paramTable() {
        return paramTable(false);
    }

    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        //Get all parameters from all layers/vertices
        Map<String, INDArray> allParams = new LinkedHashMap<>();
        for(GraphVertex gv : vertices){
            Map<String, INDArray> paramMap = gv.paramTable(backpropParamsOnly);
            for (Map.Entry<String, INDArray> entry : paramMap.entrySet()) {
                String newKey = gv.getVertexName() + "_" + entry.getKey();
                allParams.put(newKey, entry.getValue());
            }
        }
        return allParams;
    }

    @Override
    public void setParamTable(@NonNull Map<String, INDArray> paramTable) {
        Preconditions.checkArgument(paramTable.keySet().equals(paramTable().keySet()), "Cannot set param table: parameter set keys are not equal");
        Map<String,INDArray> current = paramTable();
        //Check shapes before doing partial assigment to avoid leaving net in incorrect state
        for(String s : current.keySet()){
            INDArray arrCurrent = current.get(s);
            INDArray arrNew = paramTable.get(s);
            val shapeCurrent = arrCurrent.shape();
            val shapeNew = arrNew.shape();
            Preconditions.checkState(Arrays.equals(shapeCurrent, shapeNew), "Cannot set parameters: shape array for " +
                    "parameter \"%s\" does not match existing shape: parameter shape = %s, new param shape = %s", s, shapeCurrent, arrNew);
        }

        for(String s : current.keySet()) {
            INDArray arrCurrent = current.get(s);
            INDArray arrNew = paramTable.get(s);
            arrCurrent.assign(arrNew);
        }
    }

    @Override
    public void setParam(String key, INDArray val) {
        //        throw new UnsupportedOperationException("Not implemented");
        int idx = key.lastIndexOf('_');
        if (idx == -1)
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
        String layerName = key.substring(0, idx);
        String paramType = key.substring(idx + 1);
        getLayer(layerName).setParam(paramType, val);
    }

    @Override
    public void clear() {
        inputs = null;
        labels = null;
        inputMaskArrays = null;
        labelMaskArrays = null;
    }

    @Override
    public void applyConstraints(int iteration, int epoch) {
        for(Layer l : layers){
            l.applyConstraints(iteration, epoch);
        }
    }

    //------------------------------------------------------------------------------
    //RNN-specific functionality

    /**
     * If this ComputationGraph contains one or more RNN layers: conduct forward pass (prediction)
     * but using previous stored state for any RNN layers. The activations for the final step are
     * also stored in the RNN layers for use next time rnnTimeStep() is called.<br>
     * This method can be used to generate output one or more steps at a time instead of always having to do
     * forward pass from t=0. Example uses are for streaming data, and for generating samples from network output
     * one step at a time (where samples are then fed back into the network as input)<br>
     * If no previous state is present in RNN layers (i.e., initially or after calling rnnClearPreviousState()),
     * the default initialization (usually 0) is used.<br>
     * Supports mini-batch (i.e., multiple predictions/forward pass in parallel) as well as for single examples.<br>
     *
     * @param inputs Input to network. May be for one or multiple time steps. For single time step:
     *               input has shape [miniBatchSize,inputSize] or [miniBatchSize,inputSize,1]. miniBatchSize=1 for single example.<br>
     *               For multiple time steps: [miniBatchSize,inputSize,inputTimeSeriesLength]
     * @return Output activations. If output is RNN layer (such as RnnOutputLayer): if all inputs have shape [miniBatchSize,inputSize]
     * i.e., is 2d, then outputs have shape [miniBatchSize,outputSize] (i.e., also 2d) instead of [miniBatchSize,outputSize,1].<br>
     * Otherwise output is 3d [miniBatchSize,outputSize,inputTimeSeriesLength] when using RnnOutputLayer (or unmodified otherwise).
     */
    public INDArray[] rnnTimeStep(INDArray... inputs) {
        return rnnTimeStepHelper(null, inputs);
    }

    /**
     * See {@link #rnnTimeStep(INDArray...)} for details.<br>
     * If no memory workspace is provided, the output will be detached (not in any workspace).<br>
     * If a memory workspace is provided, the output activation array (i.e., the INDArray returned by this method)
     * will be placed in the specified workspace. This workspace must be opened by the user before calling this method -
     * and the user is responsible for (a) closing this workspace, and (b) ensuring the output array is not used out
     * of scope (i.e., not used after closing the workspace to which it belongs - as this is likely to cause either
     * an exception when used, or a crash).
     *
     * @param inputs          Input activations
     * @param outputWorkspace Output workspace. May be null
     * @return The output/activations from the network (either detached or in the specified workspace if provided)
     */
    public INDArray[] rnnTimeStep(MemoryWorkspace outputWorkspace, INDArray... inputs){
        try{
            return rnnTimeStepHelper(outputWorkspace, inputs);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private INDArray[] rnnTimeStepHelper(MemoryWorkspace outputWs, INDArray... inputs){
        boolean inputIs2d = true;
        for (INDArray i : inputs) {
            if (i.rank() != 2) {
                inputIs2d = false;
                break;
            }
        }

        INDArray[] outputs = outputOfLayersDetached(false, FwdPassType.RNN_TIMESTEP, getOutputLayerIndices(), inputs, null, null, true, false, outputWs);

        //As per MultiLayerNetwork.rnnTimeStep(): if inputs are all 2d, then outputs are all 2d
        if (inputIs2d) {
            for (int i = 0; i < outputs.length; i++) {
                if (outputs[i].rank() == 3 && outputs[i].size(2) == 1) {
                    //Return 2d output with shape [miniBatchSize,nOut]
                    // instead of 3d output with shape [miniBatchSize,nOut,1]
                    outputs[i] = outputs[i].tensorAlongDimension(0, 1, 0);
                }
            }
        }

        this.inputs = null;
        return outputs;
    }

    /**
     * Get the state of the RNN layer, as used in {@link #rnnTimeStep(INDArray...)}.
     *
     * @param layer Number/index of the layer.
     * @return Hidden state, or null if layer is not an RNN layer
     */
    public Map<String, INDArray> rnnGetPreviousState(int layer) {
        return rnnGetPreviousState(layers[layer].conf().getLayer().getLayerName());
    }

    /**
     * Get the state of the RNN layer, as used in {@link #rnnTimeStep(INDArray...)}.
     *
     * @param layerName name of the layer
     * @return Hidden state, or null if layer is not an RNN layer
     */
    public Map<String, INDArray> rnnGetPreviousState(String layerName) {
        Layer l = verticesMap.get(layerName).getLayer();
        if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer){
            l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying();
        }
        if (l == null || !(l instanceof RecurrentLayer))
            return null;
        return ((RecurrentLayer) l).rnnGetPreviousState();
    }

    /**
     * Get a map of states for ALL RNN layers, as used in {@link #rnnTimeStep(INDArray...)}.
     * Layers that are not RNN layers will not have an entry in the returned map
     *
     * @return Map of states (keyed by layer name) or null if layer is not an RNN layer
     * @see #rnnSetPreviousStates(Map)
     */
    public Map<String, Map<String, INDArray>> rnnGetPreviousStates() {
        Map<String, Map<String, INDArray>> states = new HashMap<>();
        for (Layer l : layers) {
            if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer){
                l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying();
            }
            if (l instanceof RecurrentLayer) {
                states.put(l.conf().getLayer().getLayerName(), ((RecurrentLayer) l).rnnGetPreviousState());
            }
        }
        return states;
    }

    /**
     * Set the state of the RNN layer, for use in {@link #rnnTimeStep(INDArray...)}
     *
     * @param layer The number/index of the layer.
     * @param state The state to set the specified layer to
     */
    public void rnnSetPreviousState(int layer, Map<String, INDArray> state) {
        rnnSetPreviousState(layers[layer].conf().getLayer().getLayerName(), state);
    }

    /**
     * Set the state of the RNN layer, for use in {@link #rnnTimeStep(INDArray...)}
     *
     * @param layerName The name of the layer.
     * @param state     The state to set the specified layer to
     */
    public void rnnSetPreviousState(String layerName, Map<String, INDArray> state) {
        Layer l = verticesMap.get(layerName).getLayer();
        if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer){
            l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying();
        }
        if (l == null || !(l instanceof RecurrentLayer)) {
            throw new UnsupportedOperationException(
                    "Layer \"" + layerName + "\" is not a recurrent layer. Cannot set state");
        }
        ((RecurrentLayer) l).rnnSetPreviousState(state);
    }

    /**
     * Set the states for all RNN layers, for use in {@link #rnnTimeStep(INDArray...)}
     *
     * @param previousStates The previous time step states for all layers (key: layer name. Value: layer states)
     * @see #rnnGetPreviousStates()
     */
    public void rnnSetPreviousStates(Map<String, Map<String, INDArray>> previousStates) {
        for (Map.Entry<String, Map<String, INDArray>> entry : previousStates.entrySet()) {
            rnnSetPreviousState(entry.getKey(), entry.getValue());
        }
    }

    /**
     * Clear the previous state of the RNN layers (if any), used in {@link #rnnTimeStep(INDArray...)}
     */
    public void rnnClearPreviousState() {
        if (layers == null)
            return;
        for (Layer layer : layers) {
            if (layer instanceof RecurrentLayer)
                ((RecurrentLayer) layer).rnnClearPreviousState();
            else if (layer instanceof MultiLayerNetwork) {
                ((MultiLayerNetwork) layer).rnnClearPreviousState();
            }
        }
    }

    /**
     * Fit the network using truncated BPTT
     */
    protected void doTruncatedBPTT(INDArray[] inputs, INDArray[] labels, INDArray[] featureMasks,
                                   INDArray[] labelMasks, LayerWorkspaceMgr workspaceMgr) {
        if (flattenedGradients == null) {
            initGradientsView();
        }

        //Approach used here to implement truncated BPTT: if input is 3d, split it. Otherwise: input is unmodified
        long timeSeriesLength = -1;
        for (INDArray in : inputs) {
            if (in.rank() != 3)
                continue;
            if (timeSeriesLength == -1)
                timeSeriesLength = in.size(2);
            else if (timeSeriesLength != in.size(2)) {
                log.warn("Cannot do TBPTT with time series of different lengths");
                return;
            }
        }
        for (INDArray out : labels) {
            if (out.rank() != 3)
                continue;
            if (timeSeriesLength == -1)
                timeSeriesLength = out.size(2);
            else if (timeSeriesLength != out.size(2)) {
                log.warn("Cannot do TBPTT with time series of different lengths");
                return;
            }
        }

        long fwdLen = configuration.getTbpttFwdLength();
        long nSubsets = timeSeriesLength / fwdLen;
        if (timeSeriesLength % fwdLen != 0)
            nSubsets++;

        rnnClearPreviousState();

        for (int i = 0; i < nSubsets; i++) {
            long startTimeIdx = i * fwdLen;
            long endTimeIdx = startTimeIdx + fwdLen;
            if (endTimeIdx > timeSeriesLength)
                endTimeIdx = timeSeriesLength;

            if (startTimeIdx > Integer.MAX_VALUE)
                throw new ND4JArraySizeException();
            List<INDArray[]> list = getSubsetsForTbptt((int) startTimeIdx, endTimeIdx, inputs, labels, featureMasks, labelMasks);

            setInputs(list.get(0));
            setLabels(list.get(1));
            setLayerMaskArrays(list.get(2), list.get(3));

            if (solver == null) {
                try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                    solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this)
                            .build();
                }
            }
            solver.optimize(workspaceMgr);

            //Finally, update the state of the RNN layers:
            rnnUpdateStateWithTBPTTState();
        }

        if(clearTbpttState) {
            rnnClearPreviousState();
        }
        clearLayerMaskArrays();
    }

    private List<INDArray[]> getSubsetsForTbptt(int startTimeIdx, long endTimeIdx, INDArray[] inputs, INDArray[] labels,
                                                INDArray[] featureMasks, INDArray[] labelMasks){
        INDArray[] newInputs = new INDArray[inputs.length];
        INDArray[] newLabels = new INDArray[labels.length];
        INDArray[] newFeatureMasks = (featureMasks != null ? new INDArray[featureMasks.length] : null);
        INDArray[] newLabelMasks = (labelMasks != null ? new INDArray[labelMasks.length] : null);

        for (int j = 0; j < inputs.length; j++) {
            if (inputs[j].rank() != 3)
                newInputs[j] = inputs[j];
            else {
                newInputs[j] = inputs[j].get(NDArrayIndex.all(), NDArrayIndex.all(),
                        NDArrayIndex.interval(startTimeIdx, endTimeIdx));
            }
        }
        for (int j = 0; j < labels.length; j++) {
            if (labels[j].rank() != 3)
                newLabels[j] = labels[j];
            else {
                newLabels[j] = labels[j].get(NDArrayIndex.all(), NDArrayIndex.all(),
                        NDArrayIndex.interval(startTimeIdx, endTimeIdx));
            }
        }
        if (featureMasks != null) {
            for (int j = 0; j < featureMasks.length; j++) {
                if (featureMasks[j] == null)
                    continue;
                newFeatureMasks[j] = featureMasks[j].get(NDArrayIndex.all(),
                        NDArrayIndex.interval(startTimeIdx, endTimeIdx));
            }
        }
        if (labelMasks != null) {
            for (int j = 0; j < labelMasks.length; j++) {
                if (labelMasks[j] == null)
                    continue;
                newLabelMasks[j] = labelMasks[j].get(NDArrayIndex.all(),
                        NDArrayIndex.interval(startTimeIdx, endTimeIdx));
            }
        }

        return Arrays.asList(newInputs, newLabels, newFeatureMasks, newLabelMasks);
    }

    /**
     * Similar to rnnTimeStep and feedForward() methods. Difference here is that this method:<br>
     * (a) like rnnTimeStep does forward pass using stored state for RNN layers, and<br>
     * (b) unlike rnnTimeStep does not modify the RNN layer state<br>
     * Therefore multiple calls to this method with the same input should have the same output.<br>
     * Typically used during training only. Use rnnTimeStep for prediction/forward pass at test time.
     *
     * @param inputs            Input to network
     * @param training          Whether training or not
     * @param storeLastForTBPTT set to true if used as part of truncated BPTT training
     * @return Activations for each layer (including input, as per feedforward() etc)
     */
    public Map<String, INDArray> rnnActivateUsingStoredState(INDArray[] inputs, boolean training,
                                                             boolean storeLastForTBPTT) {
        return ffToLayerActivationsDetached(training, FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE, storeLastForTBPTT, vertices.length-1,
                null, inputs, inputMaskArrays, labelMaskArrays, true);
    }

    /**
     * Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many
     * and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths
     * within the same minibatch.<br>
     * For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape
     * [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength]
     * and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding -
     * at a given time step).<br>
     * <b>NOTE</b>: This method is not usually used directly. Instead, the various feedForward and fit methods handle setting
     * of masking internally.
     *
     * @param featureMaskArrays Mask array for features (input)
     * @param labelMaskArrays   Mask array for labels (output)
     * @see #clearLayerMaskArrays()
     */
    public void setLayerMaskArrays(INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
        this.clearLayerMaskArrays();
        this.inputMaskArrays = featureMaskArrays;
        this.labelMaskArrays = labelMaskArrays;

        if (featureMaskArrays != null) {
            if (featureMaskArrays.length != numInputArrays) {
                throw new IllegalArgumentException("Invalid number of feature mask arrays");
            }

            long minibatchSize = -1;
            for (INDArray i : featureMaskArrays) {
                if (i != null) {
                    minibatchSize = i.size(0);
                }
            }

            //Here: need to do forward pass through the network according to the topological ordering of the network

            Map<Integer, Pair<INDArray, MaskState>> map = new HashMap<>();
            for (int i = 0; i < topologicalOrder.length; i++) {
                GraphVertex current = vertices[topologicalOrder[i]];

                if (current.isInputVertex()) {
                    INDArray fMask = featureMaskArrays[current.getVertexIndex()];
                    map.put(current.getVertexIndex(), new Pair<>(fMask, MaskState.Active));
                } else {
                    VertexIndices[] inputVertices = current.getInputVertices();

                    //Now: work out the mask arrays to feed forward...
                    INDArray[] inputMasks = null; //new INDArray[inputVertices.length];
                    MaskState maskState = null;
                    for (int j = 0; j < inputVertices.length; j++) {
                        Pair<INDArray, MaskState> p = map.get(inputVertices[j].getVertexIndex());
                        if (p != null) {
                            if (inputMasks == null) {
                                inputMasks = new INDArray[inputVertices.length];
                            }
                            inputMasks[j] = p.getFirst();
                            if (maskState == null || maskState == MaskState.Passthrough) {
                                maskState = p.getSecond();
                            }
                        }
                    }

                    if (minibatchSize > Integer.MAX_VALUE)
                        throw new ND4JArraySizeException();
                    Pair<INDArray, MaskState> outPair =
                            current.feedForwardMaskArrays(inputMasks, maskState, (int)minibatchSize);
                    map.put(topologicalOrder[i], outPair);
                }
            }
        }

        if (labelMaskArrays != null) {
            if (labelMaskArrays.length != numOutputArrays) {
                throw new IllegalArgumentException("Invalid number of label mask arrays");
            }
            for (int i = 0; i < labelMaskArrays.length; i++) {
                if (labelMaskArrays[i] == null) {
                    // This output doesn't have a mask, we can skip it.
                    continue;
                }
                String outputName = configuration.getNetworkOutputs().get(i);
                GraphVertex v = verticesMap.get(outputName);
                Layer ol = v.getLayer();
                ol.setMaskArray(labelMaskArrays[i]);
            }
        }
    }

    /**
     * Remove the mask arrays from all layers.<br>
     * See {@link #setLayerMaskArrays(INDArray[], INDArray[])} for details on mask arrays.
     */
    public void clearLayerMaskArrays() {
        for (Layer layer : layers) {
            layer.setMaskArray(null);
        }
        this.inputMaskArrays = null;
        this.labelMaskArrays = null;
    }

    /**
     * Update the internal state of RNN layers after a truncated BPTT fit call
     */
    protected void rnnUpdateStateWithTBPTTState() {
        for (int i = 0; i < layers.length; i++) {
            if (layers[i] instanceof RecurrentLayer) {
                RecurrentLayer l = ((RecurrentLayer) layers[i]);
                l.rnnSetPreviousState(l.rnnGetTBPTTState());
            } else if (layers[i] instanceof MultiLayerNetwork) {
                ((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState();
            }
        }
    }

    /**
     * Evaluate the network (classification performance - single output ComputationGraphs only)
     *
     * @param iterator Iterator to evaluate on
     * @return Evaluation object; results of evaluation on all examples in the data set
     */
    public <T extends Evaluation> T evaluate(DataSetIterator iterator) {
        return (T)evaluate(iterator, (List<String>)null);
    }

    /**
     * Evaluate the network (classification performance - single output ComputationGraphs only)
     *
     * @param iterator Iterator to evaluate on
     * @return Evaluation object; results of evaluation on all examples in the data set
     */
    public <T extends Evaluation> T  evaluate(MultiDataSetIterator iterator) {
        return evaluate(iterator, (List<String>)null);
    }

    /**
     * Evaluate the network on the provided data set (single output ComputationGraphs only). Used for evaluating
     * the performance of classifiers
     *
     * @param iterator Data to undertake evaluation on
     * @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator
     */
    public <T extends Evaluation> T  evaluate(DataSetIterator iterator, List<String> labelsList) {
        return evaluate(iterator, labelsList, 1);
    }

    /**
     * Evaluate the network on the provided data set (single output ComputationGraphs only). Used for evaluating
     * the performance of classifiers
     *
     * @param iterator Data to undertake evaluation on
     * @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator
     */
    public <T extends Evaluation> T evaluate(MultiDataSetIterator iterator, List<String> labelsList) {
        return evaluate(iterator, labelsList, 1);
    }

    /**
     * Evaluate the network (for classification) on the provided data set, with top N accuracy in addition to standard accuracy.
     * For 'standard' accuracy evaluation only, use topN = 1
     *
     * @param iterator   Iterator (data) to evaluate on
     * @param labelsList List of labels. May be null.
     * @param topN       N value for top N accuracy evaluation
     * @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator
     */
    public <T extends Evaluation> T evaluate(DataSetIterator iterator, List<String> labelsList, int topN) {
        if (labelsList == null)
            labelsList = iterator.getLabels();

        Layer outputLayer = getOutputLayer(0);
        if(getConfiguration().isValidateOutputLayerConfig()){
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class);
        }

        return (T)doEvaluation(iterator, new org.deeplearning4j.eval.Evaluation(labelsList, topN))[0];
    }

    /**
     * Evaluate the network (for classification) on the provided data set, with top N accuracy in addition to standard accuracy.
     * For 'standard' accuracy evaluation only, use topN = 1
     *
     * @param iterator   Iterator (data) to evaluate on
     * @param labelsList List of labels. May be null.
     * @param topN       N value for top N accuracy evaluation
     * @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator
     */
    public <T extends Evaluation> T evaluate(MultiDataSetIterator iterator, List<String> labelsList, int topN) {
        Layer outputLayer = getOutputLayer(0);
        if(getConfiguration().isValidateOutputLayerConfig()){
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class);
        }
        return (T)doEvaluation(iterator, new org.deeplearning4j.eval.Evaluation(labelsList, topN))[0];
    }

    /**
     * Evaluate the (single output layer only) network for regression performance
     *
     * @param iterator Data to evaluate on
     * @return Regression evaluation
     */
    public <T extends RegressionEvaluation> T evaluateRegression(DataSetIterator iterator) {
        return evaluateRegression(iterator, null);
    }

    /**
     * Evaluate the (single output layer only) network for regression performance
     *
     * @param iterator Data to evaluate on
     * @return Regression evaluation
     */
    public <T extends RegressionEvaluation> T evaluateRegression(MultiDataSetIterator iterator) {
        return evaluateRegression(iterator, null);
    }

    /**
     * Evaluate the (single output layer only) network for regression performance
     *
     * @param iterator    Data to evaluate on
     * @param columnNames Column names for the regression evaluation. May be null.
     * @return Regression evaluation
     */
    public <T extends RegressionEvaluation> T evaluateRegression(DataSetIterator iterator, List<String> columnNames) {
        return (T)doEvaluation(iterator, new org.deeplearning4j.eval.RegressionEvaluation(columnNames))[0];
    }

    /**
     * Evaluate the (single output layer only) network for regression performance
     *
     * @param iterator Data to evaluate on
     * @return Regression evaluation
     */
    public <T extends RegressionEvaluation> T evaluateRegression(MultiDataSetIterator iterator, List<String> columnNames) {
        return (T)doEvaluation(iterator, new org.deeplearning4j.eval.RegressionEvaluation(columnNames))[0];
    }


    /**
     * @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration
     */
    @Deprecated
    public <T extends ROC> T evaluateROC(DataSetIterator iterator) {
        return evaluateROC(iterator, 0);
    }

    /**
     * Evaluate the network (must be a binary classifier) on the specified data, using the {@link ROC} class
     *
     * @param iterator          Data to evaluate on
     * @param rocThresholdSteps Number of threshold steps to use with {@link ROC}
     * @return ROC evaluation on the given dataset
     */
    public <T extends ROC> T evaluateROC(DataSetIterator iterator, int rocThresholdSteps) {
        Layer outputLayer = getOutputLayer(0);
        if(getConfiguration().isValidateOutputLayerConfig()){
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class);
        }
        return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0];
    }

    /**
     * @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration
     */
    @Deprecated
    public <T extends ROC> T evaluateROC(MultiDataSetIterator iterator) {
        return evaluateROC(iterator, 0);
    }

    /**
     * Evaluate the network (must be a binary classifier) on the specified data, using the {@link ROC} class
     *
     * @param iterator          Data to evaluate on
     * @param rocThresholdSteps Number of threshold steps to use with {@link ROC}
     * @return ROC evaluation on the given dataset
     */
    public <T extends ROC> T evaluateROC(MultiDataSetIterator iterator, int rocThresholdSteps) {
        Layer outputLayer = getOutputLayer(0);
        if(getConfiguration().isValidateOutputLayerConfig()){
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class);
        }
        return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0];
    }

    /**
     * @deprecated To be removed - use {@link #evaluateROCMultiClass(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration
     */
    @Deprecated
    public <T extends ROCMultiClass> T evaluateROCMultiClass(DataSetIterator iterator) {
        return evaluateROCMultiClass(iterator, 0);
    }

    /**
     * Evaluate the network on the specified data, using the {@link ROCMultiClass} class
     *
     * @param iterator          Data to evaluate on
     * @param rocThresholdSteps Number of threshold steps to use with {@link ROCMultiClass}
     * @return Multi-class ROC evaluation on the given dataset
     */
    public <T extends ROCMultiClass> T evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) {
        Layer outputLayer = getOutputLayer(0);
        if(getConfiguration().isValidateOutputLayerConfig()){
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class);
        }
        return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps))[0];
    }

    /**
     * Evaluate the network on the specified data, using the {@link ROCMultiClass} class
     *
     * @param iterator          Data to evaluate on
     * @param rocThresholdSteps Number of threshold steps to use with {@link ROCMultiClass}
     * @return Multi-class ROC evaluation on the given dataset
     */
    public <T extends ROCMultiClass> T evaluateROCMultiClass(MultiDataSetIterator iterator, int rocThresholdSteps) {
        Layer outputLayer = getOutputLayer(0);
        if(getConfiguration().isValidateOutputLayerConfig()){
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class);
        }
        return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps))[0];
    }

    /**
     * Perform evaluation on the given data (DataSetIterator) with the given {@link IEvaluation} instance
     *
     * @param iterator   Test data to evaluate on
     * @param evaluations IEvaluation instances
     * @param <T>        Type of the IEvaluation instance
     * @return The input IEvaluation instance, after performing evaluation on the test data
     */
    public <T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T... evaluations) {
        return doEvaluation(new MultiDataSetIteratorAdapter(iterator), evaluations);
    }

    /**
     * Perform evaluation on the given data (MultiDataSetIterator) with the given {@link IEvaluation} instance
     *
     * @param iterator    Test data to evaluate on
     * @param evaluations IEvaluation insntance
     * @param <T>         Type of the IEvaluation instance
     * @return The input IEvaluation instance, after performing evaluation on the test data
     */
    public <T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations) {
        try{
            return doEvaluationHelper(iterator, evaluations);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    /**
     * Perform evaluation for networks with multiple outputs.
     *
     * @param iterator    Data to evaluate
     * @param evaluations Evaluation instances. Key: the network output number (0 to numOutputs-1). Value: the IEvaluation
     *                    instances to perform evaluation with, for that output only. Note that not every output needs to
     *                    have an IEvaluation[] defined.
     * @return The same evaluation map, after performing evaluation
     */
    public <T extends IEvaluation> Map<Integer, T[]> evaluate(DataSetIterator iterator, Map<Integer,T[]> evaluations){
        return evaluate(new MultiDataSetIteratorAdapter(iterator), evaluations);
    }

    /**
     * Perform evaluation for networks with multiple outputs.
     *
     * @param iterator    Data to evaluate
     * @param evaluations Evaluation instances. Key: the network output number (0 to numOutputs-1). Value: the IEvaluation
     *                    instances to perform evaluation with, for that output only. Note that not every output needs to
     *                    have an IEvaluation[] defined.
     * @return The same evaluation map, after performing evaluation
     */
    public <T extends IEvaluation> Map<Integer, T[]> evaluate(MultiDataSetIterator iterator, Map<Integer,T[]> evaluations){
        try{
            return doEvaluationHelper(iterator, evaluations);
        } catch (OutOfMemoryError e){
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    @SuppressWarnings("unchecked")
    @SafeVarargs
    private final <T extends IEvaluation> T[] doEvaluationHelper(MultiDataSetIterator iterator, T... evaluations) {
        Map<Integer,IEvaluation[]> map = Collections.singletonMap(0, (IEvaluation[])evaluations);
        return (T[])doEvaluationHelper(iterator, map).get(0);
    }

    private <T extends IEvaluation> Map<Integer,T[]> doEvaluationHelper(MultiDataSetIterator iterator, Map<Integer, T[]> evaluations){
        if (layers == null || !(getOutputLayer(0) instanceof IOutputLayer)) {
            throw new IllegalStateException("Cannot evaluate network with no output layer");
        }

        WorkspaceUtils.assertNoWorkspacesOpen("Expected no external workspaces open at start of evaluation (doEvaluationHelper)");

        if (iterator.resetSupported() && !iterator.hasNext())
            iterator.reset();

        MultiDataSetIterator iter =
                iterator.asyncSupported() ? new AsyncMultiDataSetIterator(iterator, 2, true) : iterator;

        WorkspaceMode cMode = configuration.getTrainingWorkspaceMode();
        configuration.setTrainingWorkspaceMode(configuration.getInferenceWorkspaceMode());

        boolean useRnnSegments = (configuration.getBackpropType() == BackpropType.TruncatedBPTT);

        MemoryWorkspace outputWs;
        if(getConfiguration().getInferenceWorkspaceMode() == WorkspaceMode.ENABLED){
            outputWs = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM);
        } else {
            outputWs = new DummyWorkspace();
        }

        while (iter.hasNext()) {
            MultiDataSet next = iter.next();

            if (next.getFeatures() == null || next.getLabels() == null)
                continue;

            if (!useRnnSegments) {
                //Standard/non-RNN case

                //Assuming single output here
                INDArray[] features = next.getFeatures();
                INDArray[] featuresMasks = next.getFeaturesMaskArrays();
                INDArray[] labels = next.getLabels();
                INDArray[] labelMasks = next.getLabelsMaskArrays();

                try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) {
                    INDArray[] out = outputOfLayersDetached(false, FwdPassType.STANDARD, getOutputLayerIndices(), features, featuresMasks, labelMasks, true, false, ws);

                    for (Integer i : evaluations.keySet()) {
                        Preconditions.checkState(i >= 0 && i <labels.length, "Invalid output index: evaluation/output indices must be between 0" +
                                " and numOutputs-1 (%s), got index %s", numOutputArrays, (int)i);
                        IEvaluation[] evalsThisOutput = evaluations.get(i);
                        if (evalsThisOutput == null)
                            continue;

                        Preconditions.checkState(i >= 0 && i < getNumOutputArrays(), "Invalid output index: indices for outputs " +
                                "must be between 0 and %s inclusive - found index %s", numOutputArrays, (int) i);
                        INDArray currOut = out[i];
                        INDArray currLabel = labels[i];

                        try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
                            for (IEvaluation evaluation : evalsThisOutput)
                                evaluation.eval(currLabel, currOut, next.getLabelsMaskArray(i));
                        }
                    }
                }


            } else {
                rnnClearPreviousState();

                int fwdLen = configuration.getTbpttFwdLength();
                long tsLength = -1;
                long nF = next.getFeatures().length;
                for (int i = 0; i < nF; i++) {
                    if (next.getFeatures(i).rank() == 3) {
                        tsLength = next.getFeatures(i).size(2);
                    }
                }
                if (tsLength < 0) {
                    throw new IllegalStateException("Invalid configuration: detected TBPTT backprop type without" +
                            " time series features");
                }

                long nSubsets = tsLength / fwdLen;
                if (tsLength % fwdLen != 0)
                    nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20)
                for (int i = 0; i < nSubsets; i++) {
                    int startTimeIdx = i * fwdLen;
                    long endTimeIdx = Math.min(startTimeIdx + fwdLen, tsLength);

                    List<INDArray[]> subset = getSubsetsForTbptt(startTimeIdx, endTimeIdx, next.getFeatures(),
                            next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
                    setLayerMaskArrays(subset.get(2), subset.get(3));

                    try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) {
                        INDArray[] outSub = rnnTimeStep(ws, subset.get(0));

                        for (Integer idx : evaluations.keySet()) {
                            IEvaluation[] evalsThisOutput = evaluations.get(idx);
                            if (evalsThisOutput == null)
                                continue;

                            INDArray labelSub = (subset.get(1) == null ? null : subset.get(1)[idx]);
                            INDArray maskSub = subset.get(3) == null ? null : subset.get(3)[idx];
                            INDArray currOut = outSub[idx];
                            try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
                                for (IEvaluation evaluation : evalsThisOutput)
                                    evaluation.eval(labelSub, currOut, maskSub);
                            }
                        }
                    }
                }

                rnnClearPreviousState();
            }

            //Clear inputs, masks etc. Important to avoid leaking invalidated/out of scope arrays between iterations
            clearLayersStates();
        }

        if (iterator.asyncSupported())
            ((AsyncMultiDataSetIterator) iter).shutdown();

        configuration.setTrainingWorkspaceMode(cMode);

        return (Map<Integer, T[]>) evaluations;
    }

    /**
     * String detailing the architecture of the computation graph.
     * Vertices are printed in a topological sort order.
     * Columns are Vertex Names with layer/vertex type, nIn, nOut, Total number of parameters and the Shapes of the parameters
     * And the inputs to the vertex
     * Will also give information about frozen layers/vertices, if any.
     *
     * @return Summary as a string
     * @see #memoryInfo(int, InputType...)
     */
    public String summary() {
        return summary((InputType[])null);
    }

    /**
     * String detailing the architecture of the computation graph.
     * Will also display activation size when given an input type.
     * Vertices are printed in a topological sort order.
     * Columns are Vertex Names with layer/vertex type, nIn, nOut, Total number of parameters and the Shapes of the parameters
     * And the inputs to the vertex
     * Will also give information about frozen layers/vertices, if any.
     *
     * @return Summary as a string
     * @see #memoryInfo(int, InputType...)
     */
    public String summary(InputType... inputTypes) {
        StringBuilder ret = new StringBuilder();
        ret.append("\n");

        int frozenParams = 0;
        Map<String, InputType> vertexOutputs = new HashMap<>(); //vertex name and output types
        int currLayerIdx = -1;

        List<String[]> lines = new ArrayList<>();
        if(inputTypes == null){
            lines.add(new String[]{"VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs"});
        } else {
            lines.add(new String[]{"VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs", "InputShape", "OutputShape"});
        }
        int[] maxLength = new int[inputTypes == null || inputTypes.length == 0 ? 5 : 7];
        String[] header = lines.get(0);
        for( int i=0; i<header.length; i++ ){
            maxLength[i] = header[i].length();
        }

        if(topologicalOrder == null){
            GraphIndices indices = calculateIndices();
            topologicalOrder = indices.getTopologicalSortOrder();
        }

        for (int currVertexIdx : topologicalOrder) {

            GraphVertex currentVertex = vertices[currVertexIdx];
            String currentVertexName = currentVertex.getVertexName();

            //String vars for print
            String[] classNameArr = currentVertex.getClass().toString().split("\\.");
            String className = classNameArr[classNameArr.length - 1];
            String connections = "-";
            String inShape = "-";
            String outShape = "-";
            String paramCount = "-";
            String in = "-";
            String out = "-";
            String paramShape = "-";
            if (currentVertex.isInputVertex()) {
                if (inputTypes != null) vertexOutputs.put(currentVertexName, inputTypes[configuration.getNetworkInputs().indexOf(currentVertexName)]); //for input vertices the outputs are just the input types (only layer vertices have preprocessing?)
            } else {
                connections = configuration.getVertexInputs().get(currentVertexName).toString();
                List<InputType> inputTypeList = new ArrayList<>();
                if (currentVertex.hasLayer()) {
                    Layer currentLayer = ((LayerVertex) currentVertex).getLayer();
                    classNameArr = currentLayer.getClass().getName().split("\\.");
                    className = classNameArr[classNameArr.length - 1];
                    paramCount = String.format("%,d", currentLayer.numParams());
                    //layer with params
                    if (currentLayer.numParams() > 0) {
                        paramShape = "";
                        if (currentLayer instanceof BidirectionalLayer) { // Bidirectional layer is not an FFL
                            BidirectionalLayer bi = (BidirectionalLayer) currentLayer;
                            in = String.valueOf(((Bidirectional)bi.conf().getLayer()).getNIn());
                            out = String.valueOf(((Bidirectional)bi.conf().getLayer()).getNOut());
                        } else {
                            try {
                                in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn());
                                out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut());
                            }
                            catch (Exception e) { // Some layers, like PReLU, are just BaseLayers (but have parameters)
                            }
                        }
                        List<String> paraNames = currentLayer.conf().variables();
                        for (String aP : paraNames) {
                            String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape());
                            paramShape += aP + ":" + paramS + ", ";
                        }
                        paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString();
                    }
                    //frozen layer
                    if (currentLayer instanceof FrozenLayer) {
                        frozenParams += currentLayer.numParams();
                        classNameArr = ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName().split("\\.");
                        className = "Frozen " + classNameArr[classNameArr.length - 1];
                    }

                    if (inputTypes != null) {
                        //get input type
                        String inputVertexName = vertices[currentVertex.getInputVertices()[0].getVertexIndex()].getVertexName();
                        InputType currentInType = vertexOutputs.get(inputVertexName);
                        inShape = currentInType.toString();
                        inputTypeList.add(currentInType);

                        InputPreProcessor layerVertexPreProcesor = ((org.deeplearning4j.nn.conf.graph.LayerVertex)configuration.getVertices().get(currentVertexName)).getPreProcessor();
                        if (layerVertexPreProcesor != null) {
                            inShape += "-->" + layerVertexPreProcesor.getOutputType(currentInType);
                        }
                    }
                    currLayerIdx++;
                } else {
                    //get input type
                    if (inputTypes != null) {
                        VertexIndices[] inputVertices = currentVertex.getInputVertices();
                        if (inputVertices != null) {
                            for (int i = 0; i < inputVertices.length; i++) {
                                GraphVertex thisInputVertex = vertices[inputVertices[i].getVertexIndex()];
                                inputTypeList.add(vertexOutputs.get(thisInputVertex.getVertexName()));
                            }
                        }
                    }
                }
                if (inputTypes != null) {
                    InputType currentVertexOutputType = configuration.getVertices().get(currentVertexName).getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
                    outShape = currentVertexOutputType.toString();
                    vertexOutputs.put(currentVertexName, currentVertexOutputType);
                }
            }

            //Add on to summary string
            String[] line;
            if (inputTypes == null) {
                line = new String[]{currentVertexName + " (" + className + ")", in + "," + out, paramCount, paramShape, connections};
            } else {
                line = new String[]{currentVertexName + " (" + className + ")", in + "," + out, paramCount, paramShape, connections, inShape, outShape};
            }
            for( int i=0; i<line.length; i++ ){
                maxLength[i] = Math.max(maxLength[i], line[i] == null ? 0 : line[i].length());
            }
            lines.add(line);
        }

        StringBuilder sbFormat = new StringBuilder();
        int totalLength = 0;
        int pos = 0;
        for(int length : maxLength){
            int currLength;
            if(pos++ == maxLength.length-1){
                currLength = length;
            } else {
                currLength = length+3;
            }
            sbFormat.append("%-").append(currLength).append("s");
            totalLength += currLength;
        }
        sbFormat.append("\n");
        String format = sbFormat.toString();



        ret.append(StringUtils.repeat("=", totalLength))
                .append("\n");

        boolean first = true;
        for(String[] line : lines){
            String formatted = String.format(format, (Object[])line);
            ret.append(formatted);
            if(first){
                ret.append(StringUtils.repeat("=", totalLength)).append("\n");
                first = false;
            }
        }

        ret.append(StringUtils.repeat("-", totalLength))
                .append(String.format("\n%30s %,d", "Total Parameters: ", params().length()))
                .append(String.format("\n%30s %,d", "Trainable Parameters: ", params().length() - frozenParams))
                .append(String.format("\n%30s %,d", "Frozen Parameters: ", frozenParams))
                .append("\n")
                .append(StringUtils.repeat("=", totalLength))
                .append("\n");

        return ret.toString();
    }

    /**
     * Generate information regarding memory use for the network, for the given input types and minibatch size.
     * Note that when using workspaces or CuDNN, the network should be trained for some iterations so that the memory
     * workspaces have time to initialize. Without this, the memory requirements during training may be underestimated.
     *
     * Note also that this is the same information that is generated during an OOM crash when training or performing
     * inference.
     *
     * @param minibatch    Minibatch size to estimate memory for
     * @param inputTypes   Input types to the network
     * @return A String with information about network memory use information
     */
    public String memoryInfo(int minibatch, InputType... inputTypes){
        return CrashReportingUtil.generateMemoryStatus(this, minibatch, inputTypes);
    }

    /**
     * This method just makes sure there's no state preserved within layers
     */
    public void clearLayersStates() {
        for (Layer layer : layers) {
            layer.clear();
            layer.clearNoiseWeightParams();
        }

        for (GraphVertex vertex : vertices) {
            vertex.clearVertex();
        }
    }

    /**
     * Increment the epoch count (in the underlying {@link ComputationGraphConfiguration} by 1).
     * Note that this is done <i>automatically</i> when using iterator-based fitting methods, such as
     * {@link #fit(DataSetIterator)} or {@link #fit(MultiDataSet)}. However, when using non-iterator fit methods
     * (DataSet, MultiDataSet, INDArrays etc), the network has no way to know when one epoch ends and another starts.
     * In such situations, this method can be used to increment the epoch counter.<br>
     * Note that the epoch counter is used for situations such as some learning rate schedules, and the like.
     *
     * The current epoch count can be obtained using {@code ComputationGraph.getConfiguration().getEpochCount()}
     */
    public void incrementEpochCount(){
        configuration.setEpochCount(configuration.getEpochCount() + 1);
        synchronizeIterEpochCounts();
    }

    protected void synchronizeIterEpochCounts(){
        //TODO: this is necessrry for some schedules - but the redundant values are a little ugly...
        int currIter = getConfiguration().getIterationCount();
        int currEpoch = getConfiguration().getEpochCount();
        for(Layer l : layers){
            l.setIterationCount(currIter);
            l.setEpochCount(currEpoch);
        }
    }

    /**
     * Returns the number of iterations (parameter updates) that the ComputationGraph has done
     * @return Number of iterations
     */
    public int getIterationCount(){
        return configuration.getIterationCount();
    }

    /**
     * Returns the number of epochs that the ComputationGraph has done.
     * Note that the epoch count is incremented only when {@link #fit(DataSetIterator)}, {@link #fit(MultiDataSetIterator)},
     * {@link #fit(DataSetIterator, int)} or {@link #fit(MultiDataSetIterator, int)} are used.
     * The epoch count can also be manually incremented using {@link #incrementEpochCount()}
     * @return Number of epochs
     */
    public int getEpochCount(){
        return configuration.getEpochCount();
    }

    /**
     * Save the ComputationGraph to a file. Restore using {@link #load(File, boolean)}.
     * Note that this saves the updater (i.e., the state array for momentum/Adam/rmsprop etc), which is desirable
     * if further training will be undertaken.
     *
     * @param f File to save the network to
     * @see ModelSerializer ModelSerializer for more details (and saving/loading via streams)
     * @see #save(File, boolean)
     */
    public void save( File f ) throws IOException {
        save(f, true);
    }

    /**
     * Save the ComputationGraph to a file. Restore using {@link #load(File, boolean)}.
     *
     * @param f File to save the network to
     * @param saveUpdater If true: save the updater (i.e., the state array for momentum/Adam/rmsprop etc), which should
     *                    usually be saved if further training is required
     * @see ModelSerializer ModelSerializer for more details (and saving/loading via streams)
     * @see #save(File, boolean)
     */
    public void save(File f, boolean saveUpdater) throws IOException{
        ModelSerializer.writeModel(this, f, saveUpdater);
    }

    /**
     * Restore a ComputationGraph to a file, saved using {@link #save(File)} or {@link ModelSerializer}
     * @param f File to load the network from
     * @param loadUpdater If true: load the updater if it is available (i.e., the state array for momentum/Adam/rmsprop
     *                   etc) - use <i>false</i> if no further training is required, or <i>true</i> if further training
     *                    will be undertaken
     * @see ModelSerializer ModelSerializer for more details (and saving/loading via streams)
     */
    public static ComputationGraph load(File f, boolean loadUpdater) throws IOException {
        return ModelSerializer.restoreComputationGraph(f, loadUpdater);
    }

    /**
     * Return a copy of the network with the parameters and activations set to use the specified (floating point) data type.
     * If the existing datatype is the same as the requested dataype, the original network will be returned unchanged.
     * Only floating point datatypes (DOUBLE, FLOAT, HALF) may be used.
     *
     * @param dataType Datatype to convert the network to
     * @return The network, set to use the specified datatype for the parameters and activations
     */
    public ComputationGraph convertDataType(@NonNull DataType dataType){
        Preconditions.checkState(dataType.isFPType(), "Invalid DataType: %s. Can only convert network to a floating point type", dataType);
        if(dataType == params().dataType()){
            return this;
        }

        try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
            INDArray newParams = params().castTo(dataType);
            String jsonConfig = getConfiguration().toJson();
            ComputationGraphConfiguration newConf = ComputationGraphConfiguration.fromJson(jsonConfig);
            newConf.setDataType(dataType);
            ComputationGraph newNet = new ComputationGraph(newConf);
            newNet.init(newParams, false);

            Updater u = getUpdater(false);
            if(u != null && u.getStateViewArray() != null){
                INDArray oldUpdaterState = u.getStateViewArray();
                newNet.getUpdater(true).getStateViewArray().assign(oldUpdaterState);
            }
            return newNet;
        }
    }

    /**
     * Set the learning rate for all layers in the network to the specified value. Note that if any learning rate
     * schedules are currently present, these will be removed in favor of the new (fixed) learning rate.<br>
     * <br>
     * <b>Note</b>: <i>This method not free from a performance point of view</i>: a proper learning rate schedule
     * should be used in preference to calling this method at every iteration.
     *
     * @param newLr New learning rate for all layers
     * @see #setLearningRate(ISchedule)
     * @see #setLearningRate(String, double)
     */
    public void setLearningRate(double newLr) {
        NetworkUtils.setLearningRate(this, newLr);
    }

    /**
     * Set the learning rate schedule for all layers in the network to the specified schedule.
     * This schedule will replace any/all existing schedules, and also any fixed learning rate values.<br>
     * Note that the iteration/epoch counts will <i>not</i> be reset. Use {@link ComputationGraphConfiguration#setIterationCount(int)}
     * and {@link ComputationGraphConfiguration#setEpochCount(int)} if this is required
     *
     * @param newLr New learning rate schedule for all layers
     * @see #setLearningRate(ISchedule)
     * @see #setLearningRate(String, double)
     */
    public void setLearningRate(ISchedule newLr) {
        NetworkUtils.setLearningRate(this, newLr);
    }

    /**
     * Set the learning rate for a single layer in the network to the specified value. Note that if any learning rate
     * schedules are currently present, these will be removed in favor of the new (fixed) learning rate.<br>
     * <br>
     * <b>Note</b>: <i>This method not free from a performance point of view</i>: a proper learning rate schedule
     * should be used in preference to calling this method at every iteration. Note also that
     * {@link #setLearningRate(double)} should also be used in preference, when all layers need to be set to a new LR
     *
     * @param layerName Name of the layer to set the LR for
     * @param newLr     New learning rate for a single layer
     * @see #setLearningRate(ISchedule)
     * @see #setLearningRate(String, double)
     */
    public void setLearningRate(String layerName, double newLr) {
        NetworkUtils.setLearningRate(this, layerName, newLr);
    }

    /**
     * Set the learning rate schedule for a single layer in the network to the specified value.<br>
     * Note also that {@link #setLearningRate(ISchedule)} should also be used in preference, when all layers need
     * to be set to a new LR schedule.<br>
     * This schedule will replace any/all existing schedules, and also any fixed learning rate values.<br>
     * Note also that the iteration/epoch counts will <i>not</i> be reset. Use {@link ComputationGraphConfiguration#setIterationCount(int)}
     * and {@link ComputationGraphConfiguration#setEpochCount(int)} if this is required
     *
     * @param layerName Name of the layer to set the LR schedule for
     * @param newLr     New learning rate for a single layer
     * @see #setLearningRate(ISchedule)
     * @see #setLearningRate(String, double)
     */
    public void setLearningRate(String layerName, ISchedule newLr) {
        NetworkUtils.setLearningRate(this, layerName, newLr);
    }

    /**
     * Get the current learning rate, for the specified layer, from the network.
     * Note: If the layer has no learning rate (no parameters, or an updater without a learning rate) then null is returned
     * @param layerName   Layer name
     * @return Learning rate for the specified layer, or null
     */
    public Double getLearningRate(String layerName){
        return NetworkUtils.getLearningRate(this, layerName);
    }

    /**
     * Return the layer size (number of units) for the specified layer.
     * Note that the meaning of the "layer size" can depend on the type of layer. For example:<br>
     * - DenseLayer, OutputLayer, recurrent layers: number of units (nOut configuration option)<br>
     * - ConvolutionLayer: the channels (number of channels)<br>
     * - Subsampling layers, global pooling layers, etc: size of 0 is always returned<br>
     *
     * @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive
     * @return Size of the layer
     */
    public long layerSize(int layer) {
        if (layer < 0 || layer > layers.length) {
            throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and "
                    + (layers.length - 1) + " inclusive");
        }
        return layerSize(layers[layer].conf().getLayer().getLayerName());
    }

    /**
     * Return the input size (number of inputs) for the specified layer.<br>
     * Note that the meaning of the "input size" can depend on the type of layer. For example:<br>
     * - DenseLayer, OutputLayer, etc: the feature vector size (nIn configuration option)<br>
     * - Recurrent layers: the feature vector size <i>per time step</i> (nIn configuration option)<br>
     * - ConvolutionLayer: the channels (number of channels)<br>
     * - Subsampling layers, global pooling layers, etc: size of 0 is always returned<br>
     *
     * @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive
     * @return Size of the layer
     */
    public long layerInputSize(int layer) {
        if (layer < 0 || layer > layers.length) {
            throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and "
                    + (layers.length - 1) + " inclusive");
        }
        return layerInputSize(layers[layer].conf().getLayer().getLayerName());
    }

    /**
     * Return the layer size (number of units) for the specified layer.<br>
     * Note that the meaning of the "layer size" can depend on the type of layer. For example:<br>
     * - DenseLayer, OutputLayer, recurrent layers: number of units (nOut configuration option)<br>
     * - ConvolutionLayer: the channels (number of channels)<br>
     * - Subsampling layers, global pooling layers, etc: size of 0 is always returned<br>
     *
     * @param layerName Name of the layer to get the size of
     * @return Size of the layer
     */
    public long layerSize(String layerName) {
        Layer l = getLayer(layerName);
        if(l == null){
            throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
        }
        org.deeplearning4j.nn.conf.layers.Layer conf = l.conf().getLayer();
        if (conf == null || !(conf instanceof FeedForwardLayer)) {
            return 0;
        }
        FeedForwardLayer ffl = (FeedForwardLayer) conf;

        return ffl.getNOut();
    }

    /**
     * Return the input size (number of inputs) for the specified layer.<br>
     * Note that the meaning of the "input size" can depend on the type of layer. For example:<br>
     * - DenseLayer, OutputLayer, etc: the feature vector size (nIn configuration option)<br>
     * - Recurrent layers: the feature vector size <i>per time step</i> (nIn configuration option)<br>
     * - ConvolutionLayer: the channels (number of channels)<br>
     * - Subsampling layers, global pooling layers, etc: size of 0 is always returned<br>
     *
     * @param layerName Name of the layer to get the size of
     * @return Size of the layer
     */
    public long layerInputSize(String layerName) {
        Layer l = getLayer(layerName);
        if(l == null){
            throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
        }
        org.deeplearning4j.nn.conf.layers.Layer conf = l.conf().getLayer();
        if (conf == null || !(conf instanceof FeedForwardLayer)) {
            return 0;
        }
        FeedForwardLayer ffl = (FeedForwardLayer) conf;

        return ffl.getNIn();
    }

    /**
     * Indicates whether some other object is "equal to" this one.
     * <p>
     * The {@code equals} method implements an equivalence relation
     * on non-null object references:
     * <ul>
     * <li>It is <i>reflexive</i>: for any non-null reference value
     * {@code x}, {@code x.equals(x)} should return
     * {@code true}.
     * <li>It is <i>symmetric</i>: for any non-null reference values
     * {@code x} and {@code y}, {@code x.equals(y)}
     * should return {@code true} if and only if
     * {@code y.equals(x)} returns {@code true}.
     * <li>It is <i>transitive</i>: for any non-null reference values
     * {@code x}, {@code y}, and {@code z}, if
     * {@code x.equals(y)} returns {@code true} and
     * {@code y.equals(z)} returns {@code true}, then
     * {@code x.equals(z)} should return {@code true}.
     * <li>It is <i>consistent</i>: for any non-null reference values
     * {@code x} and {@code y}, multiple invocations of
     * {@code x.equals(y)} consistently return {@code true}
     * or consistently return {@code false}, provided no
     * information used in {@code equals} comparisons on the
     * objects is modified.
     * <li>For any non-null reference value {@code x},
     * {@code x.equals(null)} should return {@code false}.
     * </ul>
     * <p>
     * The {@code equals} method for class {@code Object} implements
     * the most discriminating possible equivalence relation on objects;
     * that is, for any non-null reference values {@code x} and
     * {@code y}, this method returns {@code true} if and only
     * if {@code x} and {@code y} refer to the same object
     * ({@code x == y} has the value {@code true}).
     * <p>
     * Note that it is generally necessary to override the {@code hashCode}
     * method whenever this method is overridden, so as to maintain the
     * general contract for the {@code hashCode} method, which states
     * that equal objects must have equal hash codes.
     *
     * @param obj the reference object with which to compare.
     * @return {@code true} if this object is the same as the obj
     * argument; {@code false} otherwise.
     * @see #hashCode()
     * @see HashMap
     */
    @Override
    public boolean equals(Object obj) {
        if (obj == null)
            return false;
        if (obj instanceof ComputationGraph) {
            ComputationGraph network = (ComputationGraph) obj;
            boolean paramsEquals = network.params().equals(params());
            boolean confEquals = getConfiguration().equals(network.getConfiguration());
            boolean updaterEquals = getUpdater().equals(network.getUpdater());
            return paramsEquals && confEquals && updaterEquals;
        }
        return false;
    }

    private void writeObject(ObjectOutputStream oos) throws IOException {
        ModelSerializer.writeModel(this, oos, true);
    }

    private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
        val cg = ModelSerializer.restoreComputationGraph(ois, true);

        this.defaultConfiguration = cg.defaultConfiguration.clone();
        this.configuration = cg.configuration.clone();
        this.init();
        this.flattenedParams.assign(cg.flattenedParams);

        if (cg.getUpdater() != null && cg.getUpdater(false).getStateViewArray() != null)
            this.getUpdater(true).getStateViewArray().assign(cg.getUpdater(false).getStateViewArray());
    }
}