people/a/AlammarJay/feedforward-neural-networks-visual-interactive/index.html

Summary

Maintainability
Test Coverage
<!--#include virtual="/header-start.html" -->
<title>A Visual And Interactive Look at Basic Neural Network Math</title>
<meta content="https://jalammar.github.io/visual-interactive-guide-basics-neural-networks/" name="url">
<meta content="Alammar, Jay" name="author">
<meta content="Jay Alammar" name="copyright">
<meta content="A Visual And Interactive Look at Basic Neural Network Math" property="og:title"/>
<meta content="A Visual And Interactive Look at Basic Neural Network Math" property="twitter:title"/>
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/5.16.0/d3.js" type="text/javascript"></script>
<script src="../js/d3-selection-multi.v0.4.min.js" type="text/javascript"></script>
<script src="../js/d3-jetpack.js" type="text/javascript"></script>
<link href="../style.css" rel="stylesheet" type="text/css"/>
<link href="https://jalammar.github.io/feed.xml" rel="alternate"
    title="Jay Alammar - Visualizing machine learning one concept at a time." type="application/rss+xml"/>
<script type="text/javascript"> var _paq = _paq || [];</script>
<!--#include virtual="/header-end.html" -->
<div class="prediction">
  <p>In the <a href="https://jalammar.github.io/visual-interactive-guide-basics-neural-networks/">previous post, we
    looked at the basic concepts of neural networks</a>. Let us now take another example as an excuse to guide us to
    explore some of the basic mathematical ideas involved in prediction with neural networks.</p>
  <video autoplay="" class="img-div-large" controls="" loop="">
    <source src="titanic_nn_calculation.mp4" type="video/mp4">
    Your browser does not support the video tag.
  </video>
  <p>If you had been aboard the Titanic, would you have survived the sinking event? Let’s build a model to predict one’s
    odds of survival.</p>
  <p>This will be a neural network model building on what we discussed in the <a
      href="https://jalammar.github.io/feedforward-neural-networks-visual-interactive/">previous post</a>, but will have
    a higher prediction accuracy because it utilizes hidden layers and activation functions.</p>
  <p>The dataset we’ll use this time will be the <a
      href="https://jalammar.github.io/feedforward-neural-networks-visual-interactive/">Titanic passenger list</a> from
    Kaggle. It lists the names and other information of the passengers and shows whether each passenger survived the
    sinking event or not.</p>
  <p>The raw dataset looks like this:</p>
  <div class="titanic-dataset">
    <table>
      <thead>
      <tr>
        <th>PassengerId</th>
        <th>Survived</th>
        <th>Pclass</th>
        <th>Name</th>
        <th>Sex</th>
        <th>Age</th>
        <th>SibSp</th>
        <th>Parch</th>
        <th>Ticket</th>
        <th>Fare</th>
        <th>Cabin</th>
        <th>Embarked</th>
      </tr>
      </thead>
      <tbody>
      <tr>
        <td>1</td>
        <td>0</td>
        <td>3</td>
        <td>Braund, Mr. Owen Harris</td>
        <td>male</td>
        <td>22.0</td>
        <td>1</td>
        <td>0</td>
        <td>A/5 21171</td>
        <td>7.2500</td>
        <td>NaN</td>
        <td>S</td>
      </tr>
      <tr>
        <td>2</td>
        <td>1</td>
        <td>1</td>
        <td>Cumings, Mrs. John Bradley (Florence Briggs Th…</td>
        <td>female</td>
        <td>38.0</td>
        <td>1</td>
        <td>0</td>
        <td>PC 17599</td>
        <td>71.2833</td>
        <td>C85</td>
        <td>C</td>
      </tr>
      <tr>
        <td>3</td>
        <td>1</td>
        <td>3</td>
        <td>Heikkinen, Miss. Laina</td>
        <td>female</td>
        <td>26.0</td>
        <td>0</td>
        <td>0</td>
        <td>STON/O2. 3101282</td>
        <td>7.9250</td>
        <td>NaN</td>
        <td>S</td>
      </tr>
      </tbody>
    </table>
  </div>
  <p>We won’t bother with most of the columns for now. We’ll just use the sex and age columns as our features, and
    survival as our label that we’ll try to predict.</p>
  <div class="two_variables">
    <table>
      <thead>
      <tr>
        <th>Age</th>
        <th>Sex</th>
        <th>Survived?</th>
      </tr>
      </thead>
      <tbody>
      <tr>
        <td>22</td>
        <td>0</td>
        <td>0</td>
      </tr>
      <tr>
        <td>38</td>
        <td>1</td>
        <td>1</td>
      </tr>
      <tr>
        <td>26</td>
        <td>1</td>
        <td>1</td>
      </tr>
      <tr>
        <td colspan="3">… 891 rows total</td>
      </tr>
      </tbody>
    </table>
  </div>
  <p>We’ll attempt to build a network that predicts whether a passenger survived or not.</p>
  <p>Neural networks need their inputs to be numeric. So we had to change the sex column – male is now 0, female is 1.
    You’ll notice the dataset already uses something similar for the survival column – survived is 1, did not survive is
    0.</p>
  <p>The simplest neural network we can use to train to make this prediction looks like this:</p>
  <div class="img-div">
    <img alt="neural netowrk with two inputs, one output, and sigmoid output activation"
        src="./index_files/two-input-one-output-sigmoid-network.png"> Calculating a prediction is done by plugging in a
    value for "age" and "sex". The calculation then flows from the left to the right. Before we can use this net for
    prediction, however, we'll have to run a "training" process that will give us the values for the weights (<span
      class="weight-node-text">w</span>) and bias (<span class="bias-node-text">b</span>). <br> Note: we have slightly
    adjusted the way we represent the networks from the previous post. The bias node specifically is more commonly
    represented like this
  </div>
  <p>Let’s recap the elements that make up this network and how they work:</p>
  <div class="row neuron-expo vertical-align">
    <div class="col-sm-4 small-column">
      <p><img alt="input neuron" src="./index_files/input-neuron.png"></p>
    </div>
    <div class="col-sm-8 side-column">
      <p>An input neuron is where we plug in an input value (e.g. the age of a person). It’s where the calculation
        starts. The outgoing connection and the rest of the graph tell us what other calculations we need to do to
        calculate a prediction.</p>
    </div>
  </div>
  <div class="row neuron-expo vertical-align">
    <div class="col-sm-4 small-column">
      <p><img alt="weighted neuron image" src="./index_files/weight.png"></p>
    </div>
    <div class="col-sm-8 side-column">
      <p>If a connection has a weight, then the value is multiplied by that weight as it passes through it.</p>
      <div class="language-plaintext highlighter-rouge">
        <div class="highlight"><pre class="highlight"><code>connection_output = weight * connection_input
