deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java

Summary

Maintainability
F
1 wk
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.ui.model.stats;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.config.DL4JClassLoading;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.core.storage.StorageMetaData;
import org.deeplearning4j.core.storage.listener.RoutingIterationListener;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.ui.model.stats.api.*;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.model.stats.impl.DefaultStatsInitializationConfiguration;
import org.deeplearning4j.ui.model.stats.impl.DefaultStatsUpdateConfiguration;
import org.deeplearning4j.core.util.UIDProvider;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;

import java.io.InputStream;
import java.io.Serializable;
import java.lang.management.GarbageCollectorMXBean;
import java.lang.management.ManagementFactory;
import java.lang.management.OperatingSystemMXBean;
import java.lang.management.RuntimeMXBean;
import java.util.*;

@Slf4j
public abstract class BaseStatsListener implements RoutingIterationListener {
    public static final String TYPE_ID = "StatsListener";

    private enum StatType {
        Mean, Stdev, MeanMagnitude
    }

    private StatsStorageRouter router;
    private final StatsInitializationConfiguration initConfig;
    private StatsUpdateConfiguration updateConfig;
    private String sessionID;
    private String workerID;

    private transient List<GarbageCollectorMXBean> gcBeans;
    private Map<String, Pair<Long, Long>> gcStatsAtLastReport;

    //NOTE: may have multiple models, due to multiple pretrain layers all using the same StatsListener
    private List<ModelInfo> modelInfos = new ArrayList<>();

    private Map<String, Histogram> activationHistograms;
    private Map<String, Double> meanActivations;        //TODO replace with Eclipse collections primitive maps...
    private Map<String, Double> stdevActivations;
    private Map<String, Double> meanMagActivations;

    private Map<String, Histogram> gradientHistograms;
    private Map<String, Double> meanGradients;        //TODO replace with Eclipse collections primitive maps...
    private Map<String, Double> stdevGradient;
    private Map<String, Double> meanMagGradients;

    private static class ModelInfo implements Serializable {
        private final Model model;
        private long initTime;
        private long lastReportTime = -1;
        private int lastReportIteration = -1;
        private int examplesSinceLastReport = 0;
        private int minibatchesSinceLastReport = 0;

        private long totalExamples = 0;
        private long totalMinibatches = 0;

        private int iterCount = 0;

        private ModelInfo(Model model) {
            this.model = model;
        }
    }

    private ModelInfo getModelInfo(Model model) {
        ModelInfo mi = null;
        for (ModelInfo m : modelInfos) {
            if (m.model == model) {
                mi = m;
                break;
            }
        }
        if (mi == null) {
            mi = new ModelInfo(model);
            modelInfos.add(mi);
        }
        return mi;
    }

    /**
     * Create a StatsListener with network information collected at every iteration.
     *
     * @param router Where/how to store the calculated stats. For example, {@link InMemoryStatsStorage} or
     *               {@link FileStatsStorage}
     */
    public BaseStatsListener(StatsStorageRouter router) {
        this(router, null, null, null, null);
    }

    /**
     * Create a StatsListener with network information collected every n >= 1 time steps
     *
     * @param router            Where/how to store the calculated stats. For example, {@link InMemoryStatsStorage} or
     *                          {@link FileStatsStorage}
     * @param listenerFrequency Frequency with which to collect stats information
     */
    public BaseStatsListener(StatsStorageRouter router, int listenerFrequency) {
        this(router, null, new DefaultStatsUpdateConfiguration.Builder().reportingFrequency(listenerFrequency).build(),
                null, null);
    }

    public BaseStatsListener(StatsStorageRouter router, StatsInitializationConfiguration initConfig,
                             StatsUpdateConfiguration updateConfig, String sessionID, String workerID) {
        this.router = router;
        if (initConfig == null) {
            this.initConfig = new DefaultStatsInitializationConfiguration(true, true, true);
        } else {
            this.initConfig = initConfig;
        }
        if (updateConfig == null) {
            this.updateConfig = new DefaultStatsUpdateConfiguration.Builder().build();
        } else {
            this.updateConfig = updateConfig;
        }
        if (sessionID == null) {
            //TODO handle syncing session IDs across different listeners in the same model...
            this.sessionID = UUID.randomUUID().toString();
        } else {
            this.sessionID = sessionID;
        }
        if (workerID == null) {
            this.workerID = UIDProvider.getJVMUID() + "_" + Thread.currentThread().getId();
        } else {
            this.workerID = workerID;
        }
    }

