deeplearning4j/deeplearning4j

View on GitHub
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java

Summary

Maintainability
F
3 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.imports.tensorflow;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.archivers.ArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.input.CloseShieldInputStream;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.common.util.ArchiveUtils;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

import java.io.*;
import java.util.*;
import java.util.zip.GZIPInputStream;
import java.util.zip.ZipFile;

@Slf4j
public class TensorFlowImportValidator {

    /**
     * Recursively scan the specified directory for .pb files, and evaluate which operations/graphs can/can't be imported
     * @param directory Directory to scan
     * @return Status for TensorFlow import for all models in
     * @throws IOException
     */
    public static TFImportStatus checkAllModelsForImport(@NonNull File directory) throws IOException {
        return checkModelForImport(directory, false);
    }

    public static TFImportStatus checkAllModelsForImport(@NonNull File directory, boolean includeArchives) throws IOException {

        List<String> fileExts = new ArrayList<>();
        fileExts.add("pb");
        if (includeArchives) {
            fileExts.addAll(Arrays.asList("zip", "tar.gz", "gzip", "tgz", "gz", "7z", "tar.bz2", "tar.gz2", "tar.lz", "tar.lzma", "tg", "tar"));
        }

        return checkAllModelsForImport(directory, fileExts.toArray(new String[fileExts.size()]));
    }

    public static TFImportStatus checkAllModelsForImport(File directory, String[] fileExtensions) throws IOException {
        Preconditions.checkState(directory.isDirectory(), "Specified directory %s is not actually a directory", directory);


        Collection<File> files = FileUtils.listFiles(directory, fileExtensions, true);
        Preconditions.checkState(!files.isEmpty(), "No model files found in directory %s", directory);

        TFImportStatus status = null;
        for(File f : files){
            if(isArchiveFile(f)){
                String p = f.getAbsolutePath();
                log.info("Checking archive file for .pb files: " + p);

                String ext = FilenameUtils.getExtension(p).toLowerCase();
                switch (ext){
                    case "zip":
                        List<String> filesInZip;
                        try {
                            filesInZip = ArchiveUtils.zipListFiles(f);
                        } catch (Throwable t){
                            log.warn("Unable to read from file, skipping: {}", f.getAbsolutePath(), t);
                            continue;
                        }
                        for(String s : filesInZip){
                            if(s.endsWith(".pb")){
                                try (ZipFile zf = new ZipFile(f); InputStream is = zf.getInputStream(zf.getEntry(s))){
                                    String p2 = p + "/" + s;
                                    log.info("Found possible frozen model (.pb) file in zip archive: {}", p2);
                                    TFImportStatus currStatus = checkModelForImport(p2,  is, false);
                                    if(currStatus.getCantImportModelPaths() != null && !currStatus.getCantImportModelPaths().isEmpty()){
                                        log.info("Unable to load - not a frozen model .pb file: {}", p2);
                                    } else {
                                        log.info("Found frozen model .pb file in archive: {}", p2);
                                    }
                                    status = (status == null ? currStatus : status.merge(currStatus));
                                }
                            }
                        }
                        break;
                    case "tar":
                    case "tar.gz":
                    case "tar.bz2":
                    case "tgz":
                    case "gz":
                    case "bz2":
                        if(p.endsWith(".tar.gz") || p.endsWith(".tgz") || p.endsWith(".tar") || p.endsWith(".tar.bz2")) {
                            boolean isTar = p.endsWith(".tar");
                            List<String> filesInTarGz;
                            try {
                                filesInTarGz = isTar ? ArchiveUtils.tarListFiles(f) : ArchiveUtils.tarGzListFiles(f);
                            } catch (Throwable t){
                                log.warn("Unable to read from file, skipping: {}", f.getAbsolutePath(), t);
                                continue;
                            }
                            for (String s : filesInTarGz) {
                                if (s.endsWith(".pb")) {
                                    TarArchiveInputStream is;
                                    if(p.endsWith(".tar")){
                                        is = new TarArchiveInputStream(new BufferedInputStream(new FileInputStream(f)));
                                    } else if(p.endsWith(".tar.gz") || p.endsWith(".tgz")){
                                        is = new TarArchiveInputStream(new GZIPInputStream(new BufferedInputStream(new FileInputStream(f))));
                                    } else if(p.endsWith(".tar.bz2")){
                                        is = new TarArchiveInputStream(new BZip2CompressorInputStream(new BufferedInputStream(new FileInputStream(f))));
                                    } else {
                                        throw new RuntimeException("Can't parse file type: " + s);
                                    }

                                    try {
                                        String p2 = p + "/" + s;
                                        log.info("Found possible frozen model (.pb) file in {} archive: {}", ext, p2);

                                        ArchiveEntry entry;
                                        boolean found = false;
                                        while((entry = is.getNextTarEntry()) != null){
                                            String name = entry.getName();
                                            if(s.equals(name)){
                                                //Found entry we want...
                                                TFImportStatus currStatus = checkModelForImport(p2, new CloseShieldInputStream(is), false);
                                                if(currStatus.getCantImportModelPaths() != null && !currStatus.getCantImportModelPaths().isEmpty()){
                                                    log.info("Unable to load - not a frozen model .pb file: {}", p2);
                                                } else {
                                                    log.info("Found frozen model .pb file in archive: {}", p2);
                                                }
                                                status = (status == null ? currStatus : status.merge(currStatus));
                                                found = true;
                                            }
                                        }
                                        Preconditions.checkState(found, "Could not find expected tar entry in file: " + p2);
                                    } finally {
                                        is.close();
                                    }
                                }
                            }
                            break;
                        }
                        //Fall through for .gz - FilenameUtils.getExtension("x.tar.gz") returns "gz" :/
                    case "gzip":
                        //Assume single file...
                        try(InputStream is = new GZIPInputStream(new BufferedInputStream(new FileInputStream(f)))){
                            try {
                                TFImportStatus currStatus = checkModelForImport(f.getAbsolutePath(), is, false);
                                status = (status == null ? currStatus : status.merge(currStatus));
                            } catch (Throwable t){
                                log.warn("Unable to read from file, skipping: {}", f.getAbsolutePath(), t);
                                continue;
                            }
                        }
                        break;
                    default:
                        throw new UnsupportedOperationException("Archive type not yet implemented: " + f.getAbsolutePath());
                }
            } else {
                log.info("Checking model file: " + f.getAbsolutePath());
                TFImportStatus currStatus = checkModelForImport(f);
                status = (status == null ? currStatus : status.merge(currStatus));
            }

            System.out.println("DONE FILE: " + f.getAbsolutePath() + " - totalOps = " + (status == null ? 0 : status.getOpNames().size())
                    + " - supported ops: " + (status == null ? 0 : status.getImportSupportedOpNames().size())
                    + " - unsupported ops: " + (status == null ? 0 : status.getUnsupportedOpNames().size())
            );
        }
        return status;
    }

