deeplearning4j/deeplearning4j

View on GitHub
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java

Summary

Maintainability
F
2 wks
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.samediff.internal;

import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.config.ExecutionResult;
import org.nd4j.autodiff.samediff.config.SDValue;
import org.nd4j.autodiff.samediff.config.SDValueType;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.common.function.Predicate;

import java.util.*;
import java.util.stream.Collectors;

import static org.nd4j.imports.VariableUtils.stripVarSuffix;

@Slf4j
public abstract class AbstractSession<T, O> {

    /**
     * All execution in Samediff happens in a frame... this is the name of the
     * main/outer frame - i.e., the "default" frame
     * Other frames (such as for loops) may be nested within this frame
     */
    public static final String OUTER_FRAME = "main";

    protected final SameDiff sameDiff;
    @Getter
    protected final Map<VarId, SDValue> nodeValueOutputs = new LinkedHashMap<>(); // Key: variable (at a given frame +
                                                                                  // iteration). Value: the calculated
                                                                                  // output for that variable

    /*
     * The dependency tracker is responsible for determining what ops (at what
     * frame/iteration) can be executed next, given
     * what has been executed so far.
     * For static graphs, such as abstraction would not be necessary; for dynamic
     * graphs (i.e., nested loops, of arbitrary
     * number of iterations and depth - and also switch ops which can cause whole
     * subgraphs to not be executed) this is necessary
     * Note: the ExecStep represents one step for execution - some steps are as
     * simple as "execute an op (at the given frame/iter)"
     * It works by adding dependencies (X -> Y - such as
     * "op Y depends on the output of op X") and then marking them as
     * satisfied ("op X has been calculated"). Once all dependencies for an
     * execution step have been satisfied, the execution step
     * is added to a queue - outputs of which can be accessed with
     * dt.getNewAllSatisfied() and dt.getNewAllSatisfiedList(),
     * at which point it is removed from the dependency tracker
     */
    protected final DependencyTracker<ExecStep, ExecStep> dt = new DependencyTracker<>();

    /**
     * Contains variables we *might* need to execute in process of getting outputs
     * we want.
     * Variables not in this set are definitely not needed to get the requested
     * output variables, but variables that are
     * in this set may not be executed depending on the graph structure - i.e.,
     * switch ops, etc
     */
    protected final Set<String> subgraph = new LinkedHashSet<>();
    /**
     * As per subgraph set, but for ops instead
     */
    protected final Set<String> subgraphOps = new LinkedHashSet<>();

    /**
     * Contains the names of ops that don't have any inputs. Kept because normally
     * ops are triggered for execution when
     * their all their inputs have been calculated; we'll trigger that step manually
     * during execution initialization
     */
    protected final Set<String> zeroInputOpsInSubgraph = new HashSet<>();

    public AbstractSession(@NonNull SameDiff sameDiff) {
        this.sameDiff = sameDiff;
    }

    public boolean contains(String variable, String frame, int iteration, FrameIter parentFrameIter) {
        VarId varId = new VarId(variable, frame, iteration, parentFrameIter);
        return nodeValueOutputs.containsKey(varId);
    }

    /**
     * Get a previously calculated output; throws an exception if the output does
     * not exist
     */
    public SDValue get(String variable, String frame, int iteration, FrameIter parentFrameIter) {
        return get(variable, frame, iteration, parentFrameIter, true);
    }

    /**
     * Get a previously calculated output
     *
     * @param enforceExistence If true: throw an exception if the array does not
     *                         exist
     */
    public SDValue get(String variable, String frame, int iteration, FrameIter parentFrameIter,
            boolean enforceExistence) {
        // TODO eventually we'll cache and reuse VarId objects here to avoid garbage
        // generation on lookup etc
        VarId varId = new VarId(variable, frame, iteration, parentFrameIter);
        SDValue out = nodeValueOutputs.get(varId);
        if (enforceExistence) {
            Preconditions.checkNotNull(out, "No output found for variable %s (frame %s, iteration %s)", variable, frame,
                    iteration);
        }
        return out;
    }

    /**
     * Get the output of the session - i.e., perform inference/forward pass and
     * return the outputs for the specified variables
     *
     * @param variables           Name of the variables we want the
     *                            arrays/activations for
     * @param placeholderValues   The placeholder values (if any). May be null.
     * @param batch               The batch data, used to call Listener.opExecution
     * @param requiredActivations Additional activations that are required. Won't be
     *                            output, but opExecution will be called. May be
     *                            null.
     * @return The specified variable values, optionally in the specified workspace
     */
    public Map<String, T> output(@NonNull List<String> variables, Map<String, T> placeholderValues,
            MultiDataSet batch, Collection<String> requiredActivations, List<Listener> listeners, At at) {
        ExecutionResult output = output(variables, placeholderValues, Collections.emptyMap(), batch,
                requiredActivations, listeners, at);
        if (output.hasSingle())
            return (Map<String, T>) output.getOutputs();
        else if (output.hasValues()) {
            Map<String, SDValue> outputs = output.getValueOutputs();
            Map<String, INDArray> ret = new LinkedHashMap<>();
            for (Map.Entry<String, SDValue> value : outputs.entrySet()) {
                ret.put(value.getKey(), value.getValue().getTensorValue());
            }

            return (Map<String, T>) ret;
        }

        throw new IllegalStateException("No result output! Expected values or tensors.");
    }

