people/a/AlammarJay/js/simple_nn.js

Summary

Maintainability
F
1 wk
Test Coverage
/**
 * Created by alammar on 11/16/16.
 */





var NN_trainer = function (svg_el, table_el, areas, prices, weight, bias, x1, y1, x2, y2,
                           gradientDescentButton, gradientDescent10Button, gradientDescent100Button,
                           gradientDescentConvergeButton, normalize, error_chart_el, heatmap_el,
                           weightRange, biasRange, neuralNetworkGraphEl, analyticsCategory) {
  this.svg_el = svg_el;
  this.table_el = table_el;
  this.areas = areas;
  this.prices = prices;
  this.weight = weight;
  this.bias = bias;
  this.x1 = x1;
  this.y1 = y1;
  this.x2 = x2;
  this.y2 = y2;
  this.data = [{x: this.x1, y: this.y1}, {x: this.x2, y: this.y2}];
  this.prediction = [];
  this.dataPoints = zip([this.areas, this.prices]);
  this.normalize = normalize;
  this.error_chart_el = error_chart_el;
  this.heatmap_el = heatmap_el;
  this.analyticsCategory = analyticsCategory;

  this.graphWidth = 350;

  this.miniGraphWidth = 170;
  this.miniGraphHeight = 180;

  this.weightRange = weightRange;
  this.biasRange = biasRange;

  this.neuralNetworkGraphEl = neuralNetworkGraphEl;

  this.miniGraphMargin = {top: 30, right: 50, bottom: 35, left: 30};

  // Normalization doesn't work quite yet. Actually decided to roll back during implementation
  // because changing the weights and biases between examples would confuse readers.
  if (normalize)
    this.normalizeFeatures(areas, prices);

  this.initializeGraph();

  if (error_chart_el !== "")
    this.initializeErrorGraph();
  if (this.heatmap_el !== "")
    this.initializeHeatmap();
  if (this.neuralNetworkGraphEl !== "")
    this.initializeNeuralNetworkGraph();

  // Set the initial values of the sliders
  this.updateWeightAndBias(this.weight, this.bias, true);

  // Attach events to react to the user moving the sliders
  var trainer_self = this;
  $(this.table_el + "-weightSlider").on("input change", (function () {
    trainer_self.updateWeightAndBias(this.value, -1, true);
  }));

  $(this.table_el + "-biasSlider").on("input change", (function () {
    trainer_self.updateWeightAndBias(-1, this.value, true);
  }));

  // Attach Analytics events to when a user interacts with the sliders
  $(this.table_el + "-weightSlider").on("mouseup touchend", (function () {
    ga('send', 'event', trainer_self.analyticsCategory, "Interacted with", "Weight slider");
    _paq.push(['trackEvent', trainer_self.analyticsCategory, "Interacted with", "Weight slider"]);
  }));

  $(this.table_el + "-biasSlider").on("mouseup touchend", (function () {
    ga('send', 'event', trainer_self.analyticsCategory, "Interacted with", "Bias slider");
    _paq.push(['trackEvent', trainer_self.analyticsCategory, "Interacted with", "Bias slider"]);
  }));

  // Attach events to the gradient descent buttons if they exist
  if (gradientDescentButton !== '') {
    $(this.table_el + " " + gradientDescentButton).click(function () {
      trainer_self.gradientDescentStep(1);
      ga('send', 'event', trainer_self.analyticsCategory, "Clicked on", "1 Gradient Descent Step");
    });
  }
  if (gradientDescent10Button !== '') {
    $(this.table_el + " " + gradientDescent10Button).click(function () {
      trainer_self.gradientDescentStep(10);
      ga('send', 'event', trainer_self.analyticsCategory, "Clicked on", "10 Gradient Descent Steps");
    });
  }
  if (gradientDescent100Button !== '') {
    $(this.table_el + " " + gradientDescent100Button).click(function () {
      trainer_self.gradientDescentStep(100);
      ga('send', 'event', trainer_self.analyticsCategory, "Clicked on", "100 Gradient Descent Steps");
    });
  }

  // Update the reading of the weight/bias numbers
  $(this.table_el + "-weightSlider").val(this.weight);
  $(this.table_el + "-biasSlider").val(this.bias);
};