    public static boolean isArchiveFile(File f){
        return !f.getPath().endsWith(".pb");
    }

    /**
     * See {@link #checkModelForImport(File)}. Defaults to exceptionOnRead = false
     */
    public static TFImportStatus checkModelForImport(@NonNull File file) throws IOException {
        return checkModelForImport(file, false);
    }

    /**
     * Check whether the TensorFlow frozen model (protobuf format) can be imported into SameDiff or not
     * @param file            Protobuf file
     * @param exceptionOnRead If true, and the file can't be read, throw an exception. If false, return an "empty" TFImportStatus
     * @return Status for importing the file
     * @throws IOException If error
     */
    public static TFImportStatus checkModelForImport(@NonNull File file, boolean exceptionOnRead) throws IOException {
        try (InputStream is = new FileInputStream(file)) {
            return checkModelForImport(file.getAbsolutePath(), is, exceptionOnRead);
        }
    }

    public static TFImportStatus checkModelForImport(String path, InputStream is, boolean exceptionOnRead) throws IOException {

        try {
            int opCount = 0;
            Set<String> opNames = new HashSet<>();
            Map<String,Integer> opCounts = new HashMap<>();

            try(InputStream bis = new BufferedInputStream(is)) {
                GraphDef graphDef = GraphDef.parseFrom(bis);
                List<NodeDef> nodes = new ArrayList<>(graphDef.getNodeCount());
                for( int i=0; i<graphDef.getNodeCount(); i++ ){
                    nodes.add(graphDef.getNode(i));
                }

                if(nodes.isEmpty()){
                    throw new IllegalStateException("Error loading model for import - loaded graph def has no nodes (empty/corrupt file?): " + path);
                }

                for (NodeDef nd : nodes) {
                    if (TFGraphMapper.isVariableNode(nd) || TFGraphMapper.isPlaceHolder(nd))
                        continue;

                    String op = nd.getOp();
                    opNames.add(op);
                    int soFar = opCounts.containsKey(op) ? opCounts.get(op) : 0;
                    opCounts.put(op, soFar + 1);
                    opCount++;
                }
            }

            Set<String> importSupportedOpNames = new HashSet<>();
            Set<String> unsupportedOpNames = new HashSet<>();
            Map<String,Set<String>> unsupportedOpModel = new HashMap<>();

            for (String s : opNames) {
                if (DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(s) != null) {
                    importSupportedOpNames.add(s);
                } else {
                    unsupportedOpNames.add(s);
                    if(unsupportedOpModel.containsKey(s)) {
                        continue;
                    } else {
                        Set<String> l = new HashSet<>();
                        l.add(path);
                        unsupportedOpModel.put(s, l);
                    }

                }
            }




            return new TFImportStatus(
                    Collections.singletonList(path),
                    unsupportedOpNames.size() > 0 ? Collections.singletonList(path) : Collections.<String>emptyList(),
                    Collections.<String>emptyList(),
                    opCount,
                    opNames.size(),
                    opNames,
                    opCounts,
                    importSupportedOpNames,
                    unsupportedOpNames,
                    unsupportedOpModel);
        } catch (Throwable t){
            if(exceptionOnRead) {
                throw new IOException("Error reading model from path " + path + " - not a TensorFlow frozen model in ProtoBuf format?", t);
            }
            log.warn("Failed to import model from: " + path + " - not a TensorFlow frozen model in ProtoBuf format?", t);
            return new TFImportStatus(
                    Collections.<String>emptyList(),
                    Collections.<String>emptyList(),
                    Collections.singletonList(path),
                    0,
                    0,
                    Collections.<String>emptySet(),
                    Collections.<String, Integer>emptyMap(),
                    Collections.<String>emptySet(),
                    Collections.<String>emptySet(),
                    Collections.<String, Set<String>>emptyMap());
        }
    }
}