Konano/arknights-mower

View on GitHub
arknights_mower/utils/matcher.py

Summary

Maintainability
A
30 mins
Test Coverage
from __future__ import annotations

import pickle
import traceback
from typing import Optional, Tuple

import cv2
import numpy as np
import sklearn
from matplotlib import pyplot as plt
from skimage.metrics import structural_similarity as compare_ssim

from .. import __rootdir__
from . import typealias as tp
from .image import cropimg
from .log import logger

MATCHER_DEBUG = False
FLANN_INDEX_KDTREE = 0
GOOD_DISTANCE_LIMIT = 0.7
SIFT = cv2.SIFT_create()
with open(f'{__rootdir__}/models/svm.model', 'rb') as f:
    SVC = pickle.loads(f.read())


def getHash(data: list[float]) -> tp.Hash:
    """ calc image hash """
    avreage = np.mean(data)
    return np.where(data > avreage, 1, 0)


def hammingDistance(hash1: tp.Hash, hash2: tp.Hash) -> int:
    """ calc Hamming distance between two hash """
    return np.count_nonzero(hash1 != hash2)


def aHash(img1: tp.GrayImage, img2: tp.GrayImage) -> int:
    """ calc image hash """
    data1 = cv2.resize(img1, (8, 8)).flatten()
    data2 = cv2.resize(img2, (8, 8)).flatten()
    hash1 = getHash(data1)
    hash2 = getHash(data2)
    return hammingDistance(hash1, hash2)