    /**
     * Get the output of the session - i.e., perform inference/forward pass and
     * return the outputs for the specified variables
     *
     * @param variables              Name of the variables we want the
     *                               arrays/activations for
     * @param placeholderValues      The placeholder values (if any). May be null.
     * @param otherPlaceHolderValues other placeholder values that may not be
     *                               ndarrays.
     * @param batch                  The batch data, used to call
     *                               Listener.opExecution
     * @param requiredActivations    Additional activations that are required. Won't
     *                               be output, but opExecution will be called. May
     *                               be null.
     * @return The specified variable values, optionally in the specified workspace
     */
    public ExecutionResult output(@NonNull List<String> variables,
            Map<String, T> placeholderValues,
            Map<String, SDValue> otherPlaceHolderValues,
            MultiDataSet batch,
            Collection<String> requiredActivations,
            List<Listener> listeners, At at) {
        Preconditions.checkState(!variables.isEmpty() || !requiredActivations.isEmpty(),
                "Variables to perform forward pass for must not be empty");

        // ensure all placeholders are in a mutable map
        otherPlaceHolderValues = new LinkedHashMap<>(otherPlaceHolderValues);

        // ensure all placeholders passed in are placed with the other placeholder
        // values for consistency
        // later in execution we only use other place holder values
        if (placeholderValues != null && !placeholderValues.isEmpty()) {
            for (Map.Entry<String, T> placeHolderValue : placeholderValues.entrySet()) {
                if (otherPlaceHolderValues.containsKey(placeHolderValue.getKey())) {
                    throw new IllegalArgumentException(
                            "Unable to determine which placeholder to use. Please ensure all names across both placeholders are unique");
                }

                otherPlaceHolderValues.put(placeHolderValue.getKey(),
                        SDValue.create((INDArray) placeHolderValue.getValue()));
            }
        }

        if (requiredActivations == null)
            requiredActivations = Collections.emptySet();

        if (at == null)
            at = At.defaultAt();

        // Step 0: validation - that variables exist, placeholders have arrays, etc
        for (String s : variables) {
            Preconditions.checkState(sameDiff.variableMap().containsKey(s),
                    "Requested output variable %s does not exist in SameDiff instance", s);
        }

        Set<String> reqOutputVariablesSet = new LinkedHashSet<>(variables);

        placeholderValues = preprocessPlaceholders(placeholderValues, at);
        otherPlaceHolderValues = preprocessValuePlaceholders(otherPlaceHolderValues, at);

        // Clear state from past iterations, if any
        dt.clear();
        subgraph.clear();
        subgraphOps.clear();

        // Step 1: determine subgraph structure we actually need to execute
        // Basic plan: work backwards from the variables we want, based on the graph
        // structure, to work out what
        // we actually need to execute
        // TODO we'll optimize this and cache the results, only recalculating if the
        // graph structure changes
        Set<String> userRequestedUnique = new LinkedHashSet<>(variables);
        Set<String> allRequired = new LinkedHashSet<>(requiredActivations);
        allRequired.addAll(variables);
        initSubgraph(allRequired);

        // Step 2: Check that we have required placeholders
        List<String> phNames = sameDiff.inputs();
        Set<String> presentPlaceholders = new HashSet<>();
        // add all placeholder values together
        if (placeholderValues != null && !placeholderValues.isEmpty())
            presentPlaceholders.addAll(placeholderValues.keySet());
        if (otherPlaceHolderValues != null && !otherPlaceHolderValues.isEmpty())
            presentPlaceholders.addAll(otherPlaceHolderValues.keySet());

        if (presentPlaceholders.isEmpty() || !presentPlaceholders.containsAll(phNames)) {
            /*
             * We only have a subset of all placeholders
             * Validate that we have all *required* placeholder values. Some might not be
             * needed to calculate the requested outputs
             * A placeholder is required if:
             * (a) It's one of the requested outputs
             * (b) It's required to calculate any of the ops in the subgraph
             * For example, we might have a label placeholder, and we're doing inference not
             * training
             */
            for (String s : phNames) {
                boolean required = false;
                if (variables.contains(s)) {
                    required = true;
                }
                if (!required) {
                    Variable v = sameDiff.getVariables().get(s);
                    if (v.getInputsForOp() != null) {
                        for (String s2 : v.getInputsForOp()) {
                            if (subgraph.contains(s2)) {
                                // Placeholder is required
                                required = true;
                                break;
                            }
                        }
                    }
                }

                if (required && (presentPlaceholders.isEmpty() || !presentPlaceholders.contains(s))) {
                    throw new IllegalStateException(
                            "An input placeholder \"" + s + "\" is required to calculate the requested outputs," +
                                    " but a placeholder value was not provided");
                }
            }
        }

        // Step 3: Mark the (required) variables, constants and placeholders as
        // available via dependency tracker
        // And also any "zero dependency" ops - i.e., those without any inputs
        ExecStep start = new ExecStep(ExecType.EXEC_START, "", null); // Dummy dependency to trigger the variables and
                                                                      // constants
        for (SDVariable v : sameDiff.variables()) {
            VariableType vt = v.getVariableType();
            if (vt == VariableType.VARIABLE || vt == VariableType.CONSTANT) {
                ExecType et = vt == VariableType.VARIABLE ? ExecType.VARIABLE : ExecType.CONSTANT;
                ExecStep es = new ExecStep(et, v.name(), new FrameIter(OUTER_FRAME, 0, null));
                dt.addDependency(es, start);

                Variable var = sameDiff.getVariables().get(v.name());
                if (var.getControlDeps() != null) {
                    addVarControlDeps(es, var); // Before this variable can be considered available for use, we need
                                                // specified op to be executed
                }
            }
        }

        for (String s : phNames) {
            ExecStep es = new ExecStep(ExecType.PLACEHOLDER, s, new FrameIter(OUTER_FRAME, 0, null));
            dt.addDependency(es, start);

            Variable var = sameDiff.getVariables().get(s);
            if (var.getControlDeps() != null) {
                addVarControlDeps(es, var); // Before this variable can be considered available for use, we need
                                            // specified op to be executed
            }
        }

        for (String s : zeroInputOpsInSubgraph) {
            ExecStep es = new ExecStep(ExecType.OP, s, new FrameIter(OUTER_FRAME, 0, null));
            dt.addDependency(es, start);
        }
        dt.markSatisfied(start, true);

        // Step 4: execute in any order, but not switching to new frame/iteration until
        // all from current frame/iter ops
        // are done - until we have all required nodeOutputs
        /*
         * The idea is simple: we start off with a set of "available to execute"
         * variables - just the placeholders,
         * constants and variables (assuming no control dependencies) at the start of
         * execution.
         *
         * Then, we remove an "available to execute" node and execute it. Execution may
         * be:
         * (a) For constants, variable type SDVariables, and placeholders: just look up
         * the value
         * (b) For variables as outputs of ops: actually execute the op
         *
         * After execution, we look at the graph structure and determine what that now
         * executed/calculated variable is
         * an input to. If all inputs are available for the op, we mark all output
         * variables of that op as available for execution.
         * Both parts of this (tracking dependencies, and also what's now available to
         * execute) are handled in the dependency tracker
         *
         * We stop computation once all the required outputs are available. At this
         * point, subgraph may NOT be empty - for example,
         * switch ops may cause entire branches of the graph to be skipped.
         */

        Map<String, SDValue> outValues = new LinkedHashMap<>();
        Set<String> allExecuted = new LinkedHashSet<>();
        int step = 0; // Number of execution steps
        // Next 3: current execution frame
        String currentFrame = OUTER_FRAME;
        int currentFrameIter = 0;
        FrameIter currParentFrame = null;
        ExecStepPredicate predicate = new ExecStepPredicate();
        while (allExecuted.size() < allRequired.size()) {
            if (!dt.hasNewAllSatisfied()) {
                execFailed(userRequestedUnique, outValues, allRequired, allExecuted, step);
                // note execFailed will not always throw an exception if a user required all
                // variables from
                // outputAll. A common case is conditional paths not being executed. This will
                // just ensure that
                // no other exceptions are thrown.
                break;

            }

            // Get variable in the current frame/iteration and execute it's corresponding op
            // If no more ops exist for the current frame/iter, we'll switch to the next
            // frame/iter
            // The idea is to not mix the order of execution of ops in different
            // frames/iters - i.e., finish the current
            // frame/iter before starting the next one
            predicate.setCurrentFrame(currentFrame);
            predicate.setCurrentFrameIter(currentFrameIter);
            predicate.setCurrParentFrame(currParentFrame);

            ExecStep es = dt.getFirstNewAllSatisfiedMatching(predicate);
            if (es == null) {
                // We must have finished the current frame/iter, and are switching to the next
                // one
                es = dt.getNewAllSatisfied();
            }

            currentFrame = es.getFrameIter().getFrame();
            currentFrameIter = es.getFrameIter().getIteration();
            currParentFrame = es.getFrameIter().getParentFrame();

            log.trace("Beginning execution step {}: {}", step, es);

            FrameIter outFrameIter;
            boolean skipDepUpdate = false; // Only used for Switch ops, which have slightly different handling...
            boolean skipMarkSatisfied = false; // Only for enter ops, because of different frame/iter
            if (es.getType() == ExecType.CONSTANT || es.getType() == ExecType.VARIABLE) {
                VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null);
                T arr = getConstantOrVariable(es.getName());
                Preconditions.checkNotNull(arr, "Encountered null placeholder array for constant: %s", vid);
                putNodeValue(SDValue.create((INDArray) arr), vid);
                outFrameIter = new FrameIter(OUTER_FRAME, 0, null);
                if (userRequestedUnique.contains(es.getName())) {
                    // User requested const/variable as one of the outputs
                    outValues.put(es.getName(), SDValue.create((INDArray) arr));
                }

                if (allRequired.contains(es.getName())) {
                    allExecuted.add(es.getName());
                }
            } else if (es.getType() == ExecType.PLACEHOLDER) {
                VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null);
                if (placeholderValues != null && placeholderValues.containsKey(es.getName())) {
                    T phVal = placeholderValues == null ? null : placeholderValues.get(es.getName());
                    SDValue valueCreate = SDValue.create((INDArray) phVal);
                    putNodeValue(valueCreate, vid);
                } else if (otherPlaceHolderValues != null && otherPlaceHolderValues.containsKey(es.getName())) {
                    SDValue value = otherPlaceHolderValues.get(es.getName());
                    switch (value.getSdValueType()) {
                        default:
                            putNodeValue(value, vid);
                            break;
                        case DICT:
                            throw new UnsupportedOperationException("Unable to process dictionary types.");
                    }
                } else {
                    putNodeValue(null, vid);
                }

                outFrameIter = new FrameIter(OUTER_FRAME, 0, null);
                if (allRequired.contains(es.getName())) {
                    Preconditions.checkState(placeholderValues != null
                            && !placeholderValues.containsKey(es.getName())
                            || otherPlaceHolderValues != null &&
                                    otherPlaceHolderValues.containsKey(es.getName()),
                            "No array was provided for the placeholder variable \"%s\" that is required for execution",
                            es.getName());
                    // User requested placeholder value as one of the outputs
                    if (placeholderValues.containsKey(es.getName()))
                        outValues.put(es.getName(), SDValue.create((INDArray) placeholderValues.get(es.getName())));
                    else if (otherPlaceHolderValues.containsKey(es.getName())) {
                        outValues.put(es.getName(), otherPlaceHolderValues.get(es.getName()));
                    }
                }

                if (allRequired.contains(es.getName())) {
                    allExecuted.add(es.getName());
                }
            } else if (es.getType() == ExecType.OP) {
                String opName = es.getName();
                SameDiffOp op = sameDiff.getOps().get(opName);
                DifferentialFunction o = op.getOp();

                if (o instanceof Enter) {
                    // Enter op: output is variable in a new (specified) frame, iteration 0.
                    // Parent is current (input) frame
                    String outFrame = ((Enter) o).getFrameName();
                    outFrameIter = new FrameIter(outFrame, 0, es.getFrameIter());

                } else if (o instanceof Exit) {
                    outFrameIter = getExitIter(es);
                } else if (o instanceof NextIteration) {
                    // NextIteration op: forwards its single input to its output variable in the
                    // current frame, but increments the iteration number
                    outFrameIter = es.getFrameIter().clone();
                    outFrameIter.setIteration(outFrameIter.getIteration());
                } else {
                    // Standard ops - output variable has same frame and iteration number as the
                    // input(s)
                    // Also loopCond, merge, while, etc
                    outFrameIter = es.getFrameIter();
                }

                // Resolve the inputs to this execution step (op) to actual arrays
                Set<VarId> inputs = null;
                Set<VarId> allIterInputs = null;
                Set<String> constAndPhInputs = null;
                DependencyList<ExecStep, ExecStep> dl = dt.getDependencies(es);

                List<String> inputNames = op.getInputsToOp();
                if (inputNames != null && !inputNames.isEmpty()) {
                    inputs = new LinkedHashSet<>();
                    allIterInputs = new LinkedHashSet<>();
                    constAndPhInputs = new LinkedHashSet<>();
                    Iterable<ExecStep> deps = dl.getDependencies();
                    if (deps != null) {
                        for (ExecStep dep : deps) {
                            switch (dep.getType()) {
                                case OP:
                                case SWITCH_L:
                                case SWITCH_R:
                                    // The current execution step depends on one output of the op "dep"
                                    SameDiffOp toExecOp = sameDiff.getOps().get(es.getName());
                                    List<String> inputsToExecOp = toExecOp.getInputsToOp();
                                    SameDiffOp inputOp = sameDiff.getOps().get(dep.getName());
                                    List<String> inputOpOutNames = inputOp.getOutputsOfOp();
                                    for (String s : inputsToExecOp) {
                                        if (inputOpOutNames.contains(s)) {
                                            VarId vid = new VarId(s, dep.getFrameIter().getFrame(),
                                                    dep.getFrameIter().getIteration(),
                                                    dep.getFrameIter().getParentFrame());
                                            inputs.add(vid);
                                        }
                                    }
                                    break;
                                case VARIABLE:
                                    inputs.add(new VarId(dep.getName(), dep.getFrameIter().getFrame(),
                                            dep.getFrameIter().getIteration(), dep.getFrameIter().getParentFrame()));
                                    break;
                                case CONSTANT:
                                case PLACEHOLDER:
                                    constAndPhInputs.add(dep.getName());
                                    break;
                                default:
                                    throw new UnsupportedOperationException("Not yet implemented: " + dep.getType());
                            }
                        }
                    }
                }

                // Do execution of the op, in 2 steps
                // (a) "Parameterize" the op - i.e., find and set the arrays on the op, allocate
                // outputs, etc ready for execution
                // (b) actually execute the operation
                O parameterizedOp = getAndParameterizeOp(opName, outFrameIter, inputs, allIterInputs, constAndPhInputs,
                        placeholderValues, reqOutputVariablesSet, otherPlaceHolderValues);
                ExecutionResult opOutputValues = getOutputs(parameterizedOp, outFrameIter, inputs, allIterInputs,
                        constAndPhInputs, listeners, at, batch, reqOutputVariablesSet, otherPlaceHolderValues);
                List<String> opOutVarNames = op.getOutputsOfOp();

                int lengthToCheck = opOutputValues.numResults();
                if (!opOutVarNames.isEmpty() && opOutputValues.hasSingle()) {
                    Preconditions.checkState(lengthToCheck == opOutVarNames.size(),
                            "Unexpected number of outputs from executed op %s:" +
                                    " got %s outputs when %s outputs were expected (%s)",
                            parameterizedOp.getClass().getSimpleName(), opOutputValues.numResults(),
                            opOutVarNames.size(), opOutVarNames);
                }
                // Store the op outputs
                for (int i = 0; i < lengthToCheck; i++) {
                    if (opOutputValues.hasSingle() && opOutputValues.resultAt(i) == null
                            || opOutputValues.hasValues() && !opOutputValues.valueExistsAtIndex(i)
                                    && op.getOp() instanceof Switch) {
                        // Switch op only forwards the input to one of the outputs
                        continue;
                    }

                    // control flow ops are actually variables from the input forwarding to the next
                    // frame
                    String n = opOutVarNames.get(i);

                    VarId vid = new VarId(n, outFrameIter.getFrame(), outFrameIter.getIteration(),
                            outFrameIter.getParentFrame());
                    if (opOutputValues.hasValues()) {
                        SDValue sdValue = opOutputValues.valueWithKeyAtIndex(i, false);
                        // values can be null
                        if (sdValue != null)
                            switch (sdValue.getSdValueType()) {
                                case LIST:
                                    // tensor array op
                                    // note: we leave this out since we already update node value outputs earlier
                                    putNodeValue(sdValue, vid);
                                    break;

                                case TENSOR:
                                    putNodeValue(sdValue, vid);
                                    // tensorflow import case where 2 input names are the same and 1 output will be
                                    // null
                                    if (op.getOp() instanceof Switch && inputNames.size() > 1
                                            && inputNames.get(0).equals(inputNames.get(1))) {
                                        putNodeValue(sdValue, vid);
                                        putNodeValue(sdValue, outFrameIter.toVarId(vid.getVariable() + ":1"));
                                    } else {
                                        putNodeValue(sdValue, vid);
                                    }
                                    break;
                            }

                        if (userRequestedUnique.contains(n)) {
                            outValues.put(n, sdValue);
                        }

                    } else {
                        SDValue currValueOutput = SDValue.create(opOutputValues.resultAt(i));
                        putNodeValue(currValueOutput, vid);
                        // ensure a singular value is populated in case the user uses the node value
                        // outputs
                        if (userRequestedUnique.contains(n)) {
                            outValues.put(n, currValueOutput);
                        }

                    }

                    if (allRequired.contains(n)) {
                        allExecuted.add(n);
                    }
                }

                // Post execution: update dependency tracker so we know what is available to
                // execute next, given we now
                // have these new values
                if (o instanceof Switch) {
                    /*
                     * Switch is a special case: only one output/branch is considered to exist post
                     * execution.
                     * Unlike every other type of op, only 1 of 2 output arrays is actually
                     * executed.
                     * For dependency tracking purposes, this is why we have SWITCH_L and _R
                     * execution types.
                     * If we just depended on the op, the dependency tracker would incorrectly
                     * conclude that ops relying on
                     * both branches (i.e., including the unavailable one) can now be executed
                     */
                    skipDepUpdate = true;
                    skipMarkSatisfied = true;
                    String[] argNames = o.argNames();
                    // tensorflow import case: this means we output a list with a single name and
                    // need to extract the null value from that singular list
                    if (argNames[0].equals(argNames[1])) {
                        SDValue sdValue = opOutputValues.getValueOutputs().get(argNames[0]);
                        List<INDArray> inputList = sdValue.getListValue();
                        int nullCount = (inputList.get(0) != null ? 1 : 0) + (inputList.get(1) != null ? 1 : 0);
                        Preconditions.checkState(nullCount == 1,
                                "Expected exactly one output to be present for switch ops, got %s", nullCount);
                        boolean left = inputList.get(0) != null;

                        ExecStep branch;
                        if (left) {
                            branch = new ExecStep(ExecType.SWITCH_L, es.getName(), es.getFrameIter());
                        } else {
                            branch = new ExecStep(ExecType.SWITCH_R, es.getName(), es.getFrameIter());
                        }
                        updateDescendantDeps(branch, outFrameIter);
                        dt.markSatisfied(branch, true);
                    } else {
                        int nullCount = (opOutputValues.valueExistsAtIndex(0) ? 1 : 0)
                                + (opOutputValues.valueExistsAtIndex(1) ? 1 : 0);
                        Preconditions.checkState(nullCount == 1,
                                "Expected exactly one output to be present for switch ops, got %s", nullCount);
                        boolean left = opOutputValues.valueExistsAtIndex(0);
                        ExecStep branch;
                        if (left) {
                            branch = new ExecStep(ExecType.SWITCH_L, es.getName(), es.getFrameIter());
                        } else {
                            branch = new ExecStep(ExecType.SWITCH_R, es.getName(), es.getFrameIter());
                        }
                        updateDescendantDeps(branch, outFrameIter);
                        dt.markSatisfied(branch, true);
                    }

                } else if (o instanceof Enter) {
                    // Enter op: we want to say that the inner frame is executed...
                    skipDepUpdate = true;
                    skipMarkSatisfied = true;
                    Enter e = (Enter) o;
                    FrameIter fi = new FrameIter(e.getFrameName(), 0, es.getFrameIter());
                    ExecStep exec = new ExecStep(ExecType.OP, es.getName(), fi);
                    updateDescendantDeps(exec, fi);
                    dt.markSatisfied(exec, true);
                } else if (o instanceof Exit) {
                    // Exit op: we want to say that the parent frame is executed...
                    skipDepUpdate = true;
                    skipMarkSatisfied = true;
                    FrameIter fi = es.getFrameIter().getParentFrame();
                    ExecStep exec = new ExecStep(ExecType.OP, es.getName(), fi);
                    updateDescendantDeps(exec, fi);
                    dt.markSatisfied(exec, true);
                }

                /*
                 * Edge case for TensorFlow import control dependencies: for some reason, TF
                 * allows op control dependencies
                 * like /while/x -> SomeConstant - i.e., a constant depending on something
                 * inside a scope.
                 * This should be handled with an enter op, but TF doesn't always use this :/
                 * Note that this is equivalent to marking the control dependency as satisfied
                 * on the first iteration
                 * TODO double check that this is exactly the same behaviour as TF - otherwise
                 * this approach might fail in
                 * some rare cases that rely on the constant/variable not being available
                 */
                List<String> cdFor = op.getControlDepFor();
                if (cdFor != null) {
                    ExecStep cdEs = new ExecStep(ExecType.CONTROL_DEP, opName, null);
                    if (!dt.isSatisfied(cdEs)) {
                        dt.markSatisfied(cdEs, true);
                    }
                }

            } else {
                // Should never happen
                throw new RuntimeException("Unknown ExecStep: " + es);
            }