NN_trainer.prototype.normalizeFeatures = function (areas, prices) {
  this.area_std = standardDeviation(areas);
  this.area_mean = average(areas);
  this.areas_normalized = normalizeFeaturesArray(areas, this.area_mean, this.area_std);

  this.prices_std = standardDeviation(prices);
  this.prices_mean = average(prices);
  this.prices_normalized = normalizeFeaturesArray(prices, this.prices_mean, this.prices_std);
};


// Returns a normalized array, given an input array, with the standard deviation and mean of its elements
normalizeFeaturesArray = function (array, mean, std) {
  outputArray = [];
  for (var i = 0; i < array.length; i++)
    outputArray[i] = (array[i] - mean) / std;

  return outputArray;
};


NN_trainer.prototype.initializeGraph = function () {
  this.holder = d3.select(this.svg_el) // select the 'body' element
    .append("svg")           // append an SVG element to the body
    .attr("width", this.graphWidth)      // make the SVG element 449 pixels wide
    .attr("height", 249);    // make the SVG element 249 pixels high
  this.margin = {top: 20, right: 20, bottom: 50, left: 50},
    this.width = +this.holder.attr("width") - this.margin.left - this.margin.right,
    this.height = +this.holder.attr("height") - this.margin.top - this.margin.bottom,
    this.g = this.holder.append("g").attr("transform", "translate(" + this.margin.left + "," + this.margin.top + ")");

  // Initialize scales and axes
  this.x = d3.scaleLinear().rangeRound([0, this.width]);
  this.y = d3.scaleLinear().rangeRound([this.height, 0]);
  this.x.domain([this.x1, this.x2]);
  this.y.domain([this.y1, this.y2]);
  var formatX = d3.tickFormat(d3.format("d"));
  this.x.ticks().map(formatX);

  // define the line
  this.valueline = d3.line()
    .x(function (d) {
      return this.x(d.x);
    }.bind(this))
    .y(function (d) {
      return this.y(d.y);
    }.bind(this));

  // Draw prediction line
  this.g.append("path")
    .attr("class", "line")
    .attr("d", this.valueline(this.data));
  // Draw X axis
  this.g.append("g")
    .attr("class", "axis axis--x")
    .attr("transform", "translate(0," + this.height + ")")
    .call(d3.axisBottom(this.x).ticks(5));
  // Draw Y axis
  this.g.append("g")
    .attr("class", "axis axis--y")
    .call(d3.axisLeft(this.y).ticks(5));
  // Draw datapoints as dots
  this.dataPointDots = this.g.selectAll(this.svg_el + " .dot")
    .data(this.dataPoints)
    .enter().append("circle")
    .attr("class", "dot")
    .attr("r", 3.5)
    .attr("cx", function (d) {
      return this.x(d[0]);
    }.bind(this))
    .attr("cy", function (d) {
      return this.y(d[1]);
    }.bind(this));
};