</code></pre>
        </div>
      </div>
    </div>
  </div>
  <div class="row neuron-expo vertical-align">
    <div class="col-sm-4 small-column">
      <img alt=" neuron image" src="./index_files/neuron.png">
    </div>
    <div class="col-sm-8 side-column">
      <p>If a neuron has inputs, it sums their value and sends that sum along its outgoing connection(s).</p>
      <div class="language-plaintext highlighter-rouge">
        <div class="highlight"><pre class="highlight"><code>node_output = input_1 + input_2
</code></pre>
        </div>
      </div>
    </div>
  </div>
  <section>
    <h2>Sigmoid</h2>
    <div class="row neuron-expo vertical-align">
      <div class="col-sm-4 small-column">
        <p><img alt="sigmoid neuron" src="./index_files/sigmoid-neuron.png"></p>
      </div>
      <div class="col-sm-8 side-column">
        <p>To turn the network’s calculation into a probability value between 0 and 1, we have to pass the value from
          the output layer through a “sigmoid” formula. Sigmoid squashes the output value of a neuron to between 0 and 1
          according to a specific curve.</p>
        <p>`f(x)=1/(1+e^-x)`</p>
        <p>Where e is the mathematical constant approximately equal to 2.71828</p>
        <div class="language-plaintext highlighter-rouge">
          <div class="highlight"><pre class="highlight"><code>def sigmoid(x):
    return 1/(1 + math.exp(-x))