            // Standard ops
            if (!skipDepUpdate) {
                updateDescendantDeps(es, outFrameIter);
            }
            if (!skipMarkSatisfied) {
                dt.markSatisfied(es, true);
            }

            step++;
        }

        // TODO we should clear the node outputs map to get rid of the invalid (closed,
        // out of workspace, etc) arrays

        outValues = postProcessOutputValues(outValues);
        return ExecutionResult.builder()
                .valueOutputs(outValues).build();
    }

    private FrameIter getExitIter(ExecStep es) {
        FrameIter outFrameIter;
        // Exit node forwards input to parent frame
        String outFrame = es.getFrameIter().getParentFrame().getFrame();
        int outIter = es.getFrameIter().getParentFrame().getIteration();
        FrameIter outParentFrame = es.getFrameIter().getParentFrame().getParentFrame();
        outFrameIter = new FrameIter(outFrame, outIter, outParentFrame);
        return outFrameIter;
    }

    /**
     * Add the control dependency from Op -> variable
     *
     * @param es Execution step for the variable
     * @param v  Variable
     */
    protected void addVarControlDeps(ExecStep es, Variable v) {
        List<String> cds = v.getControlDeps();
        if (cds != null) {
            for (String s : cds) {
                ExecStep controlES = new ExecStep(ExecType.CONTROL_DEP, s, null);
                dt.addDependency(es, controlES); // Before this variable can be considered available for use, we need
                                                 // specified op to be executed
            }
        }
    }

    protected SDValue getSdValue(VarId tArr) {
        return nodeValueOutputs.get(tArr);
    }

    protected void setArrayAtIndex(List<INDArray> l, int i, INDArray sub) {
        l.set(i, sub);
    }

    protected void putNodeValue(SDValue sdValue, VarId varId) {
        nodeValueOutputs.put(varId, sdValue);
    }

    protected INDArray getTensorFromOutputs(VarId varId) {
        if (nodeValueOutputs.containsKey(varId) && getSdValue(varId).getTensorValue() != null)
            return getSdValue(varId).getTensorValue();
        return null;
    }

    /**
     * Execution failed - can't calculate all requested outputs, and there's nothing
     * left to calculate.
     * Throws an exception with a useful message
     *
     * @param userRequestedUnique All outputs that the user requested
     * @param out                 Current outputs
     * @param step                Execution step
     */
    protected void execFailed(Set<String> userRequestedUnique, Map<String, SDValue> out, Set<String> allRequired,
            Set<String> allExecuted, int step) {
        int missingCount = userRequestedUnique.size() - out.size();
        StringBuilder sb = new StringBuilder();
        sb.append("No variable are available for execution at step ")
                .append(step).append(": ").append(missingCount).append(" requested output values remaining, ")
                .append(allExecuted.size() - allRequired.size()).append(" variables required to be executed remaining");
        Set<String> missing = new LinkedHashSet<>();
        for (String s : userRequestedUnique) {
            if (!out.containsKey(s)) {
                missing.add(s);
            }
        }

        if (missingCount <= 10) {
            sb.append(". Missing variables: ");
            sb.append(missing);
        } else {
            sb.append(". First 10 missing variables: ");
            Iterator<String> iter = missing.iterator();
            for (int i = 0; i < 10 && iter.hasNext(); i++) {
                if (i > 0)
                    sb.append(",");
                sb.append(iter.next());
            }
        }

        log.warn(
                "Not all required variables were executed. This may be due to conditionals. Missing variables include: "
                        + sb.toString());

    }

    /**
     * Update the descendant dependencies
     * So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the
     * dependency tracker
     * This is for a specific frame and iteration, for both sides of the dependency
     * (in and out)
     *
     * @param justExecuted The execution step that has just completed
     * @param outFrameIter The frame/iteration of the output
     */
    protected void updateDescendantDeps(ExecStep justExecuted, FrameIter outFrameIter) {
        ExecType t = justExecuted.getType();
        String n = justExecuted.getName();
        if (justExecuted.getType() == ExecType.OP) {
            SameDiffOp op = sameDiff.getOps().get(n);
            List<String> outNames = op.getOutputsOfOp();
            for (String s : outNames) {
                Variable v = sameDiff.getVariables().get(s);
                if (v != null) {
                    List<String> inputsToOps = v.getInputsForOp();
                    if (inputsToOps != null) {
                        for (String opName : inputsToOps) {
                            if (subgraphOps.contains(opName)) {
                                // We've just executed X, and there's dependency X -> Y
                                // But, there also might be a Z -> Y that we should mark as needed for Y
                                addDependenciesForOp(opName, outFrameIter);
                            }
                        }
                    }

                    // Also add control dependencies (variable)
                    List<String> cdForOps = v.getControlDepsForOp();
                    if (cdForOps != null) {
                        for (String opName : cdForOps) {
                            if (subgraphOps.contains(opName)) {
                                // We've just executed X, and there's dependency X -> Y
                                // But, there also might be a Z -> Y that we should mark as needed for Y
                                addDependenciesForOp(opName, outFrameIter);
                            }
                        }
                    }
                }

            }
        } else if (t == ExecType.VARIABLE || t == ExecType.CONSTANT || t == ExecType.PLACEHOLDER) {
            Variable v = sameDiff.getVariables().get(n);
            if (v != null) {
                List<String> inputsToOps = v.getInputsForOp();
                if (inputsToOps != null) {
                    for (String opName : inputsToOps) {
                        if (subgraphOps.contains(opName)) {
                            addDependenciesForOp(opName, outFrameIter);
                        }
                    }
                }
            }

        } else if (justExecuted.getType() == ExecType.SWITCH_L || justExecuted.getType() == ExecType.SWITCH_R) {
            SameDiffOp op = sameDiff.getOps().get(n);
            List<String> outNames = op.getOutputsOfOp();
            String branchVarName = (justExecuted.getType() == ExecType.SWITCH_L ? outNames.get(0) : outNames.get(1));
            Variable v = sameDiff.getVariables().get(branchVarName);
            if (v != null) {
                List<String> inputsToOps = v.getInputsForOp();
                if (inputsToOps != null) {
                    for (String opName : inputsToOps) {
                        if (subgraphOps.contains(opName)) {
                            // We've just executed X, and there's dependency X -> Y
                            // But, there also might be a Z -> Y that we should mark as needed for Y
                            addDependenciesForOp(opName, outFrameIter);
                        }
                    }
                }
            }

        } else {
            throw new UnsupportedOperationException("Unknown or not yet implemented exec type: " + justExecuted);
        }
    }

    /**
     * Suppose operation X has just been executed.
     * For X -> someOp, add all dependencies for someOp, i.e., all Z -> someOp
     * (which includes X, but may not only be X)
     *
     * @param opName       Name of the op
     * @param depFrameIter Frame/iteration of the op instance to be executed
     */
    protected void addDependenciesForOp(String opName, FrameIter depFrameIter) {
        SameDiffOp op = sameDiff.getOps().get(opName);
        List<String> inputs = op.getInputsToOp();
        List<String> cdOps = op.getControlDeps();
        List<String> cdVars = op.getVarControlDeps();

        ExecStep es = new ExecStep(ExecType.OP, opName, depFrameIter);
        if (!(op.getOp() instanceof NextIteration) && dt.hasDependency(es)) {
            // Already processed this once. We only add dependencies once per op (for a
            // given frame/iteration)
            return;
        }

        if (op.getOp() instanceof Merge) {
            // Merge ops are a special case: they can be executed with EITHER ONE of the
            // inputs available - unlike every
            // other op, we don't need all inputs, just one, before it can be executed
            Variable v0 = sameDiff.getVariables().get(inputs.get(0));
            Variable v1 = sameDiff.getVariables().get(inputs.get(1));

            ExecStep or0 = getExecStepForVar(v0.getName(), depFrameIter);
            ExecStep or1 = getExecStepForVar(v1.getName(), depFrameIter);
            dt.addOrDependency(es, or0, or1);
        } else if (op.getOp() instanceof NextIteration) {
            // For NextIteration, dependencies should be of the form X(iter) ->
            // NextIter(iter+1)
            FrameIter fi = depFrameIter.clone();
            fi.setIteration(fi.getIteration() + 1);
            es = new ExecStep(ExecType.OP, opName, fi);
            for (String s : inputs) {
                ExecStep req = getExecStepForVar(s, depFrameIter);
                dt.addDependency(es, req);
            }
        } else {
            for (String s : inputs) {
                ExecStep req = getExecStepForVar(s, depFrameIter);
                dt.addDependency(es, req);
            }
        }

        if (cdOps != null) {
            for (String s : cdOps) {
                ExecStep req = getExecStepForVar(s, depFrameIter);
                dt.addDependency(es, req);
            }
        }

    }

    /**
     * Get the ExecStep for the given variable, given execution is happening at the
     * specified frame/iteration
     */
    protected ExecStep getExecStepForVar(String varName, FrameIter frameIter) {
        Variable v = sameDiff.getVariables().get(varName);
        if (v == null) {
            SameDiffOp op = sameDiff.getOps().get(varName);
            if (op != null) {
                // redirect because of rename
                v = sameDiff.getVariables().get(op.getOutputsOfOp().get(0));
            } else {
                throw new IllegalArgumentException("Variable name " + varName + " not found! Renamed?");
            }
        }
        VariableType vt = v.getVariable().getVariableType();
        if (vt == VariableType.VARIABLE) {
            return new ExecStep(ExecType.VARIABLE, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
        } else if (vt == VariableType.PLACEHOLDER) {
            return new ExecStep(ExecType.PLACEHOLDER, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
        } else if (vt == VariableType.CONSTANT) {
            return new ExecStep(ExecType.CONSTANT, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
        } else {
            // Array type. Must be output of an op
            if (v.getOutputOfOp() == null) {
                v = sameDiff.getVariables().get(stripVarSuffix(v.getName()));
            }

            String outOfOp = v.getOutputOfOp();
            SameDiffOp sdo = sameDiff.getOps().get(outOfOp);

            if (sdo == null) {
                throw new IllegalStateException(
                        "Samediff output op named " + v.getName() + " did not have any ops associated with it.");
            }

            if (sdo.getOp() instanceof Switch) {
                // For dependency tracking purposes, we track left and right output branches of
                // switch op separately
                // Otherwise, ops depending both branches will be marked as available if we just
                // rely on "op has been executed"
                List<String> opOutputs = sdo.getOutputsOfOp();
                int idx = opOutputs.indexOf(v.getName());
                if (idx == 0) {
                    // Left branch
                    return new ExecStep(ExecType.SWITCH_L, outOfOp, frameIter);
                } else if (idx == 1) {
                    // Right branch
                    return new ExecStep(ExecType.SWITCH_R, outOfOp, frameIter);
                } else {
                    // Should never happen
                    throw new IllegalStateException(
                            "Expected variable \"" + v.getName() + "\" to be an output of operation \"" +
                                    outOfOp + "\", but op output variables are: " + opOutputs);
                }
            } else if (sdo.getOp() instanceof Enter) {
                Enter e = (Enter) sdo.getOp();

                // For enter ops, "constant=true" enter ops are available for ALL iterations,
                // hence use iter=0
                // For constant=false, these are only available at iteration 0 - so use
                // *current* iteration, same as all other ops
                // (which is this case, won't be triggered on iter > 0 - as desired/expected)
                if (e.isConstant()) {
                    FrameIter fi = frameIter.clone();
                    fi.setIteration(0);

                    // Nested constant enter case: Iteration 0 all the way down...
                    String inVarName = sdo.getInputsToOp().get(0);
                    FrameIter parentFrame = fi.getParentFrame();
                    while (parentFrame != null) {
                        Variable var = sameDiff.getVariables().get(inVarName);
                        if (var.getOutputOfOp() != null) {
                            String opName = var.getOutputOfOp();
                            SameDiffOp sdo2 = sameDiff.getOps().get(opName);
                            if (sdo2.getOp() instanceof Enter) {
                                Enter e2 = (Enter) sdo.getOp();
                                if (e2.isConstant()) {
                                    parentFrame.setIteration(0);
                                    parentFrame = parentFrame.getParentFrame();
                                    inVarName = sdo2.getInputsToOp().get(0);
                                } else {
                                    break;
                                }
                            } else {
                                break;
                            }
                        } else {
                            break;
                        }
                    }

                    return new ExecStep(ExecType.OP, outOfOp, fi);
                }

                // Intentional fall-through to default case
            }
            return new ExecStep(ExecType.OP, outOfOp, frameIter);
        }
    }

    /**
     * Initialize the subgraph - the subgraph and subgraphOps sets
     * This works our what ops and variables we might need to execute to get the
     * requested outputs.
     * In general, this is a subset of the graph.
     *
     * @param variables Set of output variables we need
     */
    protected void initSubgraph(Set<String> variables) {
        // Step 1: determine subgraph structure we actually need to execute
        Queue<String> processingQueue = new LinkedList<>(variables);

        // Note subgraph initially should include placeholders and constants
        while (!processingQueue.isEmpty()) {
            String varName = processingQueue.remove();
            String opName = stripVarSuffix(sameDiff.getVariableOutputOp(varName) == null ? null
                    : sameDiff.getVariableOutputOp(varName).getOwnName());

            if (!subgraph.contains(varName)) {
                String[] opInputs = opName == null ? null : sameDiff.getInputsForOp(sameDiff.getOpById(opName));
                Variable currVar = sameDiff.getVariables().get(varName);
                log.trace("Adding " + varName + " to subgraph for output.");
                // probably renamed, redirect to new name
                if (currVar == null && opName == null) {
                    SameDiffOp op2 = sameDiff.getOps().get(varName);
                    currVar = sameDiff.getVariables().get(op2.outputsOfOp.get(0));
                    if (currVar == null) {
                        throw new IllegalStateException("No variable found with name " + varName + "!");
                    }
                }
                List<String> opInputsFor = currVar.getInputsForOp();
                List<String> controlDeps = currVar.getControlDeps();
                String output = currVar.getOutputOfOp();
                int numInputs = (opInputs == null ? 0 : opInputs.length);
                if (controlDeps != null) {
                    // Also count variable control dependencies as inputs - even a constant may not
                    // be available for use
                    // until after execution of some other ops (for example, in conditional
                    // operations)
                    numInputs += controlDeps.size();
                }
                if (numInputs == 0 && opName != null) {
                    zeroInputOpsInSubgraph.add(opName);
                }

                subgraph.add(varName);

                if (opName != null) {
                    subgraphOps.add(opName);
                }

                if (controlDeps != null) {
                    // If variable has control dependencies, it's not available right away... to
                    // make it available,
                    // we need the "inputs" to be available first. This is mainly used for TF
                    // import.
                    for (String s : controlDeps) {
                        if (!subgraph.contains(s)) {
                            processingQueue.add(s);
                        }
                    }
                }

            }

            if (opName != null) {
                // To execute op - and hence get this variable: need inputs to that op
                DifferentialFunction opById = sameDiff.getOpById(opName);
                String[] inputs = sameDiff.getInputsForOp(opById);
                if (inputs != null)
                    for (String s2 : inputs) {
                        if (!subgraph.contains(s2)) {
                            processingQueue.add(s2);
                        }
                    }

                // To execute op - and hence get this variable - we also need control deps
                List<String> opControlDeps = sameDiff.getOps().get(opName).getControlDeps();
                if (opControlDeps != null) {
                    for (String s2 : opControlDeps) {
                        if (!subgraph.contains(s2)) {
                            processingQueue.add(s2);
                        }
                    }
                }
            }
        }
    }

    /**
     * Preprocess the placeholder values, if required.
     * Mainly reserved for casting in the case of InferenceSession
     *
     * @param placeholders Placeholders to preprocess.
     * @return Preprocessed placeholders
     */
    protected Map<String, SDValue> preprocessValuePlaceholders(Map<String, SDValue> placeholders, At at) {
        return placeholders;
    }

    /**
     * Preprocess the placeholder values, if required.
     * Mainly reserved for casting in the case of InferenceSession
     *
     * @param placeholders Placeholders to preprocess.
     * @return Preprocessed placeholders
     */
    protected Map<String, T> preprocessPlaceholders(Map<String, T> placeholders, At at) {
        return placeholders;
    }

    /**
     * Post process the session output values, if required.
     * Override if required in session subclasses
     *
     * @param output Output to be returned to the user
     * @return Post processed output
     */
    protected Map<String, SDValue> postProcessOutputValues(Map<String, SDValue> output) {
        for (Map.Entry<String, SDValue> entry : output.entrySet()) {
            switch (entry.getValue().getSdValueType()) {
                case DICT:
                    for (Map.Entry<String, INDArray> arr : entry.getValue().getDictValue().entrySet()) {
                        arr.getValue().setCloseable(false);
                    }
                    break;
                case LIST:
                    for (INDArray arr : entry.getValue().getListValue()) {
                        arr.setCloseable(false);
                    }
                    break;
                case TENSOR:
                    entry.getValue().getTensorValue().setCloseable(false);
                    break;
            }

        }

        return output;
    }
    /**
     * Post process the session output values, if required.
     * Override if required in session subclasses
     *
     * @param output Output to be returned to the user
     * @return Post processed output
     */
    protected Map<String, T> postProcessOutput(Map<String, T> output) {
        return output;
    }

    /**
     * Get the constant or variable output - for example, constant array or constant
     * shape.
     * Note that both constants and variables (i.e., VariableType.CONSTANT and
     * VariableType.VARIABLE) are the same
     * for all frames and iterations.
     *
     * @param variableName The name of the variable to get the constant for
     * @return The constant
     */
    public abstract T getConstantOrVariable(String variableName);

    /**
     * Get the parameterized op to execute - for example, the
     * op/DifferentialFunction with all inputs set
     *
     * @param opName            Name of the op
     * @param frameIter         The frame and iteration of the op outputs
     * @param inputs            The inputs to the op (excluding
     *                          constants/placeholders) - for the specific frame +
     *                          iteration
     * @param allIterInputs     The inputs - those that are not iteration-specific
     *                          (mainly Enter op vars, which might be used in all
     *                          iterations but are only executed once on iter 0)
     * @param constAndPhInputs  The constant and placeholder inputs - used for all
     *                          frames/iterations
     * @param allReqVariables   All required variables requested for the current
     *                          session execution (not just the current op outputs)
     * @param otherPlaceholders
     * @return The parameterized op
     */
    public abstract O getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> inputs,
            Set<VarId> allIterInputs, Set<String> constAndPhInputs,
            Map<String, T> placeholderValues, Set<String> allReqVariables, Map<String, SDValue> otherPlaceholders);

    /**
     * Execute the op - calculate INDArrays, or shape info, etc
     *
     * @param op                Operation to exit. This should be parameterized
     *                          (i.e., all inputs set)
     * @param outputFrameIter   The frame and iteration of the outputs
     * @param inputs            The specific input arrays for the op
     * @param allReqVariables   All required variables requested for the current
     *                          session execution (not just the current op outputs)
     * @param otherPlaceHolders
     * @return The outputs of the op
     */
    public abstract ExecutionResult getOutputs(O op, FrameIter outputFrameIter, Set<VarId> inputs,
            Set<VarId> allIterInputs, Set<String> constAndPhInputs,
            List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables,
            Map<String, SDValue> otherPlaceHolders);

    /**
     * Get the VarId from the specified name. The VarId should be in one or the
     * other of the collections,
     * and only one VarId with that name should exist
     */
    protected static VarId lookup(String name, Collection<VarId> varIds, Collection<VarId> varIds2,
            boolean exceptionOnNotFound) {
        VarId vid = varIds == null ? null : lookup(name, varIds, false);
        if (vid == null && varIds2 != null)
            vid = lookup(name, varIds2, false);

        if (vid == null && exceptionOnNotFound) {
            throw new RuntimeException("Could not find VarId for input \"" + name + "\"");
        }
        return vid;
    }

    /**
     * Get the {@link INDArray}
     * associated with the given variable name
     *
     * @param name the variable name
     * @return the list of {@link INDArray}
     */
    public List<INDArray> getTensorArraysInSession(String name, String frame, int iteration, FrameIter parentFrame) {
        DifferentialFunction op = sameDiff.getVariableOutputOp(name);
        if (op == null)
            return null;
        String[] inputs = sameDiff.getInputsForOp(op);
        String[] outputs = sameDiff.getOutputsForOp(op);
        Set<VarId> varIds = new LinkedHashSet<>();
        for (String input : inputs) {
            VarId varId = new VarId(input, frame, iteration, parentFrame);
            varIds.add(varId);
        }

        varIds.addAll(nodeValueOutputs.entrySet().stream().filter(input -> input.getValue() != null &&
                input.getValue().getSdValueType() == SDValueType.LIST).map(input -> input.getKey())
                .collect(Collectors.toList()));

        VarId lookup = lookup(op.getOwnName(), varIds, false);
        if (lookup == null && op.args().length > 0) {
            SDVariable inTensorArray = op.arg(0); // Dummy variable representing the tensor array
            lookup = lookup(inTensorArray.name(), varIds, false);
            if (lookup != null) {
                List<INDArray> ret = nodeValueOutputs.containsKey(lookup) ? nodeValueOutputs.get(lookup).getListValue()
                        : null;
                if (ret == null && parentFrame != null)
                    return getTensorArraysInSession(name);
            }
            return null;
        }
        List<INDArray> ret = nodeValueOutputs.get(lookup).getListValue();
        if (ret == null && parentFrame != null)
            return getTensorArraysInSession(name);
        return null;
    }

    /**
     * Get the {@link INDArray}
     * associated with the given variable name
     *
     * @param name the variable name
     * @return the list of {@link INDArray}
     */
    public List<INDArray> getTensorArraysInSession(String name) {
        return getTensorArraysInSession(name, OUTER_FRAME, 0, null);
    }

    /**
     * Get the VarId from the specified name. The VarId should be in the collection,
     * and only one VarId with that name should exist
     */
    protected static VarId lookup(String name, Collection<VarId> varIds, boolean exceptionOnNotFound) {
        for (VarId vid : varIds) {
            if (vid.getVariable().equals(name)) {
                return vid;
            }
        }
        if (exceptionOnNotFound) {
            throw new RuntimeException("Could not find VarId to input " + name);
        }
        return null;
    }

    /**
     * VarId: identifies the value of a variable in a specific frame and frame
     * iteration<br>
     * Note that frames can be nested - which generally represents nested loop
     * situations.<br>
     * Used for 2 places:<br>
     * (a) to identify variables that are available for execution<br>
     * (b) to store results<br>
     */
    @Data
    public static class VarId {
        private String variable;
        private String frame;
        private int iteration;
        private FrameIter parentFrame;

        public VarId(String variable, String frame, int iteration, FrameIter parentFrame) {
            this.variable = variable;
            this.frame = frame;
            this.iteration = iteration;
            this.parentFrame = parentFrame;
        }

        /**
         * Creates the default outer frame
         *
         * @param name the name of the variable ot create an id for
         * @return
         */
        public static VarId createDefault(String name) {
            return new VarId(name, OUTER_FRAME, 0, null);
        }

        @Override
        public String toString() {
            return "VarId(\"" + variable + "\",\"" + frame + "\"," + iteration + ",parent=" + parentFrame + ")";
        }

        /**
         * @return FrameIter corresponding to the VarId
         */
        public FrameIter toFrameIter() {
            return new FrameIter(frame, iteration, parentFrame);
        }
    }

    /**
     * ExecType: Execution type, as used in ExecStep<br>
     * OP: Operation execution<br>
     * VARIABLE: Variable "execution", mainly used to trigger ops that depend on the
     * variable<br>
     * CONSTANT: As per variable<br>
     * PLACEHOLDER: As per variable<br>
     * SWITCH_L and SWITCH_R: This is a bit of a hack to account for the fact that
     * only one of
     * the switch branches (left or right) will ever be available; without this,
     * once the switch op is executed, we'll
     * (incorrectly) conclude that *both* branches can be executed<br>
     * EXEC_START: Start of execution<br>
     * CONTROL_DEP: Control dependency for op. Used for TF import, due to its odd
     * "constant depends on op in a frame" behaviour
     */
    protected enum ExecType {
        OP, VARIABLE, CONSTANT, PLACEHOLDER, SWITCH_L, SWITCH_R, EXEC_START, CONTROL_DEP
    }

    ;

    /**
     * ExecStep represents a single execution step, for a single op (or
     * variable/constant etc) at a specific frame/iteration
     */
    @Getter
    @EqualsAndHashCode
    protected static class ExecStep {
        protected final ExecType type;
        protected final String name;
        protected final FrameIter frameIter;

        protected ExecStep(@NonNull ExecType execType, @NonNull String name, FrameIter frameIter) {
            this.type = execType;
            this.name = name;
            this.frameIter = frameIter;
        }

        protected VarId toVarId() {
            return new VarId(name, frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame());
        }

        @Override
        public String toString() {
            return "ExecStep(" + type + ",name=\"" + name + "\"," + frameIter + ")";
        }

    }

    /**
     * Used in getting the next ExecStep that matches the specified (current)
     * frame/iteration
     */
    @Data
    @AllArgsConstructor
    @NoArgsConstructor
    protected class ExecStepPredicate implements Predicate<ExecStep> {

        protected String currentFrame;
        protected int currentFrameIter;
        protected FrameIter currParentFrame;

        @Override
        public boolean test(ExecStep execStep) {
            return currentFrame.equals(execStep.getFrameIter().getFrame()) &&
                    currentFrameIter == execStep.getFrameIter().getIteration() &&
                    (currParentFrame == null && execStep.getFrameIter().getParentFrame() == null ||
                            currParentFrame.equals(execStep.getFrameIter().getParentFrame()));
        }
    }

    ;
}