    public abstract StatsInitializationReport getNewInitializationReport();

    public abstract StatsReport getNewStatsReport();

    //    public abstract StorageMetaData getNewStorageMetaData();
    public abstract StorageMetaData getNewStorageMetaData(long initTime, String sessionID, String workerID);
    //                                                          Class<? extends StatsInitializationReport> initializationReportClass,
    //                                                          Class<? extends StatsReport> statsReportClass);
    //new SbeStorageMetaData(initTime, getSessionID(model), TYPE_ID, workerID, SbeStatsInitializationReport.class, SbeStatsReport.class);


    public StatsInitializationConfiguration getInitConfig() {
        return initConfig;
    }

    public StatsUpdateConfiguration getUpdateConfig() {
        return updateConfig;
    }

    public void setUpdateConfig(StatsUpdateConfiguration newConfig) {
        this.updateConfig = newConfig;
    }

    @Override
    public void setStorageRouter(StatsStorageRouter router) {
        this.router = router;
    }

    @Override
    public StatsStorageRouter getStorageRouter() {
        return router;
    }

    @Override
    public void setWorkerID(String workerID) {
        this.workerID = workerID;
    }

    @Override
    public String getWorkerID() {
        return workerID;
    }

    @Override
    public void setSessionID(String sessionID) {
        this.sessionID = sessionID;
    }

    @Override
    public String getSessionID() {
        return sessionID;
    }

    private String getSessionID(Model model) {
        if (model instanceof MultiLayerNetwork || model instanceof ComputationGraph)
            return sessionID;
        if (model instanceof Layer) {
            //Keep in mind MultiLayerNetwork implements Layer also...
            Layer l = (Layer) model;
            int layerIdx = l.getIndex();
            return sessionID + "_layer" + layerIdx;
        }
        return sessionID; //Should never happen
    }

    @Override
    public void onEpochStart(Model model) {

    }

    @Override
    public void onEpochEnd(Model model) {

    }

    @Override
    public void onForwardPass(Model model, List<INDArray> activations) {
        int iterCount = getModelInfo(model).iterCount;
        if (calcFromActivations() && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) {
            //Assumption: we have input, layer 0, layer 1, ...
            Map<String, INDArray> activationsMap = new HashMap<>();
            int count = 0;
            for (INDArray arr : activations) {
                String layerName = (count == 0 ? "input" : String.valueOf(count - 1));
                activationsMap.put(layerName, arr);
                count++;
            }
            onForwardPass(model, activationsMap);
        }
    }

    @Override
    public void onForwardPass(Model model, Map<String, INDArray> activations) {
        int iterCount = getModelInfo(model).iterCount;
        if (calcFromActivations() && updateConfig.reportingFrequency() > 0
                && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) {
            if (updateConfig.collectHistograms(StatsType.Activations)) {
                activationHistograms = getHistograms(activations, updateConfig.numHistogramBins(StatsType.Activations));
            }
            if (updateConfig.collectMean(StatsType.Activations)) {
                meanActivations = calculateSummaryStats(activations, StatType.Mean);
            }
            if (updateConfig.collectStdev(StatsType.Activations)) {
                stdevActivations = calculateSummaryStats(activations, StatType.Stdev);
            }
            if (updateConfig.collectMeanMagnitudes(StatsType.Activations)) {
                meanMagActivations = calculateSummaryStats(activations, StatType.MeanMagnitude);
            }
        }
    }