NN_trainer.prototype.initializeErrorGraph = function () {
  this.error_chart_history_x = 200;       // How many error data points to show
  this.error_chart_history_y = 100000;    // How high the bar goes
  this.error_history = [10000];

  this.miniErrorChartMargin = {top: 30, right: 30, bottom: 35, left: 65};
  this.errorHolder = d3.select(this.error_chart_el) // select the 'body' element
    .append("svg")           // append an SVG element to the body
    .attr("width", this.miniGraphWidth)      // make the SVG element 449 pixels wide
    .attr("height", this.miniGraphHeight);    // make the SVG element 249 pixels high

  this.errorChartWidth = +this.errorHolder.attr("width") -
    this.miniErrorChartMargin.left - this.miniErrorChartMargin.right;
  this.errorChartHeight = +this.errorHolder.attr("height") -
    this.miniErrorChartMargin.top - this.miniErrorChartMargin.bottom;
  this.errorG = this.errorHolder.append("g").attr("transform", "translate("
    + this.miniErrorChartMargin.left + ","
    + this.miniErrorChartMargin.top + ")");

  // Initialize scales and axes
  this.error_x = d3.scaleLinear()
    .rangeRound([0, this.errorChartWidth])
    .domain([this.x1, this.error_chart_history_x]);

  this.error_y = d3.scaleLinear()
    .rangeRound([this.errorChartHeight, 2])
    .domain([1, d3.max(this.error_history, function (d) {
      return d;
    }) * 1.3]);
  this.errorGraphScaleColors = ['#F8CA00', '#feb24c', '#fd8d3c', '#fc4e2a'];
  //Color scale
  this.errorGraphScale = d3.scaleLinear()
    .domain([400, 10000, 100000, 1000000])
    .range(this.errorGraphScaleColors);

  this.errorGraphLine = d3.line()
    .x(function (d, i) {
      return this.error_x(i);
    }.bind(this))
    .y(function (d, i) {
      return this.error_y(d);
    }.bind(this));

  // Draw X axis
  this.errorG.append("g")
    .attr("class", "axis axis--x")
    .attr("transform", "translate(0," + this.errorChartHeight + ")")
    .call(d3.axisBottom(this.error_x).ticks(4));
  // Draw Y axis
  this.errorYAxis = this.errorG.append("g")
    .attr("class", "axis axis--y")
    .call(d3.axisLeft(this.error_y).ticks(5));

  // Y axis label
  this.errorHolder.append("text")
    .attr("text-anchor", "middle")  // this makes it easy to centre the text as the transform is applied to the anchor
    .attr("transform", "translate(" + 15 + ","
      + (this.errorChartHeight / 2 + this.miniErrorChartMargin.top) + ")rotate(-90)")
    .attr("class", "error-axis-label")
    .text("Erreur");


  this.errorG.append("defs").append("clipPath")
    .attr("id", "clip")
    .append("rect")
    .attr("width", this.errorChartWidth)
    .attr("height", this.errorChartHeight);

  var trainer_self = this;

  this.errorG.append("g")
    .attr("clip-path", "url(#clip)")
    .append("path")
    .datum(this.error_history)
    .attr("class", "error-history-line");


  // Chart title
  this.errorHolder.append("text")
    .attr("text-anchor", "middle")  // this makes it easy to centre the text as the transform is applied to the anchor
    .attr("transform", "translate(" +
      ((this.miniErrorChartMargin.left + this.errorChartWidth + this.miniErrorChartMargin.right) / 2)
      + "," +
      20 + ")")
    .attr("class", "chart-title")
    .text("Historique d'erreur");

};

// Adds an error points to the error graph
// https://gist.github.com/mbostock/1642874
NN_trainer.prototype.addErrorPoint = function (value) {
  this.error_history.push(value);
  // Redraw the line.
  d3.select(this.error_chart_el + " .error-history-line")
    .attr("d", this.errorGraphLine)
    .attr("transform", "translate(" + this.error_x(-1) + ",0)");

  // Pop the old data point off the front.
  if (this.error_history.length >= this.error_chart_history_x)
    this.error_history.shift();

  this.rescaleErrorGraph();
};


NN_trainer.prototype.batchAddErrorPoint = function (valuesArray) {

  this.error_history = this.error_history.concat(valuesArray);


  // Cut the needed number of elements to be within our specified error_chart_history_x
  if (this.error_history.length > this.error_chart_history_x) {
    // How much are we over by
    var overage = this.error_history.length - this.error_chart_history_x
    this.error_history.splice(0, overage);

  }

  d3.select(this.error_chart_el + " .error-history-line")
    .datum(this.error_history)
    .attr("d", this.errorGraphLine)
    .attr("transform", "translate(" + this.error_x(-valuesArray.length) + ",0)");


  this.rescaleErrorGraph();

};


