deeplearning4j/deeplearning4j

View on GitHub
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java

Summary

Maintainability
F
5 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.listeners.impl;

import com.google.flatbuffers.Table;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.graph.UIGraphStructure;
import org.nd4j.graph.UIInfoType;
import org.nd4j.graph.UIStaticInfoRecord;
import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.common.primitives.Pair;

import java.io.File;
import java.io.IOException;
import java.util.*;

public class UIListener extends BaseListener {

    /**
     * Default: FileMode.CREATE_OR_APPEND<br>
     * The mode for handling behaviour when an existing UI file already exists<br>
     * CREATE: Only allow new file creation. An exception will be thrown if the log file already exists.<br>
     * APPEND: Only allow appending to an existing file. An exception will be thrown if: (a) no file exists, or (b) the
     * network configuration in the existing log file does not match the current log file.<br>
     * CREATE_OR_APPEND: As per APPEND, but create a new file if none already exists.<br>
     * CREATE_APPEND_NOCHECK: As per CREATE_OR_APPEND, but no exception will be thrown if the existing model does not
     * match the current model structure. This mode is not recommended.<br>
     */
    public enum FileMode {CREATE, APPEND, CREATE_OR_APPEND, CREATE_APPEND_NOCHECK}

    /**
     * Used to specify how the Update:Parameter ratios are computed. Only relevant when the update ratio calculation is
     * enabled via {@link Builder#updateRatios(int, UpdateRatio)}; update ratio collection is disabled by default<br>
     * L2: l2Norm(updates)/l2Norm(parameters) is used<br>
     * MEAN_MAGNITUDE: mean(abs(updates))/mean(abs(parameters)) is used<br>
     */
    public enum UpdateRatio {L2, MEAN_MAGNITUDE}

    /**
     * Used to specify which histograms should be collected. Histogram collection is disabled by default, but can be
     * enabled via {@link Builder#histograms(int, HistogramType...)}. Note that multiple histogram types may be collected simultaneously.<br>
     * Histograms may be collected for:<br>
     * PARAMETERS: All trainable parameters<br>
     * PARAMETER_GRADIENTS: Gradients corresponding to the trainable parameters<br>
     * PARAMETER_UPDATES: All trainable parameter updates, before they are applied during training (updates are gradients after applying updater and learning rate etc)<br>
     * ACTIVATIONS: Activations - ARRAY type SDVariables - those that are not constants, variables or placeholders<br>
     * ACTIVATION_GRADIENTS: Activation gradients
     */
    public enum HistogramType {PARAMETERS, PARAMETER_GRADIENTS, PARAMETER_UPDATES, ACTIVATIONS, ACTIVATION_GRADIENTS}


    private FileMode fileMode;
    private File logFile;
    private int lossPlotFreq;
    private int performanceStatsFrequency;
    private int updateRatioFrequency;
    private UpdateRatio updateRatioType;
    private int histogramFrequency;
    private HistogramType[] histogramTypes;
    private int opProfileFrequency;
    private Map<Pair<String,Integer>, List<Evaluation.Metric>> trainEvalMetrics;
    private int trainEvalFrequency;
    private TestEvaluation testEvaluation;
    private int learningRateFrequency;

    private MultiDataSet currentIterDataSet;

    private LogFileWriter writer;
    private boolean wroteLossNames;
    private boolean wroteLearningRateName;

    private Set<String> relevantOpsForEval;
    private Map<Pair<String,Integer>,Evaluation> epochTrainEval;
    private boolean wroteEvalNames;
    private boolean wroteEvalNamesIter;

    private int firstUpdateRatioIter = -1;

    private boolean checkStructureForRestore;

    private UIListener(Builder b){
        fileMode = b.fileMode;
        logFile = b.logFile;
        lossPlotFreq = b.lossPlotFreq;
        performanceStatsFrequency = b.performanceStatsFrequency;
        updateRatioFrequency = b.updateRatioFrequency;
        updateRatioType = b.updateRatioType;
        histogramFrequency = b.histogramFrequency;
        histogramTypes = b.histogramTypes;
        opProfileFrequency = b.opProfileFrequency;
        trainEvalMetrics = b.trainEvalMetrics;
        trainEvalFrequency = b.trainEvalFrequency;
        testEvaluation = b.testEvaluation;
        learningRateFrequency = b.learningRateFrequency;

        switch (fileMode){
            case CREATE:
                Preconditions.checkState(!logFile.exists(), "Log file already exists and fileMode is set to CREATE: %s\n" +
                        "Either delete the existing file, specify a path that doesn't exist, or set the UIListener to another mode " +
                        "such as CREATE_OR_APPEND", logFile);
                break;
            case APPEND:
                Preconditions.checkState(logFile.exists(), "Log file does not exist and fileMode is set to APPEND: %s\n" +
                        "Either specify a path to an existing log file for this model, or set the UIListener to another mode " +
                        "such as CREATE_OR_APPEND", logFile);
                break;
        }

        if(logFile.exists())
            restoreLogFile();

    }