    @Override
    public void onGradientCalculation(Model model) {
        int iterCount = getModelInfo(model).iterCount;
        if (calcFromGradients() && updateConfig.reportingFrequency() > 0
                && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) {
            Gradient g = model.gradient();
            if (updateConfig.collectHistograms(StatsType.Gradients)) {
                gradientHistograms = getHistograms(g.gradientForVariable(), updateConfig.numHistogramBins(StatsType.Gradients));
            }

            if (updateConfig.collectMean(StatsType.Gradients)) {
                meanGradients = calculateSummaryStats(g.gradientForVariable(), StatType.Mean);
            }
            if (updateConfig.collectStdev(StatsType.Gradients)) {
                stdevGradient = calculateSummaryStats(g.gradientForVariable(), StatType.Stdev);
            }
            if (updateConfig.collectMeanMagnitudes(StatsType.Gradients)) {
                meanMagGradients = calculateSummaryStats(g.gradientForVariable(), StatType.MeanMagnitude);
            }
        }
    }

    private boolean calcFromActivations() {
        return updateConfig.collectMean(StatsType.Activations) || updateConfig.collectStdev(StatsType.Activations)
                || updateConfig.collectMeanMagnitudes(StatsType.Activations)
                || updateConfig.collectHistograms(StatsType.Activations);
    }

    private boolean calcFromGradients() {
        return updateConfig.collectMean(StatsType.Gradients) || updateConfig.collectStdev(StatsType.Gradients)
                || updateConfig.collectMeanMagnitudes(StatsType.Gradients)
                || updateConfig.collectHistograms(StatsType.Gradients);
    }

    @Override
    public void onBackwardPass(Model model) {
        //No op
    }

