deeplearning4j/deeplearning4j

View on GitHub
omnihub/src/main/java/org/eclipse/deeplearning4j/omnihub/OmniHubUtils.java

Summary

Maintainability
A
35 mins
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */
package org.eclipse.deeplearning4j.omnihub;

import lombok.SneakyThrows;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.omnihub.OmnihubConfig;
import org.nd4j.autodiff.samediff.SameDiff;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.*;

/**
 * Various utils for dealing with downloading files
 * and handling the omnihub cache.
 *
 * @author Adam Gibson
 */
public class OmniHubUtils {


    /**
     * Loads a {@link MultiLayerNetwork} from the local cache
     * or downloads the file from the remote model zoo.
     * Note this is the same as calling {@link #loadNetwork(String, boolean)}
     * with a value of false
     * @param name the name of the file to load
     * @return the loaded {@link MultiLayerNetwork}
     * @throws IOException
     */
    public static MultiLayerNetwork loadNetwork(String name) throws IOException {
        return  loadNetwork(name,false);
    }


    /**
     * Loads a {@link MultiLayerNetwork} from the local cache
     * or downloads the file from the remote model zoo.
     * @param name the name of the file to load
     * @param forceDownload whether to force a new download if the file exists or not
     * @return the loaded {@link MultiLayerNetwork}
     * @throws IOException
     */
    public static MultiLayerNetwork loadNetwork(String name,boolean forceDownload) throws IOException {
        return MultiLayerNetwork.load(downloadAndLoadFromZoo("dl4j",name,forceDownload),true);
    }



    /**
     * Loads a {@link ComputationGraph} from the local cache
     * or downloads the file from the remote model zoo.
     * Note this is the same as calling {@link #loadCompGraph(String, boolean)}
     * with a value of false.
     * @param name the name of the file to load
     * @return the loaded {@link ComputationGraph}
     * @throws IOException
     */
    public static ComputationGraph loadCompGraph(String name) throws IOException {
        return loadCompGraph(name,false);
    }

    /**
     * Loads a {@link ComputationGraph} from the local cache
     * or downloads the file from the remote model zoo.
     * @param name the name of the file to load
     * @param forceDownload whether to force a new download if the file exists or not
     * @return the loaded {@link ComputationGraph}
     * @throws IOException
     */
    public static ComputationGraph loadCompGraph(String name,boolean forceDownload) throws IOException {
        return ComputationGraph.load(downloadAndLoadFromZoo("dl4j",name,forceDownload),true);
    }


    /**
     * Loads a samediff model from either the local cache
     * or downloads it to the model zoo
     * @param name the name of the file to download
     * @return
     */
    public static SameDiff loadSameDiffModel(String name) {
        return loadSameDiffModel(name,false);
    }

    /**
     * Loads a samediff model from either the local cache
     * or downloads it to the model zoo
     * @param name the name of the file to download
     * @param forceDownload whether to force the download of the file
     * @return the loaded samediff model
     */
    public static SameDiff loadSameDiffModel(String name,boolean forceDownload) {
        return SameDiff.load(downloadAndLoadFromZoo("samediff",name,forceDownload),true);
    }

    /**
     * Download and load a model from the model zoo using the given file name
     * for the given framework
     * @param framework the framework to load from
     * @param name the name of the file to load
     * @param forceDownload whether to force the download
     * @return the
     */
    @SneakyThrows
    public static File downloadAndLoadFromZoo(String framework,String name,boolean forceDownload) {
        File destination = new File(OmnihubConfig.getOmnihubHome(),framework);
        File destFile = new File(destination,name);
        if(forceDownload && destFile.exists()) {
            destFile.delete();
        }
        if(!destFile.exists()) {
            String url = new StringBuilder()
                    .append(OmnihubConfig.getOmnihubUrl()).append("/").append(framework).append("/").
                    append(name).toString();
            URL remoteUrl = URI.create(url).toURL();
            long size = getFileSize(remoteUrl);
            try(InputStream is = new ProgressInputStream(new BufferedInputStream(URI.create(url).toURL().openStream()),size)) {
                FileUtils.copyInputStreamToFile(is,destFile);

            } catch (MalformedURLException e) {
                e.printStackTrace();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        return destFile;
    }


    private static int getFileSize(URL url) {
        URLConnection conn = null;
        try {
            conn = url.openConnection();
            if(conn instanceof HttpURLConnection) {
                ((HttpURLConnection)conn).setRequestMethod("HEAD");
            }
            conn.getInputStream();
            return conn.getContentLength();
        } catch (IOException e) {
            throw new RuntimeException(e);
        } finally {
            if(conn instanceof HttpURLConnection) {
                ((HttpURLConnection)conn).disconnect();
            }
        }
    }

}