    protected void restoreLogFile(){
        if(logFile.length() == 0 && fileMode == FileMode.CREATE_OR_APPEND || fileMode == FileMode.APPEND){
            logFile.delete();
            return;
        }

        try {
            writer = new LogFileWriter(logFile);
        } catch (IOException e){
            throw new RuntimeException("Error restoring existing log file at path: " + logFile.getAbsolutePath(), e);
        }

        if(fileMode == FileMode.APPEND || fileMode == FileMode.CREATE_OR_APPEND){
            //Check the graph structure, if it exists.
            //This is to avoid users creating UI log file with one network configuration, then unintentionally appending data
            // for a completely different network configuration

            LogFileWriter.StaticInfo si;
            try {
                si = writer.readStatic();
            } catch (IOException e){
                throw new RuntimeException("Error restoring existing log file, static info at path: " + logFile.getAbsolutePath(), e);
            }

            List<Pair<UIStaticInfoRecord, Table>> staticList = si.getData();
            if(si != null) {
                for (int i = 0; i < staticList.size(); i++) {
                    UIStaticInfoRecord r = staticList.get(i).getFirst();
                    if (r.infoType() == UIInfoType.GRAPH_STRUCTURE){
                        //We can't check structure now (we haven't got SameDiff instance yet) but we can flag it to check on first iteration
                        checkStructureForRestore = true;
                    }
                }
            }

        }
    }

    protected void checkStructureForRestore(SameDiff sd){
        LogFileWriter.StaticInfo si;
        try {
            si = writer.readStatic();
        } catch (IOException e){
            throw new RuntimeException("Error restoring existing log file, static info at path: " + logFile.getAbsolutePath(), e);
        }

        List<Pair<UIStaticInfoRecord, Table>> staticList = si.getData();
        if(si != null) {
            UIGraphStructure structure = null;
            for (int i = 0; i < staticList.size(); i++) {
                UIStaticInfoRecord r = staticList.get(i).getFirst();
                if (r.infoType() == UIInfoType.GRAPH_STRUCTURE){
                    structure = (UIGraphStructure) staticList.get(i).getSecond();
                    break;
                }
            }

            if(structure != null){
                int nInFile = structure.inputsLength();
                List<String> phs = new ArrayList<>(nInFile);
                for( int i=0; i<nInFile; i++ ){
                    phs.add(structure.inputs(i));
                }

                List<String> actPhs = sd.inputs();
                if(actPhs.size() != phs.size() || !actPhs.containsAll(phs)){
                    throw new IllegalStateException("Error continuing collection of UI stats in existing model file " + logFile.getAbsolutePath() +
                            ": Model structure differs. Existing (file) model placeholders: " + phs + " vs. current model placeholders: " + actPhs +
                            ". To disable this check, use FileMode.CREATE_APPEND_NOCHECK though this may result issues when rendering data via UI");
                }

                //Check variables:
                int nVarsFile = structure.variablesLength();
                List<String> vars = new ArrayList<>(nVarsFile);
                for( int i=0; i<nVarsFile; i++ ){
                    vars.add(structure.variables(i).name());
                }
                List<SDVariable> sdVars = sd.variables();
                List<String> varNames = new ArrayList<>(sdVars.size());
                for(SDVariable v : sdVars){
                    varNames.add(v.name());
                }

                if(varNames.size() != vars.size() || !varNames.containsAll(vars)){
                    int countDifferent = 0;
                    List<String> different = new ArrayList<>();
                    for(String s : varNames){
                        if(!vars.contains(s)){
                            countDifferent++;
                            if(different.size() < 10){
                                different.add(s);
                            }
                        }
                    }
                    StringBuilder msg = new StringBuilder();
                    msg.append("Error continuing collection of UI stats in existing model file ")
                            .append(logFile.getAbsolutePath())
                            .append(": Current model structure differs vs. model structure in file - ").append(countDifferent).append(" variable names differ.");
                    if(different.size() == countDifferent){
                        msg.append("\nVariables in new model not present in existing (file) model: ").append(different);
                    } else {
                        msg.append("\nFirst 10 variables in new model not present in existing (file) model: ").append(different);
                    }
                    msg.append("\nTo disable this check, use FileMode.CREATE_APPEND_NOCHECK though this may result issues when rendering data via UI");

                    throw new IllegalStateException(msg.toString());
                }
            }
        }

        checkStructureForRestore = false;
    }