    @Override
    public void iterationDone(Model model, int iteration, int epoch) {

        ModelInfo modelInfo = getModelInfo(model);
        boolean backpropParamsOnly = backpropParamsOnly(model);

        long currentTime = getTime();
        if (modelInfo.iterCount == 0) {
            modelInfo.initTime = currentTime;
            doInit(model);
        }

        if (updateConfig.collectPerformanceStats()) {
            updateExamplesMinibatchesCounts(model);
        }

        if (updateConfig.reportingFrequency() > 1 && (iteration == 0 || iteration % updateConfig.reportingFrequency() != 0)) {
            modelInfo.iterCount = iteration;
            return;
        }

        StatsReport report = getNewStatsReport();
        report.reportIDs(getSessionID(model), TYPE_ID, workerID, System.currentTimeMillis()); //TODO support NTP time

        //--- Performance and System Stats ---
        if (updateConfig.collectPerformanceStats()) {
            //Stats to collect: total runtime, total examples, total minibatches, iterations/second, examples/second
            double examplesPerSecond;
            double minibatchesPerSecond;
            if (modelInfo.iterCount == 0) {
                //Not possible to work out perf/second: first iteration...
                examplesPerSecond = 0.0;
                minibatchesPerSecond = 0.0;
            } else {
                long deltaTimeMS = currentTime - modelInfo.lastReportTime;
                examplesPerSecond = 1000.0 * modelInfo.examplesSinceLastReport / deltaTimeMS;
                minibatchesPerSecond = 1000.0 * modelInfo.minibatchesSinceLastReport / deltaTimeMS;
            }
            long totalRuntimeMS = currentTime - modelInfo.initTime;
            report.reportPerformance(totalRuntimeMS, modelInfo.totalExamples, modelInfo.totalMinibatches,
                    examplesPerSecond, minibatchesPerSecond);

            modelInfo.examplesSinceLastReport = 0;
            modelInfo.minibatchesSinceLastReport = 0;
        }

        if (updateConfig.collectMemoryStats()) {

            Runtime runtime = Runtime.getRuntime();
            long jvmTotal = runtime.totalMemory();
            long jvmMax = runtime.maxMemory();

            //Off-heap memory
            long offheapTotal = Pointer.totalBytes();
            long offheapMax = Pointer.maxBytes();

            //GPU
            long[] gpuCurrentBytes = null;
            long[] gpuMaxBytes = null;
            NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
            int nDevices = nativeOps.getAvailableDevices();
            if (nDevices > 0) {
                gpuCurrentBytes = new long[nDevices];
                gpuMaxBytes = new long[nDevices];
                for (int i = 0; i < nDevices; i++) {
                    try {
                        gpuMaxBytes[i] = nativeOps.getDeviceTotalMemory(0);
                        gpuCurrentBytes[i] = gpuMaxBytes[i] - nativeOps.getDeviceFreeMemory(0);
                    } catch (Exception e) {
                        log.error("",e);
                    }
                }
            }

            report.reportMemoryUse(jvmTotal, jvmMax, offheapTotal, offheapMax, gpuCurrentBytes, gpuMaxBytes);
        }

        if (updateConfig.collectGarbageCollectionStats()) {
            if (modelInfo.lastReportIteration == -1 || gcBeans == null) {
                //Haven't reported GC stats before...
                gcBeans = ManagementFactory.getGarbageCollectorMXBeans();
                gcStatsAtLastReport = new HashMap<>();
                for (GarbageCollectorMXBean bean : gcBeans) {
                    long count = bean.getCollectionCount();
                    long timeMs = bean.getCollectionTime();
                    gcStatsAtLastReport.put(bean.getName(), new Pair<>(count, timeMs));
                }
            } else {
                for (GarbageCollectorMXBean bean : gcBeans) {
                    long count = bean.getCollectionCount();
                    long timeMs = bean.getCollectionTime();
                    Pair<Long, Long> lastStats = gcStatsAtLastReport.get(bean.getName());
                    long deltaGCCount = count - lastStats.getFirst();
                    long deltaGCTime = timeMs - lastStats.getSecond();

                    lastStats.setFirst(count);
                    lastStats.setSecond(timeMs);
                    report.reportGarbageCollection(bean.getName(), (int) deltaGCCount, (int) deltaGCTime);
                }
            }
        }

        //--- General ---
        report.reportScore(model.score()); //Always report score

        if (updateConfig.collectLearningRates()) {
            Map<String, Double> lrs = new HashMap<>();
            if (model instanceof MultiLayerNetwork) {
                //Need to append "0_", "1_" etc to param names from layers...
                int layerIdx = 0;
                for (Layer l : ((MultiLayerNetwork) model).getLayers()) {
                    NeuralNetConfiguration conf = l.conf();
                    List<String> paramkeys = l.conf().getLayer().initializer().paramKeys(l.conf().getLayer());
                    for (String s : paramkeys) {
                        double lr = conf.getLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount());
                        if (Double.isNaN(lr)) {
                            //Edge case: No-Op updater, AdaDelta etc - don't have a LR hence return NaN for IUpdater.getLearningRate
                            lr = 0.0;
                        }
                        lrs.put(layerIdx + "_" + s, lr);
                    }
                    layerIdx++;
                }
            } else if (model instanceof ComputationGraph) {
                for (Layer l : ((ComputationGraph) model).getLayers()) {
                    NeuralNetConfiguration conf = l.conf();
                    String layerName = conf.getLayer().getLayerName();
                    List<String> paramkeys = l.conf().getLayer().initializer().paramKeys(l.conf().getLayer());
                    for (String s : paramkeys) {
                        double lr = conf.getLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount());
                        if (Double.isNaN(lr)) {
                            //Edge case: No-Op updater, AdaDelta etc - don't have a LR hence return NaN for IUpdater.getLearningRate
                            lr = 0.0;
                        }
                        lrs.put(layerName + "_" + s, lr);
                    }
                }
            } else if (model instanceof Layer) {
                Layer l = (Layer) model;
                List<String> paramkeys = l.conf().getLayer().initializer().paramKeys(l.conf().getLayer());
                for (String s : paramkeys) {
                    double lr = l.conf().getLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount());
                    lrs.put(s, lr);
                }
            }
            report.reportLearningRates(lrs);
        }


        //--- Histograms ---

        if (updateConfig.collectHistograms(StatsType.Parameters)) {
            Map<String, Histogram> paramHistograms = getHistograms(model.paramTable(backpropParamsOnly),
                    updateConfig.numHistogramBins(StatsType.Parameters));
            report.reportHistograms(StatsType.Parameters, paramHistograms);
        }

        if (updateConfig.collectHistograms(StatsType.Gradients)) {
            report.reportHistograms(StatsType.Gradients, gradientHistograms);
        }

        if (updateConfig.collectHistograms(StatsType.Updates)) {
            Map<String, Histogram> updateHistograms = getHistograms(model.gradient().gradientForVariable(),
                    updateConfig.numHistogramBins(StatsType.Updates));
            report.reportHistograms(StatsType.Updates, updateHistograms);
        }

        if (updateConfig.collectHistograms(StatsType.Activations)) {
            report.reportHistograms(StatsType.Activations, activationHistograms);
        }


        //--- Summary Stats: Mean, Variance, Mean Magnitudes ---

        if (updateConfig.collectMean(StatsType.Parameters)) {
            Map<String, Double> meanParams = calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Mean);
            report.reportMean(StatsType.Parameters, meanParams);
        }

        if (updateConfig.collectMean(StatsType.Gradients)) {
            report.reportMean(StatsType.Gradients, meanGradients);
        }

        if (updateConfig.collectMean(StatsType.Updates)) {
            Map<String, Double> meanUpdates =
                    calculateSummaryStats(model.gradient().gradientForVariable(), StatType.Mean);
            report.reportMean(StatsType.Updates, meanUpdates);
        }

        if (updateConfig.collectMean(StatsType.Activations)) {
            report.reportMean(StatsType.Activations, meanActivations);
        }


        if (updateConfig.collectStdev(StatsType.Parameters)) {
            Map<String, Double> stdevParams =
                    calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Stdev);
            report.reportStdev(StatsType.Parameters, stdevParams);
        }

        if (updateConfig.collectStdev(StatsType.Gradients)) {
            report.reportStdev(StatsType.Gradients, stdevGradient);
        }

        if (updateConfig.collectStdev(StatsType.Updates)) {
            Map<String, Double> stdevUpdates =
                    calculateSummaryStats(model.gradient().gradientForVariable(), StatType.Stdev);
            report.reportStdev(StatsType.Updates, stdevUpdates);
        }

        if (updateConfig.collectStdev(StatsType.Activations)) {
            report.reportStdev(StatsType.Activations, stdevActivations);
        }


        if (updateConfig.collectMeanMagnitudes(StatsType.Parameters)) {
            Map<String, Double> meanMagParams =
                    calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.MeanMagnitude);
            report.reportMeanMagnitudes(StatsType.Parameters, meanMagParams);
        }

        if (updateConfig.collectMeanMagnitudes(StatsType.Gradients)) {
            report.reportMeanMagnitudes(StatsType.Gradients, meanMagGradients);
        }

        if (updateConfig.collectMeanMagnitudes(StatsType.Updates)) {
            Map<String, Double> meanMagUpdates =
                    calculateSummaryStats(model.gradient().gradientForVariable(), StatType.MeanMagnitude);
            report.reportMeanMagnitudes(StatsType.Updates, meanMagUpdates);
        }

        if (updateConfig.collectMeanMagnitudes(StatsType.Activations)) {
            report.reportMeanMagnitudes(StatsType.Activations, meanMagActivations);
        }


        long endTime = getTime();
        report.reportStatsCollectionDurationMS((int) (endTime - currentTime)); //Amount of time required to alculate all histograms, means etc.
        modelInfo.lastReportTime = currentTime;
        modelInfo.lastReportIteration = iteration;
        report.reportIterationCount(iteration);

        this.router.putUpdate(report);

        modelInfo.iterCount = iteration;
        activationHistograms = null;
        meanActivations = null;
        stdevActivations = null;
        meanMagActivations = null;
        gradientHistograms = null;
        meanGradients = null;
        stdevGradient = null;
        meanMagGradients = null;
    }

    private long getTime() {
        //Abstraction to allow NTP to be plugged in later...
        return System.currentTimeMillis();
    }

    private void doInit(Model model) {
        boolean backpropParamsOnly = backpropParamsOnly(model);
        long initTime = System.currentTimeMillis(); //TODO support NTP
        StatsInitializationReport initReport = getNewInitializationReport();
        initReport.reportIDs(getSessionID(model), TYPE_ID, workerID, initTime);

        if (initConfig.collectSoftwareInfo()) {
            OperatingSystemMXBean osBean = ManagementFactory.getOperatingSystemMXBean();
            RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean();

            String arch = osBean.getArch();
            String osName = osBean.getName();
            String jvmName = runtime.getVmName();
            String jvmVersion = System.getProperty("java.version");
            String jvmSpecVersion = runtime.getSpecVersion();

            String nd4jBackendClass = Nd4j.getNDArrayFactory().getClass().getName();
            String nd4jDataTypeName = DataTypeUtil.getDtypeFromContext().name();

            String hostname = System.getenv("COMPUTERNAME");
            if (hostname == null || hostname.isEmpty()) {
                try {
                    Process proc = Runtime.getRuntime().exec("hostname");
                    try (InputStream stream = proc.getInputStream()) {
                        hostname = IOUtils.toString(stream);
                    }
                } catch (Exception e) {
                }
            }

            Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
            Map<String, String> envInfo = new HashMap<>();
            for (Map.Entry<Object, Object> e : p.entrySet()) {
                Object v = e.getValue();
                String value = (v == null ? "" : v.toString());
                envInfo.put(e.getKey().toString(), value);
            }

            initReport.reportSoftwareInfo(arch, osName, jvmName, jvmVersion, jvmSpecVersion, nd4jBackendClass,
                    nd4jDataTypeName, hostname, UIDProvider.getJVMUID(), envInfo);
        }

        if (initConfig.collectHardwareInfo()) {
            int availableProcessors = Runtime.getRuntime().availableProcessors();
            NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
            int nDevices = nativeOps.getAvailableDevices();

            long[] deviceTotalMem = null;
            String[] deviceDescription = null; //TODO
            if (nDevices > 0) {
                deviceTotalMem = new long[nDevices];
                deviceDescription = new String[nDevices];
                for (int i = 0; i < nDevices; i++) {
                    try {
                        deviceTotalMem[i] = nativeOps.getDeviceTotalMemory(i);
                        deviceDescription[i] = nativeOps.getDeviceName(i);
                        if (nDevices > 1) {
                            deviceDescription[i] = deviceDescription[i] + " (" + i + ")";
                        }
                    } catch (Exception e) {
                        log.debug("Error getting device info", e);
                    }
                }
            }
            long jvmMaxMemory = Runtime.getRuntime().maxMemory();
            long offheapMaxMemory = Pointer.maxBytes();

            initReport.reportHardwareInfo(availableProcessors, nDevices, jvmMaxMemory, offheapMaxMemory, deviceTotalMem,
                    deviceDescription, UIDProvider.getHardwareUID());
        }

        if (initConfig.collectModelInfo()) {
            String jsonConf;
            int numLayers;
            long numParams;
            if (model instanceof MultiLayerNetwork) {
                MultiLayerNetwork net = ((MultiLayerNetwork) model);
                jsonConf = net.getLayerWiseConfigurations().toJson();
                numLayers = net.getnLayers();
                numParams = net.numParams();
            } else if (model instanceof ComputationGraph) {
                ComputationGraph cg = ((ComputationGraph) model);
                jsonConf = cg.getConfiguration().toJson();
                numLayers = cg.getNumLayers();
                numParams = cg.numParams();
            } else if (model instanceof Layer) {
                Layer l = (Layer) model;
                jsonConf = l.conf().toJson();
                numLayers = 1;
                numParams = l.numParams();
            } else {
                throw new RuntimeException("Invalid model: Expected MultiLayerNetwork or ComputationGraph. Got: "
                        + (model == null ? null : model.getClass()));
            }

            Map<String, INDArray> paramMap = model.paramTable(backpropParamsOnly);
            String[] paramNames = new String[paramMap.size()];
            int i = 0;
            for (String s : paramMap.keySet()) { //Assuming sensible iteration order - LinkedHashMaps are used in MLN/CG for example
                paramNames[i++] = s;
            }

            initReport.reportModelInfo(model.getClass().getName(), jsonConf, paramNames, numLayers, numParams);
        }

        StorageMetaData meta = getNewStorageMetaData(initTime, getSessionID(model), workerID);

        router.putStorageMetaData(meta);
        router.putStaticInfo(initReport); //TODO error handling
    }

    private Map<Integer, Pointer> devPointers = new HashMap<>();

    private synchronized Pointer getDevicePointer(int device) {
        if (devPointers.containsKey(device)) {
            return devPointers.get(device);
        }
        try {
            Pointer pointer = DL4JClassLoading.createNewInstance(
                    "org.nd4j.jita.allocator.pointers.CudaPointer",
                    Pointer.class,
                    new Class[] { long.class },
                    new Object[]{(long) device});

            devPointers.put(device, pointer);
            return pointer;
        } catch (Throwable t) {
            devPointers.put(device, null); //Stops attempting the failure again later...
            return null;
        }
    }

    private void updateExamplesMinibatchesCounts(Model model) {
        ModelInfo modelInfo = getModelInfo(model);
        int examplesThisMinibatch = 0;
        if (model instanceof MultiLayerNetwork) {
            examplesThisMinibatch = model.batchSize();
        } else if (model instanceof ComputationGraph) {
            examplesThisMinibatch = model.batchSize();
        } else if (model instanceof Layer) {
            examplesThisMinibatch = ((Layer) model).getInputMiniBatchSize();
        }
        modelInfo.examplesSinceLastReport += examplesThisMinibatch;
        modelInfo.totalExamples += examplesThisMinibatch;
        modelInfo.minibatchesSinceLastReport++;
        modelInfo.totalMinibatches++;
    }

    private boolean backpropParamsOnly(Model model) {
        //For pretrain layers (VAE, AE) we *do* want pretrain params also; for MLN and CG we only want backprop params
        // as we only have backprop gradients
        return model instanceof MultiLayerNetwork || model instanceof ComputationGraph;
    }

    private static Map<String, Double> calculateSummaryStats(Map<String, INDArray> source, StatType statType) {
        Map<String, Double> out = new LinkedHashMap<>();

        if (source == null)
            return out;

        for (Map.Entry<String, INDArray> entry : source.entrySet()) {
            String name = entry.getKey();
            double value;
            switch (statType) {
                case Mean:
                    value = entry.getValue().meanNumber().doubleValue();
                    break;
                case Stdev:
                    value = entry.getValue().stdNumber().doubleValue();
                    break;
                case MeanMagnitude:
                    value = entry.getValue().norm1Number().doubleValue() / entry.getValue().length();
                    break;
                default:
                    throw new RuntimeException(); //Should never happen
            }
            out.put(name, value);
        }
        return out;
    }

    private static Map<String, Histogram> getHistograms(Map<String, INDArray> map, int nBins) {
        Map<String, Histogram> out = new LinkedHashMap<>();

        if (map == null)
            return out;

        for (Map.Entry<String, INDArray> entry : map.entrySet()) {

            org.nd4j.linalg.api.ops.impl.transforms.Histogram hOp =
                    new org.nd4j.linalg.api.ops.impl.transforms.Histogram(entry.getValue(), nBins);
            Nd4j.exec(hOp);

            INDArray bins = hOp.getOutputArgument(0);
            int[] count = new int[nBins];
            for (int i = 0; i < bins.length(); i++) {
                count[i] = (int) bins.getDouble(i);
            }

            double min = entry.getValue().minNumber().doubleValue();
            double max = entry.getValue().maxNumber().doubleValue();

            Histogram h = new Histogram(min, max, nBins, count);

            out.put(entry.getKey(), h);
        }
        return out;
    }

    @Override
    public abstract BaseStatsListener clone();
}