class Matcher(object):
    """ image matching module """

    def __init__(self, origin: tp.GrayImage) -> None:
        logger.debug(f'Matcher init: shape ({origin.shape})')
        self.origin = origin
        self.init_sift()

    def init_sift(self) -> None:
        """ get SIFT feature points """
        self.kp, self.des = SIFT.detectAndCompute(self.origin, None)

    def match(self, query: tp.GrayImage, draw: bool = False, scope: tp.Scope = None, judge: bool = True) -> Optional(tp.Scope):
        """ check if the image can be matched """
        rect_score = self.score(query, draw, scope)  # get matching score
        if rect_score is None:
            return None  # failed in matching
        else:
            rect, score = rect_score

        # use SVC to determine if the score falls within the legal range
        if judge and not SVC.predict([score])[0]:  # numpy.bool_
            logger.debug(f'match fail: {score}')
            return None  # failed in matching
        else:
            logger.debug(f'match success: {score}')
            return rect  # success in matching

    def score(self, query: tp.GrayImage, draw: bool = False, scope: tp.Scope = None, only_score: bool = False) -> Optional(Tuple[tp.Scope, tp.Score]):
        """ scoring of image matching """
        try:
            # if feature points is empty
            if self.des is None:
                logger.debug('feature points is None')
                return None

            # specify the crop scope
            if scope is not None:
                ori_kp, ori_des = [], []
                for _kp, _des in zip(self.kp, self.des):
                    if scope[0][0] <= _kp.pt[0] and scope[0][1] <= _kp.pt[1] and _kp.pt[0] <= scope[1][0] and _kp.pt[1] <= scope[1][1]:
                        ori_kp.append(_kp)
                        ori_des.append(_des)
                logger.debug(
                    f'match crop: {scope}, {len(self.kp)} -> {len(ori_kp)}')
                ori_kp, ori_des = np.array(ori_kp), np.array(ori_des)
            else:
                ori_kp, ori_des = self.kp, self.des

            # if feature points is less than 2
            if len(ori_kp) < 2:
                logger.debug('feature points is less than 2')
                return None

            # the height & width of query image
            h, w = query.shape

            # the feature point of query image
            qry_kp, qry_des = SIFT.detectAndCompute(query, None)

            # build FlannBasedMatcher
            index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
            search_params = dict(checks=50)
            flann = cv2.FlannBasedMatcher(index_params, search_params)
            matches = flann.knnMatch(qry_des, ori_des, k=2)

            # store all the good matches as per Lowe's ratio test
            good = []
            for x, y in matches:
                if x.distance < GOOD_DISTANCE_LIMIT * y.distance:
                    good.append(x)
            good_matches_rate = len(good) / len(qry_des)

            # draw all the good matches, for debug
            if draw:
                result = cv2.drawMatches(
                    query, qry_kp, self.origin, ori_kp, good, None)
                plt.imshow(result, 'gray')
                plt.show()

            # if the number of good matches no more than 4
            if len(good) <= 4:
                logger.debug(
                    f'not enough good matches are found: {len(good)} / {len(qry_des)}')
                return None

            # get the coordinates of good matches
            qry_pts = np.float32(
                [qry_kp[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
            ori_pts = np.float32(
                [ori_kp[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)

            # calculated transformation matrix and the mask
            M, mask = cv2.findHomography(qry_pts, ori_pts, cv2.RANSAC, 5.0)
            matchesMask = mask.ravel().tolist()

            # if transformation matrix is None
            if M is None:
                logger.debug('calculated transformation matrix failed')
                return None

            # calc the location of the query image
            quad = np.float32([[[0, 0]], [[0, h-1]], [[w-1, h-1]], [[w-1, 0]]])
            quad = cv2.perspectiveTransform(quad, M)  # quadrangle
            quad_points = qp = np.int32(quad).reshape(4, 2).tolist()

            # draw the result, for debug
            if draw or MATCHER_DEBUG:
                cv2.polylines(self.origin, [np.int32(quad)],
                              True, 0, 2, cv2.LINE_AA)
                draw_params = dict(matchColor=(0, 255, 0), singlePointColor=None,
                                   matchesMask=matchesMask, flags=2)
                result = cv2.drawMatches(query, qry_kp, self.origin, ori_kp,
                                         good, None, **draw_params)
                plt.imshow(result, 'gray')
                plt.show()

            # if quadrangle is not rectangle
            if max(abs(qp[0][0] - qp[1][0]), abs(qp[2][0] - qp[3][0]), abs(qp[0][1] - qp[3][1]), abs(qp[1][1] - qp[2][1])) > 30:
                logger.debug(f'square is not rectangle: {qp}')
                return None

            # make quadrangle rectangle
            rect = [(np.min(quad[:, 0, 0]), np.min(quad[:, 0, 1])),
                    (np.max(quad[:, 0, 0]), np.max(quad[:, 0, 1]))]

            # if rectangle is too small
            if rect[1][0] - rect[0][0] < 10 or rect[1][1] - rect[0][1] < 10:
                logger.debug(f'rectangle is too small: {rect}')
                return None

            # measure the rate of good match within the rectangle (x-axis)
            better = filter(
                lambda m:
                    rect[0][0] < ori_kp[m.trainIdx].pt[0] < rect[1][0] and rect[0][1] < ori_kp[m.trainIdx].pt[1] < rect[1][1], good)
            better_kp_x = [qry_kp[m.queryIdx].pt[0] for m in better]
            if len(better_kp_x):
                good_area_rate = np.ptp(better_kp_x) / w
            else:
                good_area_rate = 0

            # rectangle: float -> int
            rect = np.array(rect, dtype=int).tolist()
            rect_img = cropimg(self.origin, rect)

            # if rect_img is too small
            if np.min(rect_img.shape) < 10:
                logger.debug(f'rect_img is too small: {rect_img.shape}')
                return None

            # transpose rect_img
            rect_img = cv2.resize(rect_img, query.shape[::-1])

            # draw the result
            if draw or MATCHER_DEBUG:
                plt.subplot(1, 2, 1)
                plt.imshow(query, 'gray')
                plt.subplot(1, 2, 2)
                plt.imshow(rect_img, 'gray')
                plt.show()

            # calc aHash between query image and rect_img
            hash = 1 - (aHash(query, rect_img) / 32)

            # calc ssim between query image and rect_img
            ssim = compare_ssim(query, rect_img, multichannel=True)

            # return final rectangle and four dimensions of scoring
            if only_score:
                return (good_matches_rate, good_area_rate, hash, ssim)
            else:
                return rect, (good_matches_rate, good_area_rate, hash, ssim)

        except Exception as e:
            logger.error(e)
            logger.debug(traceback.format_exc())