output = sigmoid(value)
</code></pre>
          </div>
        </div>
      </div>
    </div>
  </section>
  <section>
    <h2>Sigmoid Visualization</h2>
    <div class="row neuron-expo vertical-align">
      <div class="col-sm-4 small-column">
        <!-- ==== SIGMOID ACTIVATION GRAPH ==== -->
        <table>
          <tbody>
          <tr>
            <td class="sigmoid-input-value explicit-slider-weight-value">-1.63</td>
            <td><img alt="sigmoid neuron" src="./index_files/sigmoid-neuron.png"></td>
            <td class="explicit-activation-output-value">0.16383036122</td>
          </tr>
          </tbody>
        </table>
      </div>
      <div class="col-sm-8 side-column">
        <p>Interact a little with sigmoid to see how it transforms various values</p>
        <!-- ==== SIGMOID SLIDER ==== -->
        <table class="activation-graph-slider">
          <tbody>
          <tr>
            <td>
              <input class="weight" id="sigmoid-slider" max="20" min="-20" step="0.01"
                  style="width: 320px; margin-left: 40px;" type="range">
            </td>
            <td class="slider-value">
              <span class="weight sigmoid-input-value">-1.63</span>
            </td>
          </tr>
          </tbody>
        </table>
        <!-- ==== SIGMOID FORMULA ==== -->
        <p style="margin-left:40px">f(<span class="slider-value"><span
            class="sigmoid-input-value weight">-1.63</span></span>) = <span
            id="sigmoid_function" style="font-size:200%">`1/(1+e^-(x))`</span> = <span
            id="sigmoid-result">0.16383036122</span></p>
        <!-- ==== SIGMOID GRAPH ==== -->
        <div id="sigmoid-graph" style="width:100%">
          <svg class="activation-graph" height="160" width="400">
            <g transform="translate(60,30)">
              <g class="x-axis" fill="none" font-family="sans-serif" font-size="10" text-anchor="middle"
                  transform="translate(0,110)">
                <g class="tick" opacity="1" transform="translate(0,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">-20</text>
                </g>
                <g class="tick" opacity="1" transform="translate(35,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">-15</text>
                </g>
                <g class="tick" opacity="1" transform="translate(70,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">-10</text>
                </g>
                <g class="tick" opacity="1" transform="translate(105,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">-5</text>
                </g>
                <g class="tick" opacity="1" transform="translate(140,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">0</text>
                </g>
                <g class="tick" opacity="1" transform="translate(175,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">5</text>
                </g>
                <g class="tick" opacity="1" transform="translate(210,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">10</text>
                </g>
                <g class="tick" opacity="1" transform="translate(245,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">15</text>
                </g>
                <g class="tick" opacity="1" transform="translate(280,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">20</text>
                </g>
              </g>
              <g class="y-axis" fill="none" font-family="sans-serif" font-size="10" text-anchor="end">
                <path class="domain" d="M-6,110.5H0.5V0.5H-6" stroke="#000"></path>
                <g class="tick" opacity="1" transform="translate(0,110)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">0.0</text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,55)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">0.5</text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,0)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">1.0</text>
                </g>
                <text dy="0.71em" fill="#000" text-anchor="end" transform="rotate(-90)" y="6">Output</text>
              </g>
              <g class="grid" fill="none" font-family="sans-serif" font-size="10" text-anchor="middle"
                  transform="translate(0,110)">
                <path class="domain" d="M0.5,-110V0.5H280.5V-110" stroke="#000"></path>
                <g class="tick" opacity="1" transform="translate(140,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="-110"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="3"></text>
                </g>
              </g>
              <g class="grid" fill="none" font-family="sans-serif" font-size="10" text-anchor="end">
                <path class="domain" d="M280,110.5H0.5V0.5H280" stroke="#000"></path>
                <g class="tick" opacity="1" transform="translate(0,110)">
                  <line stroke="#000" x2="280" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-3" y="0.5"></text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,55)">
                  <line stroke="#000" x2="280" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-3" y="0.5"></text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,0)">
                  <line stroke="#000" x2="280" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-3" y="0.5"></text>
                </g>
              </g>
              <path class="sigmoid-line"
                  d="M0,110C0,110,5.833333333333334,110,7,110C8.166666666666666,110,12.833333333333334,110,14,110C15.166666666666666,110,19.833333333333332,110,21,110C22.166666666666668,110,26.833333333333332,110,28,110C29.166666666666668,110,33.833333333333336,110,35,110C36.166666666666664,110,40.833333333333336,110,42,110C43.166666666666664,110,47.833333333333336,110,49,110C50.166666666666664,110,54.833333333333336,110,56,110C57.166666666666664,110,61.833333333333336,110,63,110C64.16666666666667,110,68.83333333333333,110,70,110C71.16666666666667,110,75.83333333333333,110,77,110C78.16666666666667,110,82.83333333333333,110,84,110C85.16666666666667,110,89.83333333333333,110,91,110C92.16666666666667,110,96.83333333333333,110.08333333333333,98,110C99.16666666666667,109.91666666666667,103.83333333333333,109.16666666666667,105,109C106.16666666666667,108.83333333333333,110.83333333333333,108.33333333333333,112,108C113.16666666666667,107.66666666666667,117.83333333333333,105.91666666666667,119,105C120.16666666666667,104.08333333333333,124.83333333333333,99.08333333333333,126,97C127.16666666666667,94.91666666666667,131.83333333333334,83.5,133,80C134.16666666666666,76.5,138.83333333333334,59.166666666666664,140,55C141.16666666666666,50.833333333333336,145.83333333333334,33.5,147,30C148.16666666666666,26.5,152.83333333333334,15.083333333333332,154,13C155.16666666666666,10.916666666666668,159.83333333333334,5.916666666666667,161,5C162.16666666666666,4.083333333333333,166.83333333333334,2.3333333333333335,168,2C169.16666666666666,1.6666666666666667,173.83333333333334,1.1666666666666667,175,1C176.16666666666666,0.8333333333333334,180.83333333333334,0.08333333333333333,182,0C183.16666666666666,-0.08333333333333333,187.83333333333334,0,189,0C190.16666666666666,0,194.83333333333334,0,196,0C197.16666666666666,0,201.83333333333334,0,203,0C204.16666666666666,0,208.83333333333334,0,210,0C211.16666666666666,0,215.83333333333334,0,217,0C218.16666666666666,0,222.83333333333334,0,224,0C225.16666666666666,0,229.83333333333334,0,231,0C232.16666666666666,0,236.83333333333334,0,238,0C239.16666666666666,0,243.83333333333334,0,245,0C246.16666666666666,0,250.83333333333334,0,252,0C253.16666666666666,0,257.8333333333333,0,259,0C260.1666666666667,0,264.8333333333333,0,266,0C267.1666666666667,0,271.8333333333333,0,273,0C274.1666666666667,0,280,0,280,0"></path>
              <g class="value-point" transform="translate(129,92)">
                <ellipse class="sigmoid-value-dot" cx="0" cy="0" rx="5" ry="5"></ellipse>
                <text class="sigmoid-value-text" fill="red" font-size="12" text-anchor="middle" y="-8">0.16383036122
                </text>
              </g>
            </g>
          </svg>
        </div>
      </div>
    </div>

    <div class="row neuron-expo vertical-align">
      <div class="col-sm-4 small-column">
        <p><img alt="weighted neuron image" src="./index_files/two-input-one-output-sigmoid-network.png"></p>
      </div>
      <div class="col-sm-8 side-column">

        <p>To bring it all together, calculating a prediction with this shallow network looks like this:</p>

        <div class="language-plaintext highlighter-rouge">
          <div class="highlight"><pre class="highlight"><code>def sigmoid(x):
    return 1/(1 + math.exp(-x))

