deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.modelimport.keras.utils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.bytedeco.hdf5.Group;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.wrappers.KerasBidirectional;
import org.deeplearning4j.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.core.type.TypeReference;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import java.io.IOException;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@Slf4j
public class KerasModelUtils {
/**
* Set the {@link org.deeplearning4j.nn.conf.DataFormat}
* for certain input preprocessors to ensure that
* model import propagates properly for cases like reshapes.
*
* @param inputPreProcessor
* @param currLayer
*/
public static void setDataFormatIfNeeded(InputPreProcessor inputPreProcessor, KerasLayer currLayer) {
if(inputPreProcessor instanceof ReshapePreprocessor) {
ReshapePreprocessor reshapePreprocessor = (ReshapePreprocessor) inputPreProcessor;
if(currLayer.isLayer()) {
if(currLayer.getDimOrder() != null) {
Layer layer = currLayer.getLayer();
if(layer instanceof ConvolutionLayer) {
ConvolutionLayer convolutionLayer = (ConvolutionLayer) layer;
if(convolutionLayer instanceof Convolution3D) {
Convolution3D convolution3D = (Convolution3D) convolutionLayer;
reshapePreprocessor.setFormat(convolution3D.getDataFormat());
} else if(convolutionLayer instanceof Deconvolution3D) {
Deconvolution3D deconvolution3D = (Deconvolution3D) convolutionLayer;
reshapePreprocessor.setFormat(deconvolution3D.getDataFormat());
} else {
reshapePreprocessor.setFormat(convolutionLayer.getCnn2dDataFormat());
}
} else if(layer instanceof BaseRecurrentLayer) {
BaseRecurrentLayer baseRecurrentLayer = (BaseRecurrentLayer) layer;
reshapePreprocessor.setFormat(baseRecurrentLayer.getRnnDataFormat());
}
}
}
}
}
/**
* Helper function to import weights from nested Map into existing model. Depends critically
* on matched layer and parameter names. In general this seems to be straightforward for most
* Keras models and layersOrdered, but there may be edge cases.
*
* @param model DL4J Model interface
* @return DL4J Model interface
* @throws InvalidKerasConfigurationException Invalid Keras config
*/
public static Model copyWeightsToModel(Model model, Map<String, KerasLayer> kerasLayers)
throws InvalidKerasConfigurationException {
/* Get list if layers from model. */
org.deeplearning4j.nn.api.Layer[] layersFromModel;
if (model instanceof MultiLayerNetwork)
layersFromModel = ((MultiLayerNetwork) model).getLayers();
else
layersFromModel = ((ComputationGraph) model).getLayers();
/* Iterate over layers in model, setting weights when relevant. */
Set<String> layerNames = new HashSet<>(kerasLayers.keySet());
for (org.deeplearning4j.nn.api.Layer layer : layersFromModel) {
String layerName = layer.conf().getLayer().getLayerName();
if (!kerasLayers.containsKey(layerName))
throw new InvalidKerasConfigurationException(
"No weights found for layer in model (named " + layerName + ")");
kerasLayers.get(layerName).copyWeightsToLayer(layer);
layerNames.remove(layerName);
}
for (String layerName : layerNames) {
if (kerasLayers.get(layerName).getNumParams() > 0)
throw new InvalidKerasConfigurationException(
"Attempting to copy weights for layer not in model (named " + layerName + ")");
}
return model;
}
/**
* Determine Keras major version
*
* @param modelConfig parsed model configuration for keras model
* @param config basic model configuration (KerasModelConfiguration)
* @return Major Keras version (1 or 2)
* @throws InvalidKerasConfigurationException Invalid Keras config
*/
public static int determineKerasMajorVersion(Map<String, Object> modelConfig, KerasModelConfiguration config)
throws InvalidKerasConfigurationException {
int kerasMajorVersion;
if (!modelConfig.containsKey(config.getFieldKerasVersion())) {
log.warn("Could not read keras version used (no "
+ config.getFieldKerasVersion() + " field found) \n"
+ "assuming keras version is 1.0.7 or earlier."
);
kerasMajorVersion = 1;
} else {
String kerasVersionString = (String) modelConfig.get(config.getFieldKerasVersion());
if (Character.isDigit(kerasVersionString.charAt(0))) {
kerasMajorVersion = Character.getNumericValue(kerasVersionString.charAt(0));
} else {
throw new InvalidKerasConfigurationException(
"Keras version was not readable (" + config.getFieldKerasVersion() + " provided)"
);
}
}
return kerasMajorVersion;
}
/**
* Determine Keras backend
*
* @param modelConfig parsed model configuration for keras model
* @param config basic model configuration (KerasModelConfiguration)
* @return Keras backend string
* @throws InvalidKerasConfigurationException Invalid Keras config
*/
public static String determineKerasBackend(Map<String, Object> modelConfig, KerasModelConfiguration config) {
String kerasBackend = null;
if (!modelConfig.containsKey(config.getFieldBackend())) {
// TODO: H5 files unfortunately do not seem to have this property in keras 1.
log.warn("Could not read keras backend used (no "
+ config.getFieldBackend() + " field found) \n"
);
} else {
kerasBackend = (String) modelConfig.get(config.getFieldBackend());
}
return kerasBackend;
}
private static String findParameterName(String parameter, String[] fragmentList) {
Matcher layerNameMatcher =
Pattern.compile(fragmentList[fragmentList.length - 1]).matcher(parameter);
String parameterNameFound = layerNameMatcher.replaceFirst("");
/* Usually layer name is separated from parameter name by an underscore. */
Matcher paramNameMatcher = Pattern.compile("^_(.+)$").matcher(parameterNameFound);
if (paramNameMatcher.find())
parameterNameFound = paramNameMatcher.group(1);
/* TensorFlow backend often appends ":" followed by one or more digits to parameter names. */
Matcher tfSuffixMatcher = Pattern.compile(":\\d+?$").matcher(parameterNameFound);
if (tfSuffixMatcher.find())
parameterNameFound = tfSuffixMatcher.replaceFirst("");
/* TensorFlow backend also may append "_" followed by one or more digits to parameter names.*/
Matcher tfParamNbMatcher = Pattern.compile("_\\d+$").matcher(parameterNameFound);
if (tfParamNbMatcher.find())
parameterNameFound = tfParamNbMatcher.replaceFirst("");
return parameterNameFound;
}
/**
* Store weights to import with each associated Keras layer.
*
* @param weightsArchive Hdf5Archive
* @param weightsRoot root of weights in HDF5 archive
* @throws InvalidKerasConfigurationException Invalid Keras configuration
*/
public static void importWeights(Hdf5Archive weightsArchive, String weightsRoot, Map<String, KerasLayer> layers,
int kerasVersion, String backend)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
// check to ensure naming scheme doesn't include forward slash
boolean includesSlash = false;
for (String layerName : layers.keySet()) {
if (layerName.contains("/"))
includesSlash = true;
}
synchronized (KerasModelUtils.class) {
List<String> layerGroups;
if (!includesSlash) {
layerGroups = weightsRoot != null ? weightsArchive.getGroups(weightsRoot) : weightsArchive.getGroups();
} else {
layerGroups = new ArrayList<>(layers.keySet());
}
/* Set weights in KerasLayer for each entry in weights map. */
for (String layerName : layerGroups) {
if(layerName.equals(KerasModelConfiguration.topLevelModelWeights)) {
//new way of saving parameter weights
synchronized(Hdf5Archive.LOCK_OBJECT) {
Group[] rootGroup = weightsArchive.openGroups(weightsRoot + "/" + layerName);
if(rootGroup[0].getNumObjs() < 1)
weightsArchive.closeGroups(rootGroup);
}
}
else {
//older layers where weights are stored per layer
List<String> layerParamNames;
// there's a bug where if a layer name contains a forward slash, the first fragment must be appended
// to the name of the dataset; it appears h5 interprets the forward slash as a data group
String[] layerFragments = layerName.split("/");
// Find nested groups when using Tensorflow
String rootPrefix = weightsRoot != null ? weightsRoot + "/" : "";
List<String> attributeStrParts = new ArrayList<>();
String attributeStr = weightsArchive.readAttributeAsString(
"weight_names", rootPrefix + layerName
);
String attributeJoinStr;
Matcher attributeMatcher = Pattern.compile(":\\d+").matcher(attributeStr);
Boolean foundTfGroups = attributeMatcher.find();
if (foundTfGroups) {
for (String part : attributeStr.split("/")) {
part = part.trim();
if (part.length() == 0)
break;
Matcher tfSuffixMatcher = Pattern.compile(":\\d+").matcher(part);
if (tfSuffixMatcher.find())
break;
attributeStrParts.add(part);
}
attributeJoinStr = StringUtils.join(attributeStrParts, "/");
} else {
attributeJoinStr = layerFragments[0];
}
String baseAttributes = layerName + "/" + attributeJoinStr;
if (layerFragments.length > 1) {
try {
layerParamNames = weightsArchive.getDataSets(rootPrefix + baseAttributes);
} catch (Exception e) {
layerParamNames = weightsArchive.getDataSets(rootPrefix + layerName);
}
} else {
if (foundTfGroups) {
layerParamNames = weightsArchive.getDataSets(rootPrefix + baseAttributes);
} else {
if (kerasVersion == 2) {
if (backend.equals("theano") && layerName.contains("bidirectional")) {
for (String part : attributeStr.split("/")) {
if (part.contains("forward"))
baseAttributes = baseAttributes + "/" + part;
}
}
if (layers.get(layerName).getNumParams() > 0) {
try {
layerParamNames = weightsArchive.getDataSets(rootPrefix + baseAttributes);
} catch (Exception e) {
log.warn("No HDF5 group with weights found for layer with name "
+ layerName + ", continuing import.");
layerParamNames = Collections.emptyList();
}
} else {
layerParamNames = weightsArchive.getDataSets(rootPrefix + layerName);
}
} else {
layerParamNames = weightsArchive.getDataSets(rootPrefix + layerName);
}
}
}
if (layerParamNames.isEmpty())
continue;
if (!layers.containsKey(layerName))
throw new InvalidKerasConfigurationException(
"Found weights for layer not in model (named " + layerName + ")");
KerasLayer layer = layers.get(layerName);
if (layerParamNames.size() != layer.getNumParams())
if (kerasVersion == 2
&& layer instanceof KerasBidirectional && 2 * layerParamNames.size() != layer.getNumParams())
throw new InvalidKerasConfigurationException(
"Found " + layerParamNames.size() + " weights for layer with " + layer.getNumParams()
+ " trainable params (named " + layerName + ")");
Map<String, INDArray> weights = new HashMap<>();
for (String layerParamName : layerParamNames) {
String paramName = KerasModelUtils.findParameterName(layerParamName, layerFragments);
INDArray paramValue;
if (kerasVersion == 2 && layer instanceof KerasBidirectional) {
String backwardAttributes = baseAttributes.replace("forward", "backward");
INDArray forwardParamValue = weightsArchive.readDataSet(layerParamName,
rootPrefix + baseAttributes);
INDArray backwardParamValue = weightsArchive.readDataSet(
layerParamName, rootPrefix + backwardAttributes);
weights.put("forward_" + paramName, forwardParamValue);
weights.put("backward_" + paramName, backwardParamValue);
} else {
if (foundTfGroups) {
paramValue = weightsArchive.readDataSet(layerParamName, rootPrefix + baseAttributes);
} else {
if (layerFragments.length > 1) {
paramValue = weightsArchive.readDataSet(
layerFragments[0] + "/" + layerParamName, rootPrefix, layerName);
} else {
if (kerasVersion == 2) {
paramValue = weightsArchive.readDataSet(
layerParamName, rootPrefix + baseAttributes);
} else {
paramValue = weightsArchive.readDataSet(layerParamName, rootPrefix, layerName);
}
}
}
weights.put(paramName, paramValue);
}
}
layer.setWeights(weights);
}
}
/* Look for layers in model with no corresponding entries in weights map. */
Set<String> layerNames = new HashSet<>(layers.keySet());
layerNames.removeAll(layerGroups);
for (String layerName : layerNames) {
if (layers.get(layerName).getNumParams() > 0)
throw new InvalidKerasConfigurationException("Could not find weights required for layer " + layerName);
}
}
}
/**
* Parse Keras model configuration from JSON or YAML string representation
*
* @param modelJson JSON string representing model (potentially null)
* @param modelYaml YAML string representing model (potentially null)
* @return Model configuration as Map<String, Object>
* @throws IOException IO exception
* @throws InvalidKerasConfigurationException Invalid Keras config
*/
public static Map<String, Object> parseModelConfig(String modelJson, String modelYaml)
throws IOException, InvalidKerasConfigurationException {
Map<String, Object> modelConfig;
if (modelJson != null)
modelConfig = parseJsonString(modelJson);
else if (modelYaml != null)
modelConfig = parseYamlString(modelYaml);
else
throw new InvalidKerasConfigurationException("Requires model configuration as either JSON or YAML string.");
return modelConfig;
}
/**
* Convenience function for parsing JSON strings.
*
* @param json String containing valid JSON
* @return Nested (key,value) map of arbitrary depth
* @throws IOException IO exception
*/
public static Map<String, Object> parseJsonString(String json) throws IOException {
ObjectMapper mapper = new ObjectMapper();
TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>() {
};
return mapper.readValue(json, typeRef);
}
/**
* Convenience function for parsing YAML strings.
*
* @param yaml String containing valid YAML
* @return Nested (key,value) map of arbitrary depth
* @throws IOException IO exception
*/
public static Map<String, Object> parseYamlString(String yaml) throws IOException {
ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>() {
};
return mapper.readValue(yaml, typeRef);
}
}