deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/DL4JKerasModelValidator.java

Summary

Maintainability
A
2 hrs
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.modelimport.keras.utils;

import lombok.NonNull;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.common.validation.Nd4jCommonValidator;
import org.nd4j.common.validation.ValidationResult;

import java.io.File;
import java.util.Collections;

public class DL4JKerasModelValidator {

    private DL4JKerasModelValidator(){ }

    /**
     * Validate whether the file represents a valid Keras Sequential model (HDF5 archive)
     *
     * @param f File that should represent an saved Keras Sequential model (HDF5 archive)
     * @return Result of validation
     */
    public static ValidationResult validateKerasSequential(@NonNull File f){
        return validateKeras(f, "Keras Sequential Model HDF5", MultiLayerNetwork.class);
    }

    /**
     * Validate whether the file represents a valid Keras Functional model (HDF5 archive)
     *
     * @param f File that should represent an saved Keras Functional model (HDF5 archive)
     * @return Result of validation
     */
    public static ValidationResult validateKerasFunctional(@NonNull File f){
        return validateKeras(f, "Keras Functional Model HDF5", ComputationGraph.class);
    }

    protected static ValidationResult validateKeras(@NonNull File f, String format, Class<?> cl){
        ValidationResult vr = Nd4jCommonValidator.isValidFile(f, format, false);
        if(vr != null && !vr.isValid()) {
            return vr;
        }

        KerasModelConfiguration c = new KerasModelConfiguration();
        Hdf5Archive archive = null;
        try{
            archive = new Hdf5Archive(f.getPath());

            //Check JSON
            try{
                String json = archive.readAttributeAsJson(c.getTrainingModelConfigAttribute());
                vr = Nd4jCommonValidator.isValidJSON(json);
                if(vr != null && !vr.isValid()){
                    vr.setFormatType(format);
                    return vr;
                }
            } catch (Throwable t){
                return ValidationResult.builder()
                        .formatType(format)
                        .formatClass(cl)
                        .valid(false)
                        .path(Nd4jCommonValidator.getPath(f))
                        .issues(Collections.singletonList("Unable to read JSON configuration from Keras Sequential model HDF5 file"))
                        .exception(t)
                        .build();
            }

        } catch (Throwable t){
            return ValidationResult.builder()
                    .formatType(format)
                    .formatClass(cl)
                    .valid(false)
                    .path(Nd4jCommonValidator.getPath(f))
                    .issues(Collections.singletonList("Unable to read from " + format + " file - file is corrupt or not a valid Keras HDF5 archive?"))
                    .exception(t)
                    .build();
        }


        return ValidationResult.builder()
                .formatType(format)
                .formatClass(cl)
                .valid(true)
                .path(Nd4jCommonValidator.getPath(f))
                .build();
    }
}