maael/node-ann

View on GitHub
lib/ann.js

Summary

Maintainability
F
5 days
Test Coverage
var perceptron = require('./perceptron');
var ann = function(options) {
    /*
    *   Network options, that shortcircuit to defaults
    */
    options = options || {};
    options.learningStep = options.learningStep || 0.1;
    options.initialisationMethod = options.initialisationMethod || 'normal';
    options.dataFormat = options.dataFormat || [];
    options.epochs = options.epochs || 10000;
    options.report = options.report || false;
    options.momentum = options.momentum || 0.9;
    options.assessment = options.assessment || 'RMSE';
    options.errorThreshold = ((options.errorThreshold === 0) ? 0 : (options.errorThreshold || 0.05));
    /*
    *   Structural variables for network
    */
    var layers = [],
        perceptrons = [],
        weightMatrix = {},
        reportText = '',
        activations = {},
        deltas = {};
    /*
    * Load network from network created from getNetwork output TODO
    */
    function createNetwork(network) {
        var recreatedPerceptrons = network.perceptrons.map(function (networkPerceptron) {
            return new perceptron().createPerceptron(networkPerceptron);
        });
        options = network.netOptions;
        layers = network.layers;
        perceptrons = recreatedPerceptrons;
        weightMatrix = network.weightMatrix;
        activations = network.activations;
        deltas = network.deltas;
        return this;
    }
    /* 
    * Inputs - weighting: {to: , from:, weight: }
    * to and from must be the IDs of perceptrons in the network
    * If weight is not set, 1 is used
    */
    function addWeighting(weighting) {
        var from = findPerceptron(weighting.from),
            to = findPerceptron(weighting.to),
            weight = weighting.weight || 1;
        if (weightMatrix[weighting.from] === undefined) { weightMatrix[weighting.from] = {}; }
        if (weightMatrix[weighting.to] === undefined) { weightMatrix[weighting.to] = {}; }
        weightMatrix[weighting.from][weighting.to] = weight;
        weightMatrix[weighting.to][weighting.from] = weight;
        perceptrons[from].addOutput(weighting.to); 
        perceptrons[to].addInput(weighting.from);
    }
    /* Getter for weight matrix, exposed for tests */
    function getWeightings() {
        return weightMatrix;
    }
    /*
    * Inputs - layerPerceptrons: [PerceptronIDs as Strings] 
    * Adds layer concept to network
    */
    function addLayer(layerPerceptrons) {
        layers.push(layerPerceptrons);
    }
    /* Getter for specific layer */
    function getLayer(index) {
        return layers[index];
    }
    /* Fully connects two layers */
    function fullyInterconnectLayers(fromLayer, toLayer) {
        var from = layers[fromLayer],
            to = layers[toLayer],
            i, j;
        for(i = 0; i < from.length; i++) {
            for(j = 0; j < to.length; j++) {
                addWeighting({from: from[i], to: to[j]});
            }
        }
    }
    /*
    * Inputs - perceptron: Object initialised via perceptron module
    * Adds perceptron to network
    */
    function addPerceptron(perceptron) {
        perceptrons.push(perceptron);
    }
    /*
    * Inputs - type: 'input', 'hidden', 'output'
    * Returns the specified type if type is a valid option,
    * Else turns all the perceptrons
    */
    function getPerceptrons(type) {
        var result = [];
        type = type || '';
        if (type === 'input' || type === 'hidden' || type === 'output') {
            result = perceptrons.reduce(function (perceptrons, perceptron) {
                if (perceptron.getType() === type) perceptrons.push(perceptron);
                return perceptrons;
            }, []);
        } else {
            result = perceptrons;
        }
        return result;
    }
    /*
    * Getter for a specific perceptron as specified by id String
    * Returns perceptron
    */
    function getPerceptron(id) {
        var result, i;
        for(i = 0; i < perceptrons.length; i++) {
            if(perceptrons[i].getID() === id) {
                result = perceptrons[i];
                break;
            }
        }
        return result;
    }
    /*
    * Find perceptron in internal perceptron list
    * Returns index in internal perceptron list
    */
    function findPerceptron(id) {
        var i;
        for(i = 0; i < perceptrons.length; i++) {
            if(perceptrons[i].getID() === id) { break; }
        }
        return i;
    }
    /*
    * Network Initialisation
    */
    function initialise() {
        /* Recommended Initialisation */
        function startWeightsNormal(n) {
            var max = -2/n,
                min = 2/n;
            return Math.random() * (max - min + 1) + min;
        }
        /* Uses Bishop idea of using gaussian functions*/
        function startWeightsGaussian(n) {
            function gaussian(n, m) {
                var mean = 0,
                    sd = (1/Math.sqrt(n));
                return (1 / Math.pow(sd * Math.sqrt(2 * Math.PI), Math.pow(Math.E, -(Math.pow(m - mean, 2) / (2 * Math.pow(sd, 2))))));
            }
            return gaussian(n, Math.random());
        }
        /* Initialises weights using specified method */
        var weightFunction = ((options.initialisationMethod === 'gaussian') ? startWeightsGaussian : startWeightsNormal),
            numberOfInputs = getPerceptrons('input').length;
        for (var from in weightMatrix) {
            if (weightMatrix.hasOwnProperty(from)) {
                for (var to in weightMatrix[from]) {
                    if (weightMatrix[from].hasOwnProperty(to)) {
                        weightMatrix[from][to] = weightFunction(numberOfInputs);
                        weightMatrix[to][from] = weightFunction(numberOfInputs);
                    }
                }
            }
        }
        return this;
    }
    /* Extract example from set */
    function getOutputlessSet(set, setNumber) {
        return (
        { 
            inputSet: set.map(function (set) {
                return set[setNumber];
            }),
            outputSet: [ set[(set.length - 1)][setNumber] ]
        });
    }
    /*
    * Network Assessment
    */
    function assess(assessmentSet) {
        /*  Iterate through set, fetching examples
        *   Save predictions and actual values for the example
        *   Also calculate difference
        *   Use to calculate differenceSquared sum and differenceOverActualSquared sum
        *   These are used in measure of difference calculations
        */
        var test, error, errors = {}, predicted = [], actual = [], 
            difference, differenceSum = 0, differenceSquared = 0, differenceOverActualSquared = 0,
            n = assessmentSet[0].length, sst = 0, ssr = 0;
        for(var i = 0; i < n; i++) {
            test = getOutputlessSet(assessmentSet, i),
            predicted.push(solve(test.inputSet));
            actual.push(test.outputSet);
            difference = predicted[i] - actual[i];
            differenceSquared += Math.pow(difference, 2);
            differenceOverActualSquared += Math.pow((difference / actual[i]), 2);
            differenceSum += difference; 
        }
        /* Calculate Total Sum of Squares for CE */
        for(var i = 0; i < n; i++) {
            sst = Math.pow(((predicted[i] - actual[i]) - (differenceSum / n)), 2);
            ssr = Math.pow((predicted[i] - actual[i]), 2);
        }
        /* Iterate through results, calculating measures of difference */
        for(var i = 0; i < predicted.length; i++) { 
                errors.RMSE = Math.sqrt(differenceSquared / n);
                errors.MSRE = (1 / n) * differenceOverActualSquared;
                errors.CE = 1 - (differenceSquared / sst);
                errors.RSqr = 1 - Math.pow((ssr / Math.sqrt(sst)), 2);
        }
        error = errors[options.assessment];
        reportText += options.assessment+' Error: ' + error.toString();
        return error;
    }
    /*
    * Method to report status via progress bar and continual RMSE output
    * Used if options.report is true
    */
    function report(count, epochs) {
        function progressUpdate(count, epochs) { 
            var percentage = ((count/epochs) * 100).toFixed(0), 
                barLength = 25, 
                bars = Math.floor(barLength * (percentage / 100)), 
                progress = '[';
            for(var i = 0; i < bars; i++) { progress += '#'; } 
            for(var i = 0; i < (barLength - bars); i++) { progress += ' '; } 
            progress += '] ' + percentage + '% '; 
            progress += count + '/' + epochs + '';
            return progress; 
        }
        process.stdout.clearLine(); 
        process.stdout.cursorTo(0); 
        process.stdout.write(progressUpdate(count, epochs) + ' | ' + reportText);
        reportText = '';
    }
    /* Training Forward Pass */
    function forwardPass(example, activations) {
        /* Calculates sum of input activations * weights for a perceptron */
        function sumInputs(perceptron) {
            var sum = 0, inputs = perceptron.getInputs();
            sum = inputs.reduce(function (sum, input) {
                return sum +  (activations[input] * weightMatrix[input][perceptron.getID()]);
            }, 0);
            sum += perceptron.getBias();
            return sum;
        }
        var perceptron, perceptronID,
            sum = 0, sums = {},
            inputs = getPerceptrons('input');

        /* Set initial activations for inputs */
        for(var i = 0; i < inputs.length; i++) {
            activations[inputs[i].getID()] = ((example.hasOwnProperty('inputs')) ? example.inputs[i] : example[i][0]);
        }
        /* Calculate non-input sums and activations */
        for(var i = 1; i < layers.length; i++) {
            for(var j = 0; j < layers[i].length; j++) {
                perceptron = getPerceptron(layers[i][j]);
                perceptronID = perceptron.getID();
                sum = sumInputs(perceptron);
                sums[perceptronID] = sum;
                activations[perceptronID] = (1/(1+(Math.pow(Math.E, -sum))));
            }
        }
        return activations;
    }
    /*
    * Network Training
    * Uses validation set to assess network against
    * TODO: Alter so can be further trained after being trained
    */
    function train(trainingSet, validationSet) {
        var exampleNumber = -1,
            error, previousError,
            count = 0, continueLoop = true, outputs = [], 
            example;
        function getError() {
            previousError = error;
            error = assess(validationSet);
        }
        /* 
        * Checks to see if training should continue
        * False if error has fallen below the threshold specified by options.errorThreshold
        */
        function continueCheck() {
            var check = true;
            getError();
            if(error < options.errorThreshold) { check = false; }
            return check;
        }
        /* Fetches example from training set */
        function getNextExample(trainingSet) {
            exampleNumber++;
            var example = {inputs: [], outputs: []};
            for(var i = 0; i < trainingSet.length; i++) {
                if(options.dataFormat[i] === 'output') {
                    example.outputs.push(trainingSet[i][exampleNumber]);
                } else {
                    example.inputs.push(trainingSet[i][exampleNumber]);
                }
            }
            return example;
        }
        /* Backpropagation Learning */
        function backwardPass(exampleOutput, deltas) {
            /* Sums output delta * weight to output for a perceptron */
            function sumOfOutputWeightings(perceptron) {
                var outputs = perceptron.getOutputs();
                return outputs.reduce(function (sum, output) {
                    return sum + deltas[output] * weightMatrix[perceptron.getID()][output];
                }, 0);
            }
            var p, pID;
            for(var i = (layers.length - 1); i > 0; i--) {
                for(var j = (layers[i].length - 1); j >= 0; j--) {
                    p = getPerceptron(layers[i][j]);
                    pID = p.getID();
                    deltas[pID] = ((p.getType() === 'output') ? (exampleOutput - activations[pID]) : (sumOfOutputWeightings(p)));
                    deltas[pID] = deltas[pID] * activations[pID] * (1 - activations[pID]);
                }
            }
            return deltas;
        }
        /* Propagate learning through network */
        function updateWeights() {
            var updatedWeight, updatedBias, perceptron, pID, outputs, weightChange = 0;
            for(var i = 0; i < layers.length; i++) {
                for(var j = 0; j < layers[i].length; j++) {
                    perceptron = getPerceptron(layers[i][j]);
                    pID = perceptron.getID();
                    outputs = perceptron.getOutputs();
                    for(var k = 0; k < outputs.length; k++) {
                        updatedWeight =  weightMatrix[pID][outputs[k]] + (options.momentum * weightChange) + (options.learningStep * deltas[outputs[k]] * activations[pID]);
                        if(j > 0) { weightChange = (updatedWeight - weightMatrix[pID][outputs[k]]); }
                        weightMatrix[pID][outputs[k]] = updatedWeight;
                        weightMatrix[outputs[k]][pID] = updatedWeight;
                    }
                    if (perceptron.getType() !== 'input') {
                        updatedBias = perceptron.getBias() + (options.learningStep * deltas[pID] * 1);
                        perceptron.updatedBias(updatedBias);
                    }
                }
            }
        }
        /* Main loop of training */
        var outputPerceptrons = getPerceptrons('output');
        //console.time('Training Main Loop');
        while (continueLoop) {
            exampleNumber = -1;
            for (var i = 0; i < trainingSet[0].length; i++) {
                example = getNextExample(trainingSet);
                activations = forwardPass(example, activations);
                outputs.push(activations[outputPerceptrons[(outputPerceptrons.length - 1)].getID()]);
                deltas = backwardPass(example.outputs[0], deltas);
                updateWeights();
            }
            count++;
            continueLoop = continueCheck() && (count < options.epochs);
            if (options.report) { report(count, options.epochs); }
        }
        if (options.report) { process.stdout.write('\n'); }
        //console.timeEnd('Training Main Loop');
    }
    /*
    * Network Solution - calculates prediction for input set via a forward pass
    */
    function solve(solveSet) {
        var activations = forwardPass(solveSet, {}),
            outputPerceptrons = getPerceptrons('output'),
            outputs = [];
        for(var i = 0; i < outputPerceptrons.length; i++) {
            outputs.push(activations[outputPerceptrons[(outputPerceptrons.length - 1)].getID()]);
        }
        return outputs;
    }
    /*
    * Get key components of network
    */
    function getNetwork() {
        var perceptronRepresentation = perceptrons.map(function (perceptron) {
            return perceptron.getPerceptron();
        });
        return {
            netOptions: options,
            layers: layers,
            perceptrons: perceptronRepresentation,
            weightMatrix: weightMatrix,
            activations: activations,
            deltas: deltas
        };
    }
    /*
    * Public exposure for ann
    */
    return {
        createNetwork: createNetwork,
        addWeighting: addWeighting,
        getWeightings: getWeightings,
        addLayer: addLayer,
        getLayer: getLayer,
        fullyInterconnectLayers: fullyInterconnectLayers,
        addPerceptron: addPerceptron,
        getPerceptrons: getPerceptrons,
        getPerceptron: getPerceptron,
        findPerceptron: findPerceptron,
        initialise: initialise,
        train: train,
        solve: solve,
        getNetwork: getNetwork
    };
};

exports.ann = ann;
exports.perceptron = require('./perceptron');