salieri/tartarus-deep

View on GitHub
examples/dense-2x.ts

Summary

Maintainability
A
2 hrs
Test Coverage
/**
 * Simple example network which learns to multiply input by two
 *
 * Input: 5, 3, 8
 * Output: 10, 6, 16
 */

import _ from 'lodash';

import {
  MemoryInputFeed,
  Dense,
  Model,
  NDArray,
} from '../src';

import { SampleGenerator } from './sample-generator';
import { Stochastic } from '../src/nn/optimizer';


export class Dense2x extends SampleGenerator {
  public model(): Model {
    const optimizer = new Stochastic({ rate: 0.001 });
    const model = new Model({ optimizer, seed: this.params.seed, loss: 'mean-squared-error' });

    /**
     * An intentionally unnecessarily complex network for figuring out 2X
     */
    model
      .input(1)
      .push(new Dense({ units: 4, activation: 'identity' }, 'hidden-1'))
      .push(new Dense({ units: 5, activation: 'identity' }, 'hidden-2'))
      .push(new Dense({ units: 1, activation: 'identity' }, 'output'));

    // .output('result'); /* will be automatically set to the last node, unless you choose to declare it */

    return model;
  }


  public samples(count: number): MemoryInputFeed {
    return MemoryInputFeed.factory(
      _.times(
        count,
        (n: number) => {
          const x = n % 25;

          return { x: new NDArray([x]), y: new NDArray([x * 2]) };
        },
      ),
    );
  }
}