    protected void initalizeWriter(SameDiff sd) {
        try{
            initializeHelper(sd);
        }catch (IOException e){
            throw new RuntimeException(e);
        }
    }

    protected void initializeHelper(SameDiff sd) throws IOException {
        writer = new LogFileWriter(logFile);

        //Write graph structure:
        writer.writeGraphStructure(sd);

        //Write system info:
        //TODO

        //All static info completed
        writer.writeFinishStaticMarker();
    }

    @Override
    public boolean isActive(Operation operation) {
        return operation == Operation.TRAINING;
    }

    @Override
    public void epochStart(SameDiff sd, At at) {
        epochTrainEval = null;
    }

    @Override
    public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {

        //If any training evaluation, report it here:
        if(epochTrainEval != null){
            long time = System.currentTimeMillis();
            for(Map.Entry<Pair<String,Integer>,Evaluation> e : epochTrainEval.entrySet()){
                String n = "evaluation/" + e.getKey().getFirst();   //TODO what if user does same eval with multiple labels? Doesn't make sense... add validation to ensure this?

                List<Evaluation.Metric> l = trainEvalMetrics.get(e.getKey());
                for(Evaluation.Metric m : l) {
                    String mName = n + "/train/" + m.toString().toLowerCase();
                    if (!wroteEvalNames) {
                        if(!writer.registeredEventName(mName)) {    //Might have been registered if continuing training
                            writer.registerEventNameQuiet(mName);
                        }
                    }

                    double score = e.getValue().scoreForMetric(m);
                    try{
                        writer.writeScalarEvent(mName, LogFileWriter.EventSubtype.EVALUATION, time, at.iteration(), at.epoch(), score);
                    } catch (IOException ex){
                        throw new RuntimeException("Error writing to log file", ex);
                    }
                }

                wroteEvalNames = true;
            }
        }

        epochTrainEval = null;
        return ListenerResponse.CONTINUE;
    }

