deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.modelimport.keras;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasOptimizerUtils;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasLSTM;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasRnnUtils;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasSimpleRnn;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasLoss;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasLambda;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.util.Convolution3DUtils;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.common.primitives.Counter;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.shade.guava.collect.Lists;
import java.io.IOException;
import java.util.*;
import static org.deeplearning4j.nn.modelimport.keras.KerasLayer.customLayers;
import static org.deeplearning4j.nn.modelimport.keras.KerasLayer.lambdaLayers;
@Slf4j
@Data
public class KerasModel {
protected static KerasModelConfiguration config = new KerasModelConfiguration();
protected KerasModelBuilder modelBuilder = new KerasModelBuilder(config);
protected String className; // Keras model class name
protected boolean enforceTrainingConfig; // whether to build model in training mode
protected Map<String, KerasLayer> layers; // map from layer name to KerasLayer
protected List<KerasLayer> layersOrdered; // ordered list of layers
protected Map<String, InputType> outputTypes; // inferred output types for all layers
protected ArrayList<String> inputLayerNames; // list of input layers
protected ArrayList<String> outputLayerNames; // list of output layers
protected boolean useTruncatedBPTT = false; // whether to use truncated BPTT
protected int truncatedBPTT = 0; // truncated BPTT value
protected int kerasMajorVersion;
protected String kerasBackend;
protected KerasLayer.DimOrder dimOrder = null;
protected IUpdater optimizer = null;
public KerasModel() {
}
public KerasModelBuilder modelBuilder() {
return this.modelBuilder;
}
/**
* (Recommended) Builder-pattern constructor for (Functional API) Model.
*
* @param modelBuilder builder object
* @throws IOException IO exception
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
*/
public KerasModel(KerasModelBuilder modelBuilder)
throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(),
modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(),
modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape(), modelBuilder.getDimOrder());
}
/**
* (Not recommended) Constructor for (Functional API) Model from model configuration
* (JSON or YAML), training configuration (JSON), weights, and "training mode"
* boolean indicator. When built in training mode, certain unsupported configurations
* (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these
* will generate warnings but will be otherwise ignored.
*
* @param modelJson model configuration JSON string
* @param modelYaml model configuration YAML string
* @param enforceTrainingConfig whether to enforce training-related configurations
* @throws IOException IO exception
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
*/
protected KerasModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot,
String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig,
int[] inputShape, KerasLayer.DimOrder dimOrder)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
this.enforceTrainingConfig = enforceTrainingConfig;
this.dimOrder = dimOrder;
/* Determine model configuration type. */
if (!modelConfig.containsKey(config.getFieldClassName()))
throw new InvalidKerasConfigurationException(
"Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
this.className = (String) modelConfig.get(config.getFieldClassName());
if (!this.className.equals(config.getFieldClassNameModel()) && !this.className.equals(config.getFieldNameClassFunctional()))
throw new InvalidKerasConfigurationException(
"Expected model class name " + config.getFieldClassNameModel() + " or " + config.getFieldNameClassFunctional() + " (found " + this.className + ")");
/* Retrieve lists of input and output layers, layer configurations. */
if (!modelConfig.containsKey(config.getModelFieldConfig()))
throw new InvalidKerasConfigurationException("Could not find model configuration details (no "
+ config.getModelFieldConfig() + " in model config)");
Map<String, Object> layerLists = (Map<String, Object>) modelConfig.get(config.getModelFieldConfig());
/* Construct list of input layers. */
if (!layerLists.containsKey(config.getModelFieldInputLayers()))
throw new InvalidKerasConfigurationException("Could not find list of input layers (no "
+ config.getModelFieldInputLayers() + " field found)");
this.inputLayerNames = new ArrayList<>();
for (Object inputLayerNameObj : (List<Object>) layerLists.get(config.getModelFieldInputLayers()))
this.inputLayerNames.add((String) ((List<Object>) inputLayerNameObj).get(0));
/* Construct list of output layers. */
if (!layerLists.containsKey(config.getModelFieldOutputLayers()))
throw new InvalidKerasConfigurationException("Could not find list of output layers (no "
+ config.getModelFieldOutputLayers() + " field found)");
this.outputLayerNames = new ArrayList<>();
for (Object outputLayerNameObj : (List<Object>) layerLists.get(config.getModelFieldOutputLayers()))
this.outputLayerNames.add((String) ((List<Object>) outputLayerNameObj).get(0));
/* Process layer configurations. */
if (!layerLists.containsKey(config.getModelFieldLayers()))
throw new InvalidKerasConfigurationException(
"Could not find layer configurations (no " + (config.getModelFieldLayers() + " field found)"));
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair =
prepareLayers((List<Object>) layerLists.get((config.getModelFieldLayers())));
this.layers = layerPair.getFirst();
this.layersOrdered = layerPair.getSecond();
/* Import training configuration. */
if (enforceTrainingConfig) {
if (trainingJson != null)
importTrainingConfiguration(trainingJson);
else log.warn("If enforceTrainingConfig is true, a training " +
"configuration object has to be provided. Usually the only practical way to do this is to store" +
" your keras model with `model.save('model_path.h5')`. If you store model config and weights" +
" separately no training configuration is attached.");
}
if(inputShape == null) {
inputShape = layersOrdered.get(0).inputShape;
}
/* Infer output types for each layer. */
this.outputTypes = inferOutputTypes(inputShape);
/* Store weights in layers. */
if (weightsArchive != null)
KerasModelUtils.importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
}
/**
* Helper method called from constructor. Converts layer configuration
* JSON into KerasLayer objects.
*
* @param layerConfigs List of Keras layer configurations
*/
Pair<Map<String, KerasLayer>, List<KerasLayer>> prepareLayers(List<Object> layerConfigs)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
Map<String, KerasLayer> layers = new HashMap<>(); // map from layer name to KerasLayer
List<KerasLayer> layersOrdered = new ArrayList<>();
for (Object layerConfig : layerConfigs) {
Map<String, Object> layerConfigMap = (Map<String, Object>) layerConfig;
// Append major keras version and backend to each layer config.
layerConfigMap.put(config.getFieldKerasVersion(), this.kerasMajorVersion);
if (kerasMajorVersion == 2 && this.kerasBackend != null)
layerConfigMap.put(config.getFieldBackend(), this.kerasBackend);
KerasLayerConfiguration kerasLayerConf = new KerasLayer(this.kerasMajorVersion).conf;
if (dimOrder != null) { // Force override of dim ordering with value from model builder
String dimOrderString;
if (dimOrder == KerasLayer.DimOrder.TENSORFLOW)
dimOrderString = kerasLayerConf.getDIM_ORDERING_TENSORFLOW();
else if (dimOrder == KerasLayer.DimOrder.THEANO)
dimOrderString = kerasLayerConf.getDIM_ORDERING_THEANO();
else
throw new InvalidKerasConfigurationException("Invalid data format / dim ordering");
layerConfigMap.put(kerasLayerConf.getLAYER_FIELD_DIM_ORDERING(), dimOrderString);
}
KerasLayer layer = KerasLayerUtils.getKerasLayerFromConfig(
layerConfigMap, this.enforceTrainingConfig, kerasLayerConf, customLayers, lambdaLayers, layers);
layersOrdered.add(layer);
layers.put(layer.getLayerName(), layer);
if (layer instanceof KerasLSTM)
this.useTruncatedBPTT = this.useTruncatedBPTT || ((KerasLSTM) layer).getUnroll();
if (layer instanceof KerasSimpleRnn)
this.useTruncatedBPTT = this.useTruncatedBPTT || ((KerasSimpleRnn) layer).getUnroll();
}
List<String> names = new ArrayList<>();
//set of names of lambda nodes
Set<String> lambdaNames = new HashSet<>();
//node inputs by name for looking up which nodes to do replacements for (useful since indices of nodes can change)
Map<String,List<String>> nodesOutputToForLambdas = new HashMap<>();
for(int i = 0; i < layers.size(); i++) {
names.add(layersOrdered.get(i).getLayerName());
if(layersOrdered.get(i) instanceof KerasLambda) {
lambdaNames.add(layersOrdered.get(i).getLayerName());
}
}
Map<String,List<String>> replacementNamesForLambda = new HashMap<>();
Map<Integer,KerasLayer> updatedOrders = new HashMap<>();
for(int i = 0; i < layersOrdered.size(); i++) {
KerasLayer kerasLayer = layers.get(names.get(i));
List<String> tempCopyNames = new ArrayList<>(kerasLayer.getInboundLayerNames());
List<String> removed = new ArrayList<>();
for(String input : tempCopyNames) {
//found a lambda where an input occurs, record the index for input
if(lambdaNames.contains(input)) {
if(!nodesOutputToForLambdas.containsKey(input)) {
nodesOutputToForLambdas.put(input,new ArrayList<String>());
}
nodesOutputToForLambdas.get(input).add(kerasLayer.getLayerName());
}
//potential loop found
int indexOfInput = names.indexOf(input);
if(indexOfInput > i) {
KerasLambda originalLambda = (KerasLambda) kerasLayer;
Map<String,Object> configCopy = new HashMap<String,Object>(kerasLayer.originalLayerConfig);
String newName = kerasLayer.getLayerName() + "-" + input;
if(!replacementNamesForLambda.containsKey(originalLambda.layerName)) {
replacementNamesForLambda.put(originalLambda.layerName,new ArrayList<String>());
}
configCopy.put(kerasLayer.conf.getLAYER_FIELD_NAME(),newName);
replacementNamesForLambda.get(originalLambda.layerName).add(newName);
SameDiffLambdaLayer sameDiffLambdaLayer = (SameDiffLambdaLayer) originalLambda.getSameDiffLayer().clone();
sameDiffLambdaLayer.setLayerName(newName);
KerasLambda kerasLambda = new KerasLambda(configCopy,sameDiffLambdaLayer);
kerasLambda.layerName = newName;
kerasLambda.setInboundLayerNames(new ArrayList<>(Arrays.asList(input)));
layers.put(newName,kerasLambda);
int indexOfNewLayer = names.indexOf(input) + 1;
updatedOrders.put(indexOfNewLayer,kerasLambda);
names.add(indexOfNewLayer,newName);
removed.add(input);
System.out.println("Found input " + input + " at keras node " + names.get(i) + " with potential cycle.");
}
}
kerasLayer.getInboundLayerNames().removeAll(removed);
}
//update the list with all the new layers
for(Map.Entry<Integer,KerasLayer> newLayers : updatedOrders.entrySet()) {
layersOrdered.add(newLayers.getKey(),newLayers.getValue());
}
List<String> oldNames = new ArrayList<>(names);
names.clear();
//old names are used for checking distance from old nodes to new ones
//node inputs by name for looking up which nodes to do replacements for (useful since indices of nodes can change)
if(!replacementNamesForLambda.isEmpty()) {
for (Map.Entry<String, List<String>> replacementEntry : replacementNamesForLambda.entrySet()) {
List<String> nodesToReplaceInputNamesWith = nodesOutputToForLambdas.get(replacementEntry.getKey());
Set<String> processed = new HashSet<>();
for (String nodeName : nodesToReplaceInputNamesWith) {
KerasLayer kerasLayer = layers.get(nodeName);
boolean shouldBeOriginal = true;
if (!processed.isEmpty()) {
for (String process : processed) {
if (kerasLayer.getInboundLayerNames().contains(process)) {
shouldBeOriginal = false;
break;
}
}
}
List<String> nearestNodes = findNearestNodesTo(replacementEntry.getKey(), nodeName, replacementEntry.getValue(), oldNames, 2);
//if the original isn't in the closest top 2 nodes, then we shouldn't replace it
if (nodesToReplaceInputNamesWith.size() > 1) {
if (!nearestNodes.contains(replacementEntry.getKey())) {
shouldBeOriginal = false;
}
}
//layers that contain an already processed
//node as an input need modification
if (shouldBeOriginal) {
processed.add(nodeName);
continue;
}
//replace whatever the final input name is that was last
kerasLayer.getInboundLayerNames().set(kerasLayer.getInboundLayerNames()
.indexOf(replacementEntry.getKey()), nearestNodes.get(0));
processed.add(nodeName);
}
}
}
layers.clear();
for(KerasLayer kerasLayer : layersOrdered) {
layers.put(kerasLayer.getLayerName(),kerasLayer);
}
return new Pair<>(layers, layersOrdered);
}
List<String> findNearestNodesTo(String original,String target,List<String> targetedNodes,List<String> topoSortNodes,int k) {
int idx = topoSortNodes.indexOf(target);
Counter<String> rankByDistance = new Counter<>();
for(int i = 0; i < targetedNodes.size(); i++) {
int currIdx = topoSortNodes.indexOf(targetedNodes.get(i));
int diff = Math.abs(currIdx - idx);
//note we want the top k ranked by the least
rankByDistance.incrementCount(targetedNodes.get(i),-diff);
}
int currIdx = topoSortNodes.indexOf(original);
int diff = Math.abs(currIdx - idx);
//note we want the top k ranked by the least
rankByDistance.incrementCount(original,-diff);
rankByDistance.keepTopNElements(k);
return rankByDistance.keySetSorted();
}
Map<String, Object> getOptimizerConfig(Map<String, Object> trainingConfig) throws InvalidKerasConfigurationException{
if (!trainingConfig.containsKey(config.getOptimizerConfig()))
throw new InvalidKerasConfigurationException("Field "
+ config.getOptimizerConfig() + " missing from layer config");
return (Map<String, Object>) trainingConfig.get(config.getOptimizerConfig());
}
/**
* Helper method called from constructor. Incorporate training configuration details into model.
* Includes loss function, optimization details, etc.
*
* @param trainingConfigJson JSON containing Keras training configuration
* @throws IOException IO exception
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
*/
void importTrainingConfiguration(String trainingConfigJson)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
Map<String, Object> trainingConfig = KerasModelUtils.parseJsonString(trainingConfigJson);
Map<String, Object> optimizerConfig = getOptimizerConfig(trainingConfig);
this.optimizer = KerasOptimizerUtils.mapOptimizer(optimizerConfig);
/* Add loss layers for each loss function. */
List<KerasLayer> lossLayers = new ArrayList<>();
if (!trainingConfig.containsKey(config.getTrainingLoss()))
throw new InvalidKerasConfigurationException("Could not determine training loss function (no "
+ config.getTrainingLoss() + " field found in training config)");
Object kerasLossObj = trainingConfig.get(config.getTrainingLoss());
if (kerasLossObj instanceof String) {
String kerasLoss = (String) kerasLossObj;
for (String outputLayerName : this.outputLayerNames)
lossLayers.add(new KerasLoss(outputLayerName + "_loss", outputLayerName, kerasLoss));
} else if (kerasLossObj instanceof Map) {
Map<String, Object> kerasLossMap = (Map<String, Object>) kerasLossObj;
//tf.keras double nesting
if(kerasLossMap.containsKey("config")) {
kerasLossMap = (Map<String, Object>) kerasLossMap.get("config");
lossLayers.add(new KerasLoss(layersOrdered.get(layers.size() - 1).getLayerName() + "_loss",layersOrdered.get(layers.size() - 1).getLayerName(),kerasLossMap.get("name").toString()));
} else {
for (String outputLayerName : kerasLossMap.keySet()) {
Object kerasLoss = kerasLossMap.get(outputLayerName);
if (kerasLoss instanceof String)
lossLayers.add(new KerasLoss(outputLayerName + "_loss", outputLayerName, (String) kerasLoss));
else
throw new InvalidKerasConfigurationException("Unknown Keras loss " + kerasLoss.toString());
}
}
}
this.outputLayerNames.clear();
/* Add loss layers to output layer list and layer graph. */
for (KerasLayer lossLayer : lossLayers) {
this.layersOrdered.add(lossLayer);
this.layers.put(lossLayer.getLayerName(), lossLayer);
this.outputLayerNames.add(lossLayer.getLayerName());
}
}
/**
* Helper method called from constructor. Infers and records output type
* for every layer.
*/
Map<String, InputType> inferOutputTypes(int[] inputShape)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
Map<String, InputType> outputTypes = new HashMap<>();
int kerasLayerIdx = 0;
for (KerasLayer layer : this.layersOrdered) {
InputType outputType;
if (layer instanceof KerasInput) {
if (inputShape != null && layer.inputShape == null) {
layer.inputShape = inputShape;
}
KerasInput kerasInput = (KerasInput) layer;
Layer layer1 = layersOrdered.get(kerasLayerIdx + 1).layer;
//no dim order, try to pull it from the next layer if there is one
if(layer1 != null && ConvolutionUtils.layerHasConvolutionLayout(layer1)) {
CNN2DFormat formatForLayer = ConvolutionUtils.getFormatForLayer(layer1);
if(formatForLayer == CNN2DFormat.NCHW) {
dimOrder = KerasLayer.DimOrder.THEANO;
} else if(formatForLayer == CNN2DFormat.NHWC) {
dimOrder = KerasLayer.DimOrder.TENSORFLOW;
} else {
dimOrder = KerasLayer.DimOrder.NONE;
}
} else if(layer1 != null && Convolution3DUtils.layerHasConvolution3DLayout(layer1)) {
Convolution3D.DataFormat dataFormat = Convolution3DUtils.getFormatForLayer(layer1);
if(dataFormat == Convolution3D.DataFormat.NCDHW) {
dimOrder = KerasLayer.DimOrder.THEANO;
} else if(dataFormat == Convolution3D.DataFormat.NDHWC) {
dimOrder = KerasLayer.DimOrder.TENSORFLOW;
} else {
dimOrder = KerasLayer.DimOrder.NONE;
}
} else if(KerasRnnUtils.isRnnLayer(layersOrdered.get(kerasLayerIdx + 1))) {
if(kerasInput.inputShape == null)
kerasInput.inputShape = layersOrdered.get(kerasLayerIdx + 1).inputShape;
}
if(dimOrder != null)
layer.setDimOrder(dimOrder);
outputType = layer.getOutputType();
this.truncatedBPTT = ((KerasInput) layer).getTruncatedBptt();
} else {
List<InputType> inputTypes = new ArrayList<>();
int i = 0;
for (String inboundLayerName : layer.getInboundLayerNames())
if(outputTypes.containsKey(inboundLayerName))
inputTypes.add(outputTypes.get(inboundLayerName));
outputType = layer.getOutputType(inputTypes.toArray(new InputType[1]));
}
outputTypes.put(layer.getLayerName(), outputType);
kerasLayerIdx++;
}
return outputTypes;
}
/**
* Configure a ComputationGraph from this Keras Model configuration.
*
* @return ComputationGraph
*/
public ComputationGraphConfiguration getComputationGraphConfiguration()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
if (!this.className.equals(config.getFieldClassNameModel())
&& !this.className.equals(config.getFieldClassNameSequential())
&& !this.className.equals(config.getFieldNameClassFunctional()))
throw new InvalidKerasConfigurationException(
"Keras model class name " + this.className + " incompatible with ComputationGraph");
NeuralNetConfiguration.Builder modelBuilder = new NeuralNetConfiguration.Builder();
if (optimizer != null) {
modelBuilder.updater(optimizer);
}
Map<String,List<String>> outputs = new HashMap<>();
for (KerasLayer layer : Lists.reverse(this.layersOrdered)) {
for(String input : layer.getInboundLayerNames()) {
if(!outputs.containsKey(input)) {
outputs.put(input,new ArrayList<String>());
}
outputs.get(input).add(layer.getLayerName());
}
}
ComputationGraphConfiguration.GraphBuilder graphBuilder = modelBuilder.graphBuilder();
// NOTE: normally this is disallowed in DL4J. However, in Keras you can create disconnected graph vertices.
// The responsibility for doing this correctly is that of the Keras user.
graphBuilder.allowDisconnected(true);
/* Build String array of input layer names, add to ComputationGraph. */
String[] inputLayerNameArray = new String[this.inputLayerNames.size()];
this.inputLayerNames.toArray(inputLayerNameArray);
graphBuilder.addInputs(inputLayerNameArray);
/* Build InputType array of input layer types, add to ComputationGraph. */
List<InputType> inputTypeList = new ArrayList<>();
List<InputType> initialInputTypes = new ArrayList<>();
for (String inputLayerName : this.inputLayerNames) {
this.layers.get(inputLayerName);
inputTypeList.add(this.layers.get(inputLayerName).getOutputType());
}
/* Build String array of output layer names, add to ComputationGraph. */
String[] outputLayerNameArray = new String[this.outputLayerNames.size()];
this.outputLayerNames.toArray(outputLayerNameArray);
graphBuilder.setOutputs(outputLayerNameArray);
Map<String, InputPreProcessor> preprocessors = new HashMap<>();
int idx = 0;
/* Add layersOrdered one at a time. */
for (KerasLayer layer : this.layersOrdered) {
/* Get inbound layer names. */
List<String> inboundLayerNames = layer.getInboundLayerNames();
String[] inboundLayerNamesArray = new String[inboundLayerNames.size()];
inboundLayerNames.toArray(inboundLayerNamesArray);
List<InputType> inboundTypeList = new ArrayList<>();
/* Get inbound InputTypes and InputPreProcessor, if necessary. */
if(!inboundLayerNames.isEmpty()) {
InputType[] inputTypes2 = new InputType[inboundLayerNames.size()];
int inboundIdx = 0;
for (String layerName : inboundLayerNames) {
KerasLayer prevLayer = layers.get(layerName);
if(prevLayer.isInputPreProcessor()) {
InputType inputType = this.outputTypes.get(layerName);
InputPreProcessor preprocessor = prevLayer.getInputPreprocessor(inputType);
KerasModelUtils.setDataFormatIfNeeded(preprocessor,layer);
InputType outputType = preprocessor.getOutputType(inputType);
inputTypes2[inboundIdx] = outputType;
inboundIdx++;
}
else {
InputType inputType = this.outputTypes.get(layerName);
inputTypes2[inboundIdx] = inputType;
inboundIdx++;
}
if(outputTypes.containsKey(layerName))
inboundTypeList.add(this.outputTypes.get(layerName));
}
}
InputType[] inboundTypeArray = new InputType[inboundTypeList.size()];
inboundTypeList.toArray(inboundTypeArray);
InputPreProcessor preprocessor = layer.getInputPreprocessor(inboundTypeArray);
//don't add pre processor if there isn't anymore output, edge case for final layer
if(idx == layersOrdered.size() - 1) {
preprocessor = null;
}
if (layer.isLayer()) {
if (preprocessor != null)
preprocessors.put(layer.getLayerName(), preprocessor);
graphBuilder.addLayer(layer.getLayerName(), layer.getLayer(), inboundLayerNamesArray);
} else if (layer.isVertex()) { // Ignore "preprocessor" layers for now
if (preprocessor != null)
preprocessors.put(layer.getLayerName(), preprocessor);
graphBuilder.addVertex(layer.getLayerName(), layer.getVertex(), inboundLayerNamesArray);
} else if (layer.isInputPreProcessor()) {
if (preprocessor == null)
throw new UnsupportedKerasConfigurationException("Layer " + layer.getLayerName()
+ " could not be mapped to Layer, Vertex, or InputPreProcessor");
graphBuilder.addVertex(layer.getLayerName(), new PreprocessorVertex(preprocessor),
inboundLayerNamesArray);
}
if(layer instanceof KerasInput) {
initialInputTypes.add(this.outputTypes.get(layer.layerName));
}
idx++;
}
graphBuilder.setInputPreProcessors(preprocessors);
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
graphBuilder.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(truncatedBPTT)
.tBPTTBackwardLength(truncatedBPTT);
else
graphBuilder.backpropType(BackpropType.Standard);
ComputationGraphConfiguration build = graphBuilder.build();
//note we don't forcibly over ride inputs when doing keras import. They are already set.
build.addPreProcessors(false,false,initialInputTypes.toArray(new InputType[initialInputTypes.size()]));
return build;
}
/**
* Build a ComputationGraph from this Keras Model configuration and import weights.
*
* @return ComputationGraph
*/
public ComputationGraph getComputationGraph()
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
return getComputationGraph(true);
}
/**
* Build a ComputationGraph from this Keras Model configuration and (optionally) import weights.
*
* @param importWeights whether to import weights
* @return ComputationGraph
*/
public ComputationGraph getComputationGraph(boolean importWeights)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
ComputationGraph model = new ComputationGraph(getComputationGraphConfiguration());
model.init();
if (importWeights)
model = (ComputationGraph) KerasModelUtils.copyWeightsToModel(model, this.layers);
return model;
}
}