def calculate_prediction(age, sex, weight_1, weight_2, bias):

    # Multiply the inputs by their weights, sum the results up
    layer_2_node = age * weight_1 + sex * weight_2 + 1 * bias

    prediction = sigmoid(layer_2_node)
    return prediction
</code></pre>
          </div>
        </div>

      </div>
    </div>

    <p>Now that we know the structure of our network, we can train it using [gradient descent] running on the first 600
      rows of the 891-row dataset. I will not be addressing the training process in this post because that’s a separate
      concern at the moment. For now, I just want you to be comfortable with how a trained network calculates a
      prediction. Once you get this intuition down, we’ll proceed to training in a future post.</p>

    <p>The training process gives us the following values (with an accuracy of 73.20%):</p>

    <div class="language-plaintext highlighter-rouge">
      <div class="highlight"><pre class="highlight"><code>weight_1 =   -0.016852 # Associated with "Age"
weight_2 =   0.704039  # Associated with "Sex" (where male is 0, female is 1)
bias =       -0.116309
</code></pre>
      </div>
    </div>
    <p>Intuitively, the weights indicate how much their associated property contribute to the prediction – odds of
      survival improve the younger a person is (since a larger age multiplied by the negative weight value gives a
      bigger negative number). They improve more if the person is female.</p></section>
  <section>
    <h2>Prediction Calculation</h2>
    <p>The trained network now looks like this: (hover or click on values in the table to see their individual
      predictions)</p>
    <div class="row vertical-align">
      <div class="col-sm-3" id="neural-network-calculation-table">
        <table class="collapsed-style">
          <thead>
          <tr>
            <th class="title"></th>
            <th class="title">Age</th>
            <th class="center">Sex</th>
            <th class="center">Survived</th>
          </tr>
          </thead>
          <tbody>
          <tr>
            <td class="title"><input class="radio_0" name="person" type="radio"></td>
            <td class="title">22</td>
            <td class="center">0</td>
            <td class="center">0</td>
          </tr>
          <tr>
            <td class="title"><input class="radio_1" name="person" type="radio"></td>
            <td class="title">38</td>
            <td class="center">1</td>
            <td class="center">1</td>
          </tr>
          <tr>
            <td class="title"><input class="radio_2" name="person" type="radio"></td>
            <td class="title">26</td>
            <td class="center">1</td>
            <td class="center">1</td>
          </tr>
          <tr>
            <td class="title"><input class="radio_3" name="person" type="radio"></td>
            <td class="title">35</td>
            <td class="center">1</td>
            <td class="center">1</td>
          </tr>
          <tr>
            <td class="title"><input class="radio_4" name="person" type="radio"></td>
            <td class="title">35</td>
            <td class="center">0</td>
            <td class="center">0</td>
          </tr>
          <tr>
            <td class="title"><input class="radio_5" name="person" type="radio"></td>
            <td class="title">14</td>
            <td class="center">1</td>
            <td class="center">0</td>
          </tr>
          <tr>
            <td class="title"><input class="radio_6" name="person" type="radio"></td>
            <td class="title">25</td>
            <td class="center">0</td>
            <td class="center">0</td>
          </tr>
          <tr>
            <td class="title"><input class="radio_7" name="person" type="radio"></td>
            <td class="title">54</td>
            <td class="center">0</td>
            <td class="center">0</td>
          </tr>
          </tbody>
        </table>
      </div>
      <div class="col-sm-9" id="neural-network-calculation-viz">
        <svg height="300" width="500">
          <g>
            <line class="nn-arrow softmax-to-output-line" marker-end="url(#arrow)" x1="413.5" x2="440" y1="150"
                y2="150"></line>
            <text class="softmax-output-class-name" x="413.5" y="175">Survived</text>
            <line class="nn-arrow bias-to-softmax-line" marker-end="url(#arrow)" x1="300" x2="357.5" y1="150"
                y2="150"></line>
            <line class="nn-arrow input-to-bias-line input-0 output-0" x1="35" x2="300" y1="35" y2="150"></line>
            <line class="nn-arrow input-to-bias-line input-1 output-0" x1="35" x2="300" y1="145" y2="150"></line>
            <line class="nn-arrow input-to-bias-line input-2 output-0" x1="35" x2="300" y1="255" y2="150"></line>
            <g class="input-group" transform="translate(35,35)">
              <circle class="outlined-input-node nn-node" cx="0" cy="0" r="25"></circle>
              <text class="node-text" id="input-name" text-anchor="middle" x="0" y="5">Age</text>
            </g>
            <g class="input-group" transform="translate(35,145)">
              <circle class="outlined-input-node nn-node" cx="0" cy="0" r="25"></circle>
              <text class="node-text" id="input-name" text-anchor="middle" x="0" y="5">Sex</text>
            </g>
            <g class="input-group" transform="translate(35,255)">
              <circle class="outlined-bias-node nn-node" cx="0" cy="0" r="25"></circle>
              <text class="node-text" id="input-name" text-anchor="middle" x="0" y="5">Bias</text>
            </g>
            <g class="weight-group input-0-weight output-0-weight" transform="translate(100,63.20754716981132)">
              <ellipse class="outlined-weight-node nn-node" cx="0" cy="0" rx="22" ry="10"></ellipse>
              <text class="weightNodeText" id="weight0Value" text-anchor="middle" y="3">-0.0169</text>
            </g>
            <g class="weight-group input-1-weight output-0-weight" transform="translate(100,146.22641509433961)">
              <ellipse class="outlined-weight-node nn-node" cx="0" cy="0" rx="22" ry="10"></ellipse>
              <text class="weightNodeText" id="weight0Value" text-anchor="middle" y="3">0.704</text>
            </g>
            <g class="weight-group input-2-weight output-0-weight" transform="translate(100,229.24528301886792)">
              <ellipse class="outlined-bias-node nn-node" cx="0" cy="0" rx="22" ry="10"></ellipse>
              <text class="weightNodeText" id="weight0Value" text-anchor="middle" y="3">-0.1163</text>
            </g>
            <g class="output-group" transform="translate(300,150)">
              <circle class="outlined-output-node nn-node" cx="0" cy="0" r="25"></circle>
              <text class="node-text" id="output-name" text-anchor="middle" x="0" y="5"></text>
            </g>
            <g class="sigmoid activation" transform="translate(382.5,150)">
              <rect class="outlined-sigmoid-node nn-node" height="140" rx="6.25" ry="6.25" width="37.5" x="-18.75"
                  y="-70"></rect>
              <text id="sigmoid-label" text-anchor="middle" x="0" y="-2">σ</text>
            </g>
          </g>
          <defs>
            <marker id="arrow" markerHeight="4" markerWidth="4" orient="auto" refX="5" refY="0" viewBox="0 -5 10 10">
              <path class="arrowHead" d="M0,-5L10,0L0,5"></path>
            </marker>
          </defs>
        </svg>
      </div>
    </div>
    <div class="nn-tooltip" style="opacity: 0"></div>
    <p>An accuracy of 73.20% isn’t very impressive. This is a case that can benefit from adding a hidden layer. Hidden
      layers give the model more capacity to represent more sophisticated prediction functions that may do a better job
      (<a href="https://www.deeplearningbook.org/contents/ml.html">Deep Learning ch.5 page 113</a>).</p>
    <div class="row neuron-expo vertical-align">
      <div class="col-sm-4 small-column">
        <p><img alt="weighted neuron with activation" src="./index_files/neuron_with_activation.png"></p>
      </div>
      <div class="col-sm-8 side-column">
        <p>It’s often useful to apply certain math functions to the weighted outputs. These are called “activation
          functions” because historically they translated the output of the neuron into either 1 (On/active) or 0
          (Off).</p>
        <div class="language-plaintext highlighter-rouge">
          <div class="highlight"><pre class="highlight"><code>def activation_function(x):
    # Do something to the value
    ...

