Konano/arknights-mower

View on GitHub
arknights_mower/utils/character_recognize.py

Summary

Maintainability
A
1 hr
Test Coverage
from __future__ import annotations

import traceback
from copy import deepcopy

import cv2
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image, ImageDraw, ImageFont

from .. import __rootdir__
from ..data import agent_list
from . import segment
from .image import saveimg
from .log import logger
from .recognize import RecognizeError
from ..ocr import ocrhandle


def poly_center(poly):
    return (np.average([x[0] for x in poly]), np.average([x[1] for x in poly]))


def in_poly(poly, p):
    return poly[0, 0] <= p[0] <= poly[2, 0] and poly[0, 1] <= p[1] <= poly[2, 1]


char_map = {}
agent_sorted = sorted(deepcopy(agent_list), key=len)
origin = origin_kp = origin_des = None

FLANN_INDEX_KDTREE = 0
GOOD_DISTANCE_LIMIT = 0.7
SIFT = cv2.SIFT_create()


def agent_sift_init():
    global origin, origin_kp, origin_des
    if origin is None:
        logger.debug('agent_sift_init')

        height = width = 2000
        lnum = 25
        cell = height // lnum

        img = np.zeros((height, width, 3), dtype=np.uint8)
        img = Image.fromarray(img)

        font = ImageFont.truetype(
            f'{__rootdir__}/fonts/SourceHanSansSC-Bold.otf', size=30, encoding='utf-8'
        )
        chars = sorted(list(set(''.join([x for x in agent_list]))))
        assert len(chars) <= (lnum - 2) * (lnum - 2)

        for idx, char in enumerate(chars):
            x, y = idx % (lnum - 2) + 1, idx // (lnum - 2) + 1
            char_map[(x, y)] = char
            ImageDraw.Draw(img).text(
                (x * cell, y * cell), char, (255, 255, 255), font=font
            )

        origin = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2GRAY)
        origin_kp, origin_des = SIFT.detectAndCompute(origin, None)


def sift_recog(query, resolution, draw=False,reverse = False):
    """
    使用 SIFT 提取特征点识别干员名称
    """
    agent_sift_init()
    query = cv2.cvtColor(np.array(query), cv2.COLOR_RGB2GRAY)
    if reverse :
        # 干员总览界面图像色度反转
        query = 255 -query
    # the height & width of query image
    height, width = query.shape

    multi = 2 * (resolution / 1080)
    query = cv2.resize(query, (int(width * multi), int(height * multi)))
    query_kp, query_des = SIFT.detectAndCompute(query, None)

    # build FlannBasedMatcher
    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    search_params = dict(checks=50)
    flann = cv2.FlannBasedMatcher(index_params, search_params)
    matches = flann.knnMatch(query_des, origin_des, k=2)

    # store all the good matches as per Lowe's ratio test
    good = []
    for x, y in matches:
        if x.distance < GOOD_DISTANCE_LIMIT * y.distance:
            good.append(x)

    if draw:
        result = cv2.drawMatches(query, query_kp, origin, origin_kp, good, None)
        plt.imshow(result, 'gray')
        plt.show()

    count = {}

    for x in good:
        x, y = origin_kp[x.trainIdx].pt
        c = char_map[(int(x) // 80, int(y) // 80)]
        count[c] = count.get(c, 0) + 1

    best = None
    best_score = 0
    for x in agent_sorted:
        score = 0
        for c in set(x):
            score += count.get(c, -1)
        if score > best_score:
            best = x
            best_score = score

    logger.debug(f'segment.sift_recog: {count}, {best}')

    return best


def agent(img, draw=False):
    """
    识别干员总览界面的干员名称
    """
    try:
        height, width, _ = img.shape
        resolution = height
        left, right = 0, width

        # 异形屏适配
        while np.max(img[:, right - 1]) < 100:
            right -= 1
        while np.max(img[:, left]) < 100:
            left += 1

        # 去除左侧干员详情
        x0 = left + 1
        while not (
            img[height - 1, x0 - 1, 0] > img[height - 1, x0, 0] + 10
            and abs(int(img[height - 1, x0, 0]) - int(img[height - 1, x0 + 1, 0])) < 5
        ):
            x0 += 1

        # 获取分割结果
        ret, ocr = segment.agent(img, draw)

        # 确定位置后开始精确识别
        ret_succ = []
        ret_fail = []
        ret_agent = []
        for poly in ret:
            found_ocr, fx = None, 0
            for x in ocr:
                cx, cy = poly_center(x[2])
                if in_poly(poly, (cx + x0, cy)) and cx > fx:
                    fx = cx
                    found_ocr = x
            __img = img[poly[0, 1]: poly[2, 1], poly[0, 0]: poly[2, 0]]
            try:
                if found_ocr is not None:
                    x = found_ocr
                    if x[1] in agent_list and x[1] not in ['砾', '陈']:  # ocr 经常会把这两个搞错
                        ret_agent.append(x[1])
                        ret_succ.append(poly)
                        continue
                    res = sift_recog(__img, resolution, draw)
                    if (res is not None) and res in agent_list:
                        ret_agent.append(res)
                        ret_succ.append(poly)
                        continue
                    logger.warning(
                        f'干员名称识别异常:{x[1]} 为不存在的数据,请报告至 https://github.com/Konano/arknights-mower/issues'
                    )
                    saveimg(__img, 'failure_agent')
                    raise Exception("启动 Plan B")
                else:
                    if 80 <= np.min(__img):
                        continue
                    res = sift_recog(__img, resolution, draw)
                    if res is not None:
                        ret_agent.append(res)
                        ret_succ.append(poly)
                        continue
                    logger.warning(f'干员名称识别异常:区域 {poly.tolist()}')
                    saveimg(__img, 'failure_agent')
                    raise Exception("启动 Plan B")
                ret_fail.append(poly)
                raise Exception("启动 Plan B")
            except Exception as e:
                # 大哥不行了,二哥上!
                name = segment.read_screen(__img,
                                           langurage="chi_sim",
                                           type="text")
                logger.warning(f'备选方案识别结果: {name}')
                ret_fail.append(poly)
                if name in ocr_error.keys():
                    name = ocr_error[name]
                else:
                    ret_agent.append(name)
                    ret_succ.append(poly)
                    continue
        if len(ret_fail):
            saveimg(img, 'failure')
            if draw:
                __img = img.copy()
                cv2.polylines(__img, ret_fail, True, (255, 0, 0), 3, cv2.LINE_AA)
                plt.imshow(__img)
                plt.show()

        logger.debug(f'character_recognize.agent: {ret_agent}')
        logger.debug(f'character_recognize.agent: {[x.tolist() for x in ret]}')
        return list(zip(ret_agent, ret_succ))

    except Exception as e:
        logger.debug(traceback.format_exc())
        saveimg(img, 'failure_agent')
        raise RecognizeError(e)

def agent_name(__img, height,reverse = False, draw: bool = False):
    query = cv2.cvtColor(np.array(__img), cv2.COLOR_RGB2GRAY)
    h, w= query.shape
    dim = (w*4, h*4)
    # resize image
    resized = cv2.resize(__img, dim, interpolation=cv2.INTER_AREA)
    ocr = ocrhandle.predict(resized)
    name = ''
    try:
        if len(ocr) > 0 and ocr[0][1] in agent_list and ocr[0][1] not in ['砾', '陈']:
            name = ocr[0][1]
        else:
            res = sift_recog(__img, height, draw, reverse)
            if (res is not None) and res in agent_list:
                name = res
            else:
                raise Exception("识别错误")
    except Exception as e:
        saveimg(__img, 'failure_agent')
    return name