salieri/tartarus-deep

View on GitHub
src/nn/layer/concat.ts

Summary

Maintainability
C
7 hrs
Test Coverage
import _ from 'lodash';
import { Layer, LayerParams } from './layer';
import { JoiEx, JoiExSchema } from '../../util';

import {
  DeferredValue,
  DeferredCollectionWrapper,
  DeferredCollection,
} from '../symbols';

import {
  Matrix,
  MatrixDirection,
  NDArray,
  Vector,
} from '../../math';

import { KeyNotFoundError } from '../../error';
import { Dense } from './dense';


export type ConcatOutputTraverseFunction =
  (field: DeferredValue, fieldKey: string, layerOutput: DeferredCollectionWrapper, layerKey: string) => void;

export type ConcatOutputTraverseKeyFunction = (fieldKey: string, layerOutput: DeferredCollectionWrapper, layerKey: string) => void;


export interface ConcatLayerExtendedDefinition {
  layer: string;
  field?: string|null;
}

export type ConcatLayerDefinition = string|ConcatLayerExtendedDefinition;


export interface ConcatParams extends LayerParams {
  fields?: null|ConcatLayerDefinition[];
}


export class Concat extends Layer<ConcatParams> {
  public static readonly CONCATENATED: string = 'concatenated';


  protected async optimizeExec(): Promise<void> {
    // do nothing
  }


  protected resolveBackpropInput(): void {
    const backpropInput = this.data.backpropInput;
    const rawBackpropInputs = this.raw.backpropInputs;

    // This needs rewriting to deal with cases where a layer has multiple outputs
    // This needs rewriting to deal with bias, not just weight
    try {
      backpropInput.setCollection(rawBackpropInputs.getDefault());
      return;
    } catch (err) {
      if (!(err instanceof KeyNotFoundError)) {
        throw err;
      }
    }

    if (rawBackpropInputs.count() < 1) {
      throw new Error(`Missing backprop input for concat layer '${this.getName()}' -- concat layer cannot be an output layer`);
    }

    if (rawBackpropInputs.count() > 1) {
      // throw new Error(`Too many inputs for a dense layer '${this.getName()}'`);
      // DO SOMETHING
    }

    backpropInput.setCollection(rawBackpropInputs.first());
  }


  protected async backwardExec(): Promise<void> {
    const bpInput = this.data.backpropInput;

    const errorTerm = bpInput.getValue(Layer.ERROR_TERM, Vector);
    const w = bpInput.getValue(Dense.WEIGHT_MATRIX, Matrix);

    let curPos = 0;

    this.traverse(
      (field: DeferredValue, fieldKey: string, layerOutput: DeferredCollectionWrapper, layerKey: string) => {
        if (layerOutput.getDefaultKey() !== fieldKey) {
          curPos += field.countElements();
          return; // only default output will have a derivative
        }

        const wm = w.slice(MatrixDirection.Vertical, curPos, field.countElements());
        const coll = this.raw.backpropOutputs.get(layerKey).getCollection();

        coll.setValue(Dense.WEIGHT_MATRIX, wm);
        coll.setValue(Layer.ERROR_TERM, errorTerm);

        curPos += field.countElements();
      },
    );
  }


  protected async forwardExec(): Promise<void> {
    let result: NDArray|undefined;

    this.traverse(
      (field: DeferredValue): void => {
        const fieldValue = field.get().flatten();

        result = result ? result.concat(fieldValue) : fieldValue;
      },
    );

    if (!result) {
      throw new Error('No layers to concatenate');
    }

    this.data.output.setDefaultValue(result);
  }


  protected getInputInOrder(): ConcatLayerDefinition[] {
    return this.params.fields ? this.params.fields : this.raw.inputs.getKeys();
  }


  protected verifyInputLayers(): void {
    const allKeys = this.raw.inputs.getKeys();
    const definedLayerKeys: string[] = [];

    try {
      this.traverseKeys(
        (fieldKey: string, layerOutput: DeferredCollectionWrapper, layerKey: string) => {
          definedLayerKeys.push(layerKey);

          try {
            layerOutput.require(fieldKey);
          } catch (err) {
            if (err instanceof KeyNotFoundError) {
              throw new Error(`Concat layer '${this.getName()}' requires `
                + `field '${fieldKey}' from `
                + `layer '${layerKey}', which has not been declared`);
            }

            throw err;
          }
        },
      );
    } catch (err) {
      if (err instanceof KeyNotFoundError) {
        throw new Error(`Concat layer '${this.getName()}' expects input from layer '${err.key}', which is not linked to this layer`);
      }

      throw err;
    }

    const cleanedLayerKeys = _.uniq(definedLayerKeys);
    const differenceAllKeys = _.difference(allKeys, cleanedLayerKeys);
    // const differenceOrderedKeys = _.difference(cleanedLayerKeys, allKeys);

    if (differenceAllKeys.length > 0) {
      throw new Error(
        `Concat layer '${this.getName()}' has more input layers than defined in the 'fields' parameter: ${_.join(differenceAllKeys)}`,
      );
    }
  }


