wkdhkr/dedupper

View on GitHub
src/services/deepLearning/CocoSsdService.js

Summary

Maintainability
A
1 hr
Test Coverage
// @flow
import "fast-text-encoding";
// eslint-disable-next-line import/no-unresolved
import * as cocoSsd from "@tensorflow-models/coco-ssd";

import FileNameMarkHelper from "../../helpers/FileNameMarkHelper";
import { MARK_ERASE } from "../../types/FileNameMarks";
import LockHelper from "../../helpers/LockHelper";
import { canvas, saveFile } from "./faceApi/commons";
import type { Config } from "../../types";

const { Image, createCanvas } = canvas;
let model;
export default class CocoSsdService {
  config: Config;

  constructor(config: Config) {
    this.config = config;
  }

  createCanvasAndContext: (w: number, h: number) => Array<any> = (
    w: number,
    h: number
  ) => {
    const c = createCanvas(w, h);
    return [c, c.getContext("2d")];
  };

  loadModel: () => Promise<void> = async (): Promise<void> => {
    await LockHelper.lockProcess();
    if (!model) {
      model = await cocoSsd.load("mobilenet_v2");
      // model = await cocoSsd.load();
    }
    await LockHelper.unlockProcess();
  };

  demo: (targetPath: string) => Promise<Array<any>> = async (
    targetPath: string
  ): Promise<any[]> => {
    const classes = await this.predict(targetPath);
    const img = await canvas.loadImage(targetPath);
    const [c, ctx] = this.createCanvasAndContext(img.width, img.height);
    ctx.drawImage(img, 0, 0);
    const context = c.getContext("2d");
    ctx.font = "10px Arial";

    for (let i = 0; i < classes.length; i += 1) {
      context.beginPath();
      context.rect(...classes[i].bbox);
      context.lineWidth = 1;
      context.strokeStyle = "green";
      context.fillStyle = "green";
      context.stroke();
      context.fillText(
        `${classes[i].score.toFixed(3)} ${classes[i].class}`,
        classes[i].bbox[0],
        classes[i].bbox[1] > 10 ? classes[i].bbox[1] - 5 : 10
      );
    }
    const destPath = FileNameMarkHelper.mark(targetPath, new Set([MARK_ERASE]));
    saveFile(destPath, c.toBuffer("image/jpeg"));
    return classes;
  };

  predict: (targetPath: string) => Promise<Array<any>> = async (
    targetPath: string
  ): Promise<any[]> => {
    return new Promise((resolve, reject) => {
      try {
        const img = new Image();
        img.onload = async () => {
          const [c, ctx] = this.createCanvasAndContext(img.width, img.height);
          ctx.drawImage(img, 0, 0);
          // classify
          if (!model) {
            await this.loadModel();
          }
          if (model) {
            const maxDetectionSize = 10;
            const classes = await model.detect(c, maxDetectionSize);
            // Classify the image
            resolve(classes);
            return;
          }
          reject(new Error("model not loaded"));
        };
        img.onerror = err => {
          reject(err);
        };
        img.src = targetPath;
      } catch (e) {
        reject(e);
      }
    });
  };
}