NN_trainer.prototype.rescaleErrorGraph = function () {

  //this.error_y.domain([0, new_max_y]);
  this.error_y.domain([1, d3.max(this.error_history, function (d) {
    return d;
  })]);

  this.errorG.select(this.error_chart_el + " .axis--y")
    .call(d3.axisLeft(this.error_y).ticks(5));

  //this.heatmapColorScale
  this.errorG.selectAll(this.error_chart_el + " .axis--y .tick text")
    .attr("fill", function (d) {
      return this.errorGraphScale(d)
    }.bind(this))
};


NN_trainer.prototype.calculatePrediction = function (x) {
  var prediction;
  if (this.normalize) {
    // scale
    var scaled_x = (x - this.area_mean) / this.area_std,
      scaled_prediction = scaled_x * this.weight + this.bias;
    prediction = scaled_prediction * this.prices_std + this.prices_mean;
  } else
    prediction = x * this.weight + this.bias;

  return prediction;
};

NN_trainer.prototype.updateWeightAndBias = function (weight, bias, updateUI) {

  var predictionDataPoints,
    errorLines,
    errorLineValues = [];


  if (weight != -1) this.weight = parseFloat(weight);
  if (bias != -1) this.bias = parseFloat(bias);

  this.data[0].y = this.calculatePrediction(this.data[0].x);
  this.data[1].y = this.calculatePrediction(this.data[1].x);


  // Calculate predictions, and total error
  this.prediction = [];
  var prediction_sum = 0, delta, delta_2, delta_sum = 0, mean_delta_sum = 0;
  for (var i = 0; i < this.areas.length; i++) {
    //this.prediction[i] = this.areas[i] * this.weight + this.bias;
    this.prediction[i] = this.calculatePrediction(this.areas[i]);
    delta = this.prices[i] - this.prediction[i];

    delta_2 = Math.pow(delta, 2);
    delta_sum = delta_sum + delta_2;

    errorLineValues[i] = [{x: this.areas[i], y: this.prices[i]}, {x: this.areas[i], y: this.prediction[i]}]
  }
  mean_delta_sum = delta_sum / this.prediction.length;

  if (updateUI)
    this.updateUI(mean_delta_sum, errorLineValues);


  return mean_delta_sum;
};

NN_trainer.prototype.updateUI = function (mean_delta_sum, errorLineValues) {

  //Update error chart if available
  if (this.error_chart_el !== "")
    this.addErrorPoint(mean_delta_sum);

  // Update the error/weight/bias indicators
  $(this.table_el + " span.weight").text(this.weight.toLocaleString('fr', {maximumFractionDigits: 3}));
  $(this.table_el + " span.bias").text(this.bias.toLocaleString('fr', {maximumFractionDigits: 1}));
  $(this.table_el + " span.error-value").text(numberWithCommas(Math.round(mean_delta_sum)));

  // Update comment on the score
  var messageElem = $(this.table_el + " .error-value-message");
  if (mean_delta_sum < 450) {
    messageElem.html("Honnêtement je ne pensais pas que c'était 'humainement' possible..");
  } else if (mean_delta_sum < 500) {
    messageElem.html("Salut à toi, overlord IA superintelligent..");
  } else if (mean_delta_sum < 600) {
    messageElem.html("Wow wow ! Les doigts dans le nez, <a href='https://en.wikipedia.org/wiki/Yann_LeCun'>LeCun</a> !!");
  } else if (mean_delta_sum < 750) {
    messageElem.text("Bravo ! Vous avez trouvé 750 !");
  } else if (mean_delta_sum < 799) {
    messageElem.text("Bien joué");
  } else if (mean_delta_sum >= 800000) {
    messageElem.text("Vous essayez vraiment, là ?");
  } else if (mean_delta_sum >= 100000) {
    messageElem.text("Encore loin, mon pote");
  } else {
    messageElem.text("");
  }


  // DRAW & UPDATE ERROR LINES
  // Draw the line's predictions for our datapoints as dots
  // DATA JOIN - only really useful the first time. It adds an element for each datapoint
  errorLines = this.g.selectAll(this.svg_el + " .error-line")
    .data(errorLineValues);
  // ENTER + UPDATE
  // Creates the dots the first time
  errorLines.enter().append("path")
    .attr("class", "error-line")
    .attr("d", function (d) {
      return this.valueline(d)
    }.bind(this));
  // UPDATE
  // This updates the coordinates of the prediction dots everytime the line changes
  errorLines.attr("d", function (d) {
    return this.valueline(d)
  }.bind(this));


  // DRAW / UPDATE PREDICTION LINE
  d3.select(this.svg_el + " .line")
    .attr("d", this.valueline(this.data));

  predictionDataPoints = zip([this.areas, this.prediction]);

  // DRAW & UPDATE PREDICTION POINTS
  // Draw the line's predictions for our datapoints as dots
  // DATA JOIN - only really useful the first time. It adds an element for each datapoint
  predictions = this.g.selectAll(this.svg_el + " .prediction-dot")
    .data(predictionDataPoints);
  // ENTER + UPDATE
  // Creates the dots the first time
  predictions.enter().append("circle")
    .attr("class", "prediction-dot")
    .attr("r", 3.5)
    .attr("cx", function (d) {
      return this.x(d[0])
    }.bind(this))
    .attr("cy", function (d) {
      return this.y(d[1])
    }.bind(this));
  // UPDATE
  // This updates the coordinates of the prediction dots everytime the line changes
  predictions.attr("cx", function (d) {
    return this.x(d[0])
  }.bind(this))
    .attr("cy", function (d) {
      return this.y(d[1])
    }.bind(this));

  if (this.heatmap_el != "")
    this.updateHeatmapElement(this.weight, this.bias, mean_delta_sum);


  if (this.neuralNetworkGraphEl != "")
    this.updateNeuralNetworkGraph()

  this.dataPointDots.moveUp();
};