weighted_sum = weight * (input_1 + input_2)
output = activation_function(weighted_sum)
</code></pre>
          </div>
        </div>
        <p>Activation functions are vital for hidden layers. Without them, deep networks would be no better than a
          shallow linear network. Read the “Commonly used activation functions” section from <a
              href="https://cs231n.github.io/neural-networks-1/">Neural Networks Part 1: Setting up the Architecture</a>
          for a look at various activation functions.</p>
      </div>
    </div>
  </section>
  <section>
    <h3 id="relu-">ReLU</h3>
    <div class="row neuron-expo vertical-align">
      <div class="col-sm-4 small-column">
        <p><img alt="weighted neuron with activation" src="./index_files/relu.png"></p>
      </div>
      <div class="col-sm-8 side-column">
        <p>A leading choice for activation function is called ReLU. It returns 0 if its input is negative, returns the
          number itself otherwise. Very simple!</p>
        <p>f(x) = max(0, x)</p>
        <div class="language-plaintext highlighter-rouge">
          <div class="highlight"><pre class="highlight"><code># Naive scalar relu implementation. In the real world, most calculations are done on vectors
def relu(x):
    if x &lt; 0:
        return 0
    else:
        return x


output = relu(value)
</code></pre>
          </div>
        </div>
      </div>
    </div>
  </section>
  <section>
    <h2 id="relu-visualization-">ReLU Visualization</h2>
    <div class="row neuron-expo vertical-align">
      <div class="col-sm-4 small-column">
        <!-- ==== RELU ACTIVATION GRAPH ==== -->
        <table>
          <tbody>
          <tr>
            <td class="relu-input-value explicit-slider-weight-value">0</td>
            <td>
              <img alt="Relu" src="./index_files/relu.png">
            </td>
            <td class="explicit-relu-activation-output-value">0</td>
          </tr>
          </tbody>
        </table>
      </div>
      <div class="col-sm-8 side-column">
        <p>Interact a little with relu to see how it transforms various values</p>
        <!-- ==== RELU SLIDER ==== -->
        <table class="activation-graph-slider">
          <tbody>
          <tr>
            <td>
              <input class="weight" id="relu-slider" max="20" min="-20" step="0.01"
                  style="width: 320px; margin-left: 40px;" type="range">
            </td>
            <td class="slider-value">
              <span class="weight relu-input-value">0</span>
            </td>
          </tr>
          </tbody>
        </table>
        <!-- ==== RELU FORMULA ==== -->
        <p style="margin-left:40px">f(<span class="slider-value"><span class="relu-input-value weight">0</span></span>)
          = max( 0, <span class="mord mathit" id="relu-formula-input"><span
              class="relu-value-input-number">0.00</span></span>) = <span id="relu-result">0</span></p>
        <!-- ==== RELU GRAPH ==== -->
        <div id="relu-graph" style="width:100%">
          <svg class="activation-graph" height="160" width="400">
            <g transform="translate(60,30)">
              <g class="x-axis" fill="none" font-family="sans-serif" font-size="10" text-anchor="middle"
                  transform="translate(0,110)">
                <g class="tick" opacity="1" transform="translate(0,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">-20</text>
                </g>
                <g class="tick" opacity="1" transform="translate(35,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">-15</text>
                </g>
                <g class="tick" opacity="1" transform="translate(70,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">-10</text>
                </g>
                <g class="tick" opacity="1" transform="translate(105,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">-5</text>
                </g>
                <g class="tick" opacity="1" transform="translate(140,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">0</text>
                </g>
                <g class="tick" opacity="1" transform="translate(175,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">5</text>
                </g>
                <g class="tick" opacity="1" transform="translate(210,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">10</text>
                </g>
                <g class="tick" opacity="1" transform="translate(245,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">15</text>
                </g>
                <g class="tick" opacity="1" transform="translate(280,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="6"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="9">20</text>
                </g>
              </g>
              <g class="y-axis" fill="none" font-family="sans-serif" font-size="10" text-anchor="end">
                <path class="domain" d="M-6,110.5H0.5V0.5H-6" stroke="#000"></path>
                <g class="tick" opacity="1" transform="translate(0,110)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">0</text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,99)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">2</text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,88)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">4</text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,77)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">6</text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,66)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">8</text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,55)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">10</text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,44)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">12</text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,33)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">14</text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,22)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">16</text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,11)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">18</text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,0)">
                  <line stroke="#000" x2="-6" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-9" y="0.5">20</text>
                </g>
                <text dy="0.71em" fill="#000" text-anchor="end" transform="rotate(-90)" y="6">Output</text>
              </g>
              <g class="grid" fill="none" font-family="sans-serif" font-size="10" text-anchor="middle"
                  transform="translate(0,110)">
                <path class="domain" d="M0.5,-110V0.5H280.5V-110" stroke="#000"></path>
                <g class="tick" opacity="1" transform="translate(140,0)">
                  <line stroke="#000" x1="0.5" x2="0.5" y2="-110"></line>
                  <text dy="0.71em" fill="#000" x="0.5" y="3"></text>
                </g>
              </g>
              <g class="grid" fill="none" font-family="sans-serif" font-size="10" text-anchor="end">
                <path class="domain" d="M280,110.5H0.5V0.5H280" stroke="#000"></path>
                <g class="tick" opacity="1" transform="translate(0,110)">
                  <line stroke="#000" x2="280" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-3" y="0.5"></text>
                </g>
                <g class="tick" opacity="1" transform="translate(0,0)">
                  <line stroke="#000" x2="280" y1="0.5" y2="0.5"></line>
                  <text dy="0.32em" fill="#000" x="-3" y="0.5"></text>
                </g>
              </g>
              <path class="relu-line activation-line-straight" d="M0,110L70,110L140,110L210,55L280,0"></path>
              <g class="value-point" transform="translate(140,110)">
                <ellipse class="relu-value-dot activation-value-dot" cx="0" cy="0" rx="5" ry="5"></ellipse>
                <text class="relu-value-text activation-value-text" fill="red" font-size="12" text-anchor="middle"
                    y="-8">0
                </text>
              </g>
            </g>
          </svg>
        </div>
      </div>
    </div>
  </section>
  <section>
    <h2>Closing</h2>
    <p>This post has been parked for more than a year. I had attempted to visualize a deeper network after this point,
      but that never materialized. I hope you enjoyed it. Let me know on <a href="https://twitter.com/JayAlammar">@JayAlammar
        on Twitter</a> if you have any feedback.</p>
  </section>