  protected traverse(callback: ConcatOutputTraverseFunction) : void {
    this.traverseKeys(
      (fieldKey: string, layerOutput: DeferredCollectionWrapper, layerKey: string) => {
        const field = layerOutput.get(fieldKey);

        callback(field, fieldKey, layerOutput, layerKey);
      },
    );
  }


  protected traverseKeys(callback: ConcatOutputTraverseKeyFunction) : void {
    _.each(
      this.getInputInOrder(),
      (layer: ConcatLayerDefinition) => {
        let layerKey = _.isString(layer) ? layer : layer.layer;

        let fieldKey: string|undefined;

        const layerSections = _.split(layerKey, '.', 2);

        if (layerSections.length > 1) {
          layerKey = layerSections[0];
          fieldKey = layerSections[1];
        }

        const layerOutput = this.raw.inputs.get(layerKey);

        if (!fieldKey) {
          fieldKey = _.isString(layer) ? layerOutput.getDefaultKey() : (layer.field || layerOutput.getDefaultKey());
        }

        callback(fieldKey, layerOutput, layerKey);
      },
    );
  }


  protected determineOutputSize(): number {
    let total = 0;

    this.traverse(
      (field: DeferredValue, fieldKey: string, layerOutput: DeferredCollectionWrapper): void => {
        layerOutput.require(fieldKey);

        total += field.countElements();
      },
    );

    return total;
  }


  protected async compileInitialization(): Promise<void> {
    this.raw.outputs.setDefault(this.data.output);
  }


  protected async compileForwardPropagation(): Promise<void> {
    this.verifyInputLayers();

    this.data.output.declare(Concat.CONCATENATED, this.determineOutputSize());
    this.data.output.setDefaultKey(Concat.CONCATENATED);

    this.prepareForBackprop();
  }


  protected prepareForBackprop(): void {
    const layers:string[] = [];

    this.traverseKeys((fieldKey: string, layerOutput: DeferredCollectionWrapper, layerKey: string) => layers.push(layerKey));

    _.each(_.uniq(layers), (layer: string) => this.raw.backpropOutputs.set(layer, new DeferredCollection()));
  }


  protected async compileBackPropagation(): Promise<void> {
    this.resolveBackpropInput();

    const backpropInput = this.data.backpropInput;

    backpropInput.require(Layer.ERROR_TERM);
    backpropInput.require(Dense.WEIGHT_MATRIX);

    this.traverse(
      (field: DeferredValue, fieldKey: string, layerOutput: DeferredCollectionWrapper, layerKey: string): void => {
        if (layerOutput.getDefaultKey() !== fieldKey) {
          return; // Only default output will have a derivative
        }

        const bpOutput = this.raw.backpropOutputs.get(layerKey).getCollection();
        const errorTerm = backpropInput.get(Layer.ERROR_TERM);
        const nextLayerUnits = errorTerm.countElements();
        const inputUnits = this.raw.inputs.get(layerKey).get(fieldKey).countElements();

        bpOutput.declare(Dense.WEIGHT_MATRIX, [nextLayerUnits, inputUnits]);
        bpOutput.declare(Layer.ERROR_TERM, errorTerm.getDims());
      },
    );
  }


  protected async initializeExec(): Promise<void> {
    // do nothing
  }


  public getParamSchema(): JoiExSchema {
    return JoiEx.object().keys(
      {
        fields: JoiEx.array()
          .optional()
          .items(
            JoiEx.string(),
            JoiEx.object().keys(
              {
                layer: JoiEx.string().required().description('Name of the layer ("layer.field" shortcut allowed)'),
                field: JoiEx.string().optional().allow(null).default(null)
                  .description('Output field to include'),
              },
            ),
          )
          .allow(null)
          .default(null)
          .description('Order in which layers are concatenated'),
      },
    );
  }
}