NN_trainer.prototype.gradientDescentStep = function (number_of_steps) {

  // I probably shouldn't do this. I started doing feature normalization so we can keep to one learning rate.
  // I decided to do it this way to maintain narrative continuity.
  this.learningRate = 0.00000001;
  this.learningRate2 = 1;


  var error, errors_array = [], weights_array = [], biases_array = [];
  for (var c = 0; c < number_of_steps; c++) {

    var sum_for_bias = 0, sum_for_weight = 0, bias_mean, weight_mean, bias_adjustment, weight_adjustment,
      new_b, new_w;
    for (var i = 0; i < this.areas.length; i++) {
      sum_for_bias = sum_for_bias + this.prediction[i] - this.prices[i];
      sum_for_weight = sum_for_weight + (this.prediction[i] - this.prices[i]) * this.areas[i];
    }


    bias_mean = sum_for_bias / this.areas.length;
    weightsMeans = sum_for_weight / this.areas.length;

    bias_adjustment = this.learningRate2 * bias_mean;
    weight_adjustment = this.learningRate * weightsMeans;

    new_b = this.bias - bias_adjustment;
    new_w = this.weight - weight_adjustment;

    // Only update the UI on the last step (if we're doing multiple steps
    // And in that case, add the errors to the error graph as a batch
    if (c == number_of_steps - 1) {

      if (errors_array.length != 0) {
        this.batchAddErrorPoint(errors_array);
      }


      this.updateWeightAndBias(new_w, new_b, true);
    } else {
      error = this.updateWeightAndBias(new_w, new_b, false);
      weights_array.push(new_w);
      biases_array.push(new_b);
      errors_array.push(error);
    }
  }

  $(this.table_el + "-weightSlider").val(new_w);
  $(this.table_el + "-biasSlider").val(new_b);


};