    @Override
    public void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlMs) {
        if(writer == null)
            initalizeWriter(sd);
        if(checkStructureForRestore)
            checkStructureForRestore(sd);

        //If there's any evaluation to do in opExecution method, we'll need this there
        currentIterDataSet = data;
    }

    @Override
    public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
        long time = System.currentTimeMillis();

        //iterationDone method - just writes loss values (so far)

        if(!wroteLossNames){
            for(String s : loss.getLossNames()){
                String n = "losses/" + s;
                if(!writer.registeredEventName(n)) {    //Might have been registered if continuing training
                    writer.registerEventNameQuiet(n);
                }
            }

            if(loss.numLosses() > 1){
                String n = "losses/totalLoss";
                if(!writer.registeredEventName(n)) {    //Might have been registered if continuing training
                    writer.registerEventNameQuiet(n);
                }
            }
            wroteLossNames = true;
        }

        List<String> lossNames = loss.getLossNames();
        double[] lossVals = loss.getLosses();
        for( int i=0; i<lossVals.length; i++ ){
            try{
                String eventName = "losses/" + lossNames.get(i);
                writer.writeScalarEvent(eventName, LogFileWriter.EventSubtype.LOSS, time, at.iteration(), at.epoch(), lossVals[i]);
            } catch (IOException e){
                throw new RuntimeException("Error writing to log file", e);
            }
        }

        if(lossVals.length > 1){
            double total = loss.totalLoss();
            try{
                String eventName = "losses/totalLoss";
                writer.writeScalarEvent(eventName, LogFileWriter.EventSubtype.LOSS, time, at.iteration(), at.epoch(), total);
            } catch (IOException e){
                throw new RuntimeException("Error writing to log file", e);
            }
        }

        currentIterDataSet = null;

        if(learningRateFrequency > 0){
            //Collect + report learning rate
            if(!wroteLearningRateName){
                String name = "learningRate";
                if(!writer.registeredEventName(name)) {
                    writer.registerEventNameQuiet(name);
                }
                wroteLearningRateName = true;
            }

            if(at.iteration() % learningRateFrequency == 0) {
                IUpdater u = sd.getTrainingConfig().getUpdater();
                if (u.hasLearningRate()) {
                    double lr = u.getLearningRate(at.iteration(), at.epoch());
                    try {
                        writer.writeScalarEvent("learningRate", LogFileWriter.EventSubtype.LEARNING_RATE, time, at.iteration(), at.epoch(), lr);
                    } catch (IOException e){
                        throw new RuntimeException("Error writing to log file");
                    }
                }
            }
        }
    }



    @Override
    public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {


        //Do training set evaluation, if required
        //Note we'll do it in opExecution not iterationDone because we can't be sure arrays will be stil be around in the future
        //i.e., we'll eventually add workspaces and clear activation arrays once they have been consumed
        if(at.operation() == Operation.TRAINING && trainEvalMetrics != null && trainEvalMetrics.size() > 0){
            long time = System.currentTimeMillis();

            //First: check if this op is relevant at all to evaluation...
            if(relevantOpsForEval == null){
                //Build list for quick lookups to know if we should do anything for this op
                relevantOpsForEval = new HashSet<>();
                for (Pair<String, Integer> p : trainEvalMetrics.keySet()) {
                    Variable v = sd.getVariables().get(p.getFirst());
                    String opName = v.getOutputOfOp();
                    Preconditions.checkState(opName != null, "Cannot evaluate on variable of type %s - variable name: \"%s\"",
                            v.getVariable().getVariableType(), opName);
                    relevantOpsForEval.add(v.getOutputOfOp());
                }
            }

            if(!relevantOpsForEval.contains(op.getName())){
                //Op outputs are not required for eval
                return;
            }

            if(epochTrainEval == null) {
                epochTrainEval = new HashMap<>();

                for (Pair<String, Integer> p : trainEvalMetrics.keySet()) {
                    epochTrainEval.put(p, new Evaluation());
                }
            }

            //Perform evaluation:
            boolean wrote = false;
            for (Pair<String, Integer> p : trainEvalMetrics.keySet()) {
                int idx = op.getOutputsOfOp().indexOf(p.getFirst());
                INDArray out = outputs[idx];
                INDArray label = currentIterDataSet.getLabels(p.getSecond());
                INDArray mask = currentIterDataSet.getLabelsMaskArray(p.getSecond());

                epochTrainEval.get(p).eval(label, out, mask);

                if(trainEvalFrequency > 0 && at.iteration() > 0 && at.iteration() % trainEvalFrequency == 0){
                    for(Evaluation.Metric m : trainEvalMetrics.get(p)) {
                        String n = "evaluation/train_iter/" + p.getKey() + "/" + m.toString().toLowerCase();
                        if (!wroteEvalNamesIter) {
                            if(!writer.registeredEventName(n)) {    //Might have been written previously if continuing training
                                writer.registerEventNameQuiet(n);
                            }
                            wrote = true;
                        }

                        double score = epochTrainEval.get(p).scoreForMetric(m);

                        try {
                            writer.writeScalarEvent(n, LogFileWriter.EventSubtype.EVALUATION, time, at.iteration(), at.epoch(), score);
                        } catch (IOException e) {
                            throw new RuntimeException("Error writing to log file");
                        }
                    }
                }
            }
            wroteEvalNamesIter = wrote;
        }
    }

    @Override
    public void preUpdate(SameDiff sd, At at, Variable v, INDArray update) {
        if(writer == null)
            initalizeWriter(sd);

        if(updateRatioFrequency > 0 && at.iteration() % updateRatioFrequency == 0){
            if(firstUpdateRatioIter < 0) {
                firstUpdateRatioIter = at.iteration();
            }

            if(firstUpdateRatioIter == at.iteration()){
                //Register name
                String name = "logUpdateRatio/" + v.getName();
                if(!writer.registeredEventName(name)){  //Might have already been registered if continuing
                    writer.registerEventNameQuiet(name);
                }
            }

            double params;
            double updates;
            if(updateRatioType == UpdateRatio.L2) {
                params = v.getVariable().getArr().norm2Number().doubleValue();
                updates = update.norm2Number().doubleValue();
            } else {
                //Mean magnitude - L1 norm divided by N. But in the ratio later, N cancels out...
                params = v.getVariable().getArr().norm1Number().doubleValue();
                updates = update.norm1Number().doubleValue();
            }

            double ratio = updates / params;
            if(params == 0.0) {
                ratio = 0.0;
            } else {
                ratio = Math.max(-10, Math.log10(ratio));   //Clip to -10, when updates are too small
            }


            try{
                String name = "logUpdateRatio/" + v.getName();
                writer.writeScalarEvent(name, LogFileWriter.EventSubtype.LOSS, System.currentTimeMillis(), at.iteration(), at.epoch(), ratio);
            } catch (IOException e){
                throw new RuntimeException("Error writing to log file", e);
            }
        }
    }





    public static Builder builder(File logFile){
        return new Builder(logFile);
    }

    public static class Builder {

        private FileMode fileMode = FileMode.CREATE_OR_APPEND;
        private File logFile;

        private int lossPlotFreq = 1;
        private int performanceStatsFrequency = -1;     //Disabled by default

        private int updateRatioFrequency = -1;          //Disabled by default
        private UpdateRatio updateRatioType = UpdateRatio.MEAN_MAGNITUDE;

        private int histogramFrequency = -1;            //Disabled by default
        private HistogramType[] histogramTypes;

        private int opProfileFrequency = -1;            //Disabled by default

        private Map<Pair<String,Integer>, List<Evaluation.Metric>> trainEvalMetrics;
        private int trainEvalFrequency = 10;            //Report evaluation metrics every 10 iterations by default

        private TestEvaluation testEvaluation = null;

        private int learningRateFrequency = 10;         //Whether to plot learning rate or not

        public Builder(@NonNull File logFile){
            this.logFile = logFile;
        }

        public Builder fileMode(FileMode fileMode){
            this.fileMode = fileMode;
            return this;
        }

        public Builder plotLosses(int frequency){
            this.lossPlotFreq = frequency;
            return this;
        }

        public Builder performanceStats(int frequency){
            this.performanceStatsFrequency = frequency;
            return this;
        }

        public Builder trainEvaluationMetrics(String name, int labelIdx, Evaluation.Metric... metrics){
            if(trainEvalMetrics == null){
                trainEvalMetrics = new LinkedHashMap<>();
            }
            Pair<String,Integer> p = new Pair<>(name, labelIdx);
            if(!trainEvalMetrics.containsKey(p)){
                trainEvalMetrics.put(p, new ArrayList<Evaluation.Metric>());
            }
            List<Evaluation.Metric> l = trainEvalMetrics.get(p);
            for(Evaluation.Metric m : metrics){
                if(!l.contains(m)){
                    l.add(m);
                }
            }
            return this;
        }

        public Builder trainAccuracy(String name, int labelIdx){
            return trainEvaluationMetrics(name, labelIdx, Evaluation.Metric.ACCURACY);
        }

        public Builder trainF1(String name, int labelIdx){
            return trainEvaluationMetrics(name, labelIdx, Evaluation.Metric.F1);
        }

        public Builder trainEvalFrequency(int trainEvalFrequency){
            this.trainEvalFrequency = trainEvalFrequency;
            return this;
        }

        public Builder updateRatios(int frequency){
            return updateRatios(frequency, UpdateRatio.MEAN_MAGNITUDE);
        }

        public Builder updateRatios(int frequency, UpdateRatio ratioType){
            this.updateRatioFrequency = frequency;
            this.updateRatioType = ratioType;
            return this;
        }

        public Builder histograms(int frequency, HistogramType... types){
            this.histogramFrequency = frequency;
            this.histogramTypes = types;
            return this;
        }

        public Builder profileOps(int frequency){
            this.opProfileFrequency = frequency;
            return this;
        }

        public Builder testEvaluation(TestEvaluation testEvalConfig){
            this.testEvaluation = testEvalConfig;
            return this;
        }

        public Builder learningRate(int frequency){
            this.learningRateFrequency = frequency;
            return this;
        }

        public UIListener build(){
            return new UIListener(this);
        }
    }

    public static class TestEvaluation {
        //TODO
    }
}