</div>
<div class="date">Written on December 14, 2016</div>
<div>
  <a href="https://creativecommons.org/licenses/by-nc-sa/4.0/" rel="license"><img alt="Creative Commons License"
      src="https://i.creativecommons.org/l/by-nc-sa/4.0/88x31.png" style="border-width:0"/></a><br/>This work is
  licensed under a <a href="https://creativecommons.org/licenses/by-nc-sa/4.0/" rel="license">Creative Commons
  Attribution-NonCommercial-ShareAlike 4.0 International License</a>. <br/> Attribution example: <br/> <i>Alammar, Jay
  (2018). The Illustrated Transformer [Blog post]. Retrieved from <a
      href="https://jalammar.github.io/illustrated-transformer/">https://jalammar.github.io/illustrated-transformer/</a></i>
  <br/><br/> Note: If you translate any of the posts, let me know so I can link your translation to the original post.
  My email is in the <a href="https://jalammar.github.io/about">about page</a>.
</div>
<script>
  (function (i, s, o, g, r, a, m) {
    i['GoogleAnalyticsObject'] = r;
    i[r] = i[r] || function () {
      (i[r].q = i[r].q || []).push(arguments)
    };
    i[r].l = 1 * new Date();
    a = s.createElement(o);
    m = s.getElementsByTagName(o)[0];
    a.async = 1;
    a.src = g;
    m.parentNode.insertBefore(a, m)
  })(window, document, 'script', '//www.google-analytics.com/analytics.js', 'ga');
  ga('create', 'UA-71956058-1', 'auto');
  ga('send', 'pageview', {
    'page': '/visual-interactive-guide-basics-neural-networks/',
    'title': 'A Visual and Interactive Guide to the Basics of Neural Networks'
  });
</script>
<!--#include virtual="/footer.html" -->
<script src="../js/bootstrap.min.js"></script>
<script>
  $(document).ready(function () {
    setTimeout(() => {
      $.getScript("../js/nnVizUtils.js");
      $.getScript("../js/sigmoid_graph.js");
      $.getScript("../js/nn_calc.js");
      $.getScript("../js/relu_graph.js");
      $.getScript("../js/accuracy-graph.js");
    }, 4000)  // Wait for angular to set the DOM
  });
</script>
<style>.mjx-math * {
  line-height: 0;
}</style>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/latest.js?config=AM_CHTML"></script>