NN_trainer.prototype.initializeHeatmap = function () {


  this.heatmapSideNumberOfElements = 15;
  //this.heatmapColors = ['#f7fcf0','#e0f3db','#ccebc5','#a8ddb5','#7bccc4','#4eb3d3','#2b8cbe','#0868ac','#084081'].reverse();
  //this.heatmapColors = ['#F8CA00','#a1dab4','#41b6c4','#225ea8'].reverse();
  this.heatmapColors = ['#fcfc99', '#feb24c', '#fd8d3c', '#fc4e2a'];
  //this.heatmapColors = ['#ffffcc','#feb24c','#fd8d3c','#fc4e2a'];
  this.heatmapEmptyBoxColor = "#fbfbfb";
  this.heatmapData = this.generateHeatmapData(this.heatmapSideNumberOfElements);

  this.heatmapHolder = d3.select(this.heatmap_el) // select the 'body' element
    .append("svg")           // append an SVG element to the body
    .attr("width", this.miniGraphWidth)
    .attr("height", this.miniGraphHeight);

  this.heatmapWidth = +this.heatmapHolder.attr("width") - this.miniGraphMargin.left - this.miniGraphMargin.right;
  this.heatmapHeight = +this.heatmapHolder.attr("height") - this.miniGraphMargin.top - this.miniGraphMargin.bottom;
  this.heatmapG = this.heatmapHolder.append("g").attr("transform", "translate(" + (this.miniGraphMargin.left + 15) + "," + this.miniGraphMargin.top + ")");


  this.heatmapBoxSize = (this.heatmapHeight) / this.heatmapSideNumberOfElements;


  // Initialize scales and axes

  // Scales for the axes
  this.heatmapXAxisScale = d3.scaleLinear()
    .domain(this.weightRange)
    .rangeRound([0, this.heatmapHeight]);
  this.heatmapYAxisScale = d3.scaleLinear()
    .domain(this.biasRange)
    .range([this.heatmapHeight, 0]);


  // Scales to map weight/bias to box number
  // Maps [0, 0.4] to discreet box numbers [1, 2, 3, ... ] for the x axis
  this.heatmapX = d3.scaleQuantile()
    .domain(this.weightRange)
    .range(d3.range(this.heatmapSideNumberOfElements));
  // Maps [0,400] to discreet box numbers [15, 14, 13,, ... 1 ] for the y axis
  this.heatmapY = d3.scaleQuantile()
    .domain(this.biasRange)
    .range(d3.range(this.heatmapSideNumberOfElements).map(function (d) {
      return this.heatmapSideNumberOfElements - 1 - d
    }.bind(this)));

  //Color scale
  this.heatmapColorScale = d3.scaleLinear()
    .domain([400, 10000, 100000, 1000000])
    .range(this.heatmapColors);

  // Draw X axis
  this.heatmapG.append("g")
    .attr("class", "axis axis--x")
    .attr("transform", "translate(0," + this.heatmapHeight + ")")
    .call(d3.axisBottom(this.heatmapXAxisScale).ticks(5));
  // Draw Y axis
  this.heatmapG.append("g")
    .attr("class", "axis axis--y")
    .call(d3.axisLeft(this.heatmapYAxisScale).ticks(5));

  // Weight axis label
  this.heatmapHolder.append("text")
    .attr("text-anchor", "middle")  // this makes it easy to centre the text as the transform is applied to the anchor
    .attr("transform", "translate(" +
      (this.heatmapWidth / 2 + this.miniGraphMargin.left) + ","
      + (this.miniGraphMargin.top + this.heatmapHeight + this.miniGraphMargin.bottom - 5) + ")")
    .attr("class", "weight-axis-label")
    .text("Poids");

  // Bias axis label
  this.heatmapHolder.append("text")
    .attr("text-anchor", "middle")  // this makes it easy to centre the text as the transform is applied to the anchor
    .attr("transform", "translate(" + (this.miniGraphMargin.left / 2) + ","
      + (this.heatmapHeight / 2 + this.miniGraphMargin.top) + ")rotate(-90)")
    .attr("class", "bias-axis-label")
    .text("Biais");

  // Chart title
  this.heatmapHolder.append("text")
    .attr("text-anchor", "middle")  // this makes it easy to centre the text as the transform is applied to the anchor
    .attr("transform", "translate(" +
      ((this.miniGraphMargin.left + this.heatmapWidth + this.miniGraphMargin.right) / 2)
      + "," +
      20 + ")")
    .attr("class", "chart-title")
    .text("Poids vs. biais vs. erreur");


  this.updateHeatmap(this.heatmapData);
};


NN_trainer.prototype.updateHeatmap = function (data) {
  var heatmapBoxSize = this.heatmapBoxSize, heatmapColorScale = this.heatmapColorScale;


  this.heatmap = this.heatmapG.selectAll(".heatmap-square")
    .data(data);

  this.heatmap.enter().append("rect")
    .attr("x", function (d) {
      return (d.x) * heatmapBoxSize
    })
    .attr("y", function (d) {
      return (d.y) * heatmapBoxSize
    })
    .attr("width", heatmapBoxSize)
    .attr("height", heatmapBoxSize)
    .attr("class", "heatmap-square")
    .attr("id", function (d) {
      return "box_" + d.x + "_" + d.y
    })
    .attr("fill", function (d, i) {
      if (d.error == 0)
        return this.heatmapEmptyBoxColor;
      else
        return heatmapColorScale(d.error);
    }.bind(this))
};

NN_trainer.prototype.updateHeatmapElement = function (weight, bias, error) {

  var x = this.heatmapX(weight), y = this.heatmapY(bias), heatmapColorScale = this.heatmapColorScale;

  var r = d3.select("#box_" + x + "_" + y)
    .filter(function (d) {
      return d.error == 0
    })
    .datum({x: x, y: y, error: error});


  r.attr("fill", function (d, i) {
    return heatmapColorScale(d.error)
  })
    .attr("stroke", "#888")
    .attr("stroke-opacity", 1)
    .transition().duration(1000)
    .attr("stroke-opacity", 0);


  //r.enter().append("rect")
  //    .attr("x", function(d){console.log("XX"); return (d.x) * this.heatmapBoxSize})
  //    .attr("y", function(d){return (d.y) * this.heatmapBoxSize})
  //    .attr("width", this.heatmapBoxSize)
  //    .attr("height", this.heatmapBoxSize)
  //    .attr("class", "heatmap-square")
  //    .attr("id", function(d){ return d.x + "_" + d.y})
  //    .attr("fill", function(d, i){ console.log("FILLING") ; if(d.error == 0) return "#a0f0f0";  else  return heatmapColorScale(d.error);})

};


NN_trainer.prototype.generateHeatmapData = function (size) {
  var data = [];
  for (var i = 0; i < size; i++)
    for (var j = 0; j < size; j++)
      data.push({x: i, y: j, error: 0});

  return data;
};


NN_trainer.prototype.initializeNeuralNetworkGraph = function () {
  this.nnGraphHolder = d3.select(this.neuralNetworkGraphEl) // select the 'body' element
    .append("svg")           // append an SVG element to the body
    .attr("width", this.graphWidth)      // make the SVG element 449 pixels wide
    .attr("height", 150);    // make the SVG element 249 pixels high
  this.neuralNetworkMargin = {top: 10, right: 10, bottom: 10, left: 10},
    this.neuralNetworkWidth = +this.nnGraphHolder.attr("width") - this.neuralNetworkMargin.left - this.neuralNetworkMargin.right,
    this.neuralNetworkHeight = +this.nnGraphHolder.attr("height") - this.neuralNetworkMargin.top - this.neuralNetworkMargin.bottom,
    this.neuralNetworkG = this.nnGraphHolder.append("g");

  var nodeRadius = 30;

  // Arrow
  // https://bl.ocks.org/tomgp/d59de83f771ca2b6f1d4
  var defs = this.nnGraphHolder.append("defs");

  defs.append("marker")
    .attrs({
      "id": "arrow",
      "viewBox": "0 -5 10 10",
      "refX": 5,
      "refY": 0,
      "markerWidth": 4,
      "markerHeight": 4,
      "orient": "auto"
    })
    .append("path")
    .attr("d", "M0,-5L10,0L0,5")
    .attr("class", "arrowHead");
  this.neuralNetworkG.append('line')
    .attrs({
      "class": "arrow",
      "marker-end": "url(#arrow)",
      "x1": this.neuralNetworkMargin.left + 2 * nodeRadius,
      "y1": this.neuralNetworkMargin.top + this.neuralNetworkHeight / 2,
      "x2": this.neuralNetworkWidth - 2 * nodeRadius + this.neuralNetworkMargin.left - 8,
      "y2": this.neuralNetworkMargin.top + this.neuralNetworkHeight / 2
    });


  // Input node
  this.inputNode = this.neuralNetworkG
    .append("circle")
    .attr("class", "input-node")
    .attr("r", nodeRadius)
    .attr("cx", this.neuralNetworkMargin.left + nodeRadius)
    .attr("cy", this.neuralNetworkMargin.top + this.neuralNetworkHeight / 2);

  // Weight Node
  this.weightG = this.neuralNetworkG.append("g")
    .attr("transform", "translate(" +
      (this.neuralNetworkMargin.left + nodeRadius + this.neuralNetworkWidth / 3 - 10)
      + ","
      + (this.neuralNetworkMargin.top + this.neuralNetworkHeight / 2) + ")");
  this.weightNode = this.weightG
    .append("ellipse")
    .attr("class", "weightNode")
    .attr("rx", nodeRadius * 1.7)
    .attr("ry", nodeRadius)
    .attr("cx", 0)
    .attr("cy", 0);
  this.weightG.append("text")
    .attr("id", "weightValue")
    .attr("text-anchor", "middle")
    .attr("y", 5)
    .text("");


  // Bias Node
  this.biasG = this.neuralNetworkG.append("g")
    .attr("transform", "translate(" +
      (this.neuralNetworkWidth * 2 / 3 - 20)
      + ","
      + (this.neuralNetworkMargin.top + this.neuralNetworkHeight / 2 - nodeRadius) + ")");
  this.biasNode = this.biasG
    .append("rect")
    .attr("class", "biasNode")
    .attr("width", nodeRadius * 2)
    .attr("height", nodeRadius * 2)
    .attr("rx", nodeRadius / 4)
    .attr("ry", nodeRadius / 4)
    .attr("x", 0)
    .attr("y", 0);
  this.biasG.append("text")
    .attr("id", "biasValue")
    .attr("text-anchor", "middle")
    .attr("x", nodeRadius)
    .attr("y", nodeRadius + 5)
    .text("-");


  // Output node
  this.outputNode = this.neuralNetworkG
    .append("circle")
    .attr("class", "output-node")
    .attr("r", nodeRadius)
    .attr("cx", this.neuralNetworkWidth - nodeRadius + this.neuralNetworkMargin.left)
    .attr("cy", this.neuralNetworkMargin.top + this.neuralNetworkHeight / 2);

};


NN_trainer.prototype.updateNeuralNetworkGraph = function () {
  d3.select(this.neuralNetworkGraphEl + " #weightValue")
    .text(this.weight.toLocaleString('fr', {maximumFractionDigits: 3}));

  d3.select(this.neuralNetworkGraphEl + " #biasValue")
    .text(this.bias.toLocaleString('fr', {maximumFractionDigits: 1}));


}


var trainer = new NN_trainer("#training-one-chart", "#training-one",
  [2104, 1600, 2400], // areas
  [399.900, 329.900, 369.000], // prices
  0.1,    // initial weight
  150,    // initial bias
  0,      // x1
  0,      // y1
  2600,   // x2
  410,    //y2
  "", "", "", "", false, "", "", "", "", "#neural-network-graph",
  "Basics of Neural Networks - Viz 1 weight & bias");


var trainer2 = new NN_trainer("#training-one-gd-chart", "#training-one-gd",
  [2104, 1600, 2400],
  [399.900, 329.900, 369.000],
  0,      // initial weight
  0,      // initial bias
  0,      // x1
  0,      // y1
  2600,   // x2
  410,    //y2
  ".gradient-descent-button",
  ".gradient-descent-10-button",
  ".gradient-descent-100-button",
  ".gradient-descent-converge-button",
  false,
  "#training-one-gd-error-chart",
  "#training-one-gd-heatmap",
  [0, 0.4],
  [0, 460],
  "#neural-network-gd-graph",
  "Basics of Neural Networks - Viz 2 gradient descent"
);