guimc233/lgz-bot

View on GitHub
src/main/kotlin/ltd/guimc/lgzbot/listener/message/ImageOCRFilter.kt

Summary

Maintainability
D
1 day
Test Coverage
package ltd.guimc.lgzbot.listener.message

import io.github.mymonstercat.Model
import io.github.mymonstercat.ocr.InferenceEngine
import ltd.guimc.lgzbot.PluginMain
import ltd.guimc.lgzbot.PluginMain.adPinyinRegex
import ltd.guimc.lgzbot.PluginMain.adRegex
import ltd.guimc.lgzbot.PluginMain.disableImageCheck
import ltd.guimc.lgzbot.PluginMain.logger
import ltd.guimc.lgzbot.PluginMain.seriousRegex
import ltd.guimc.lgzbot.files.ModuleStateConfig
import ltd.guimc.lgzbot.listener.message.MessageFilter.historyMessage
import ltd.guimc.lgzbot.listener.message.MessageFilter.messagesHandled
import ltd.guimc.lgzbot.listener.message.MessageFilter.mute
import ltd.guimc.lgzbot.listener.message.MessageFilter.recalledMessage
import ltd.guimc.lgzbot.listener.message.MessageFilter.riskList
import ltd.guimc.lgzbot.listener.message.MessageFilter.setVl
import ltd.guimc.lgzbot.utils.AsciiUtil.sbc2dbcCase
import ltd.guimc.lgzbot.utils.HttpUtils
import ltd.guimc.lgzbot.utils.ImageUtils
import ltd.guimc.lgzbot.utils.LL4JUtils
import ltd.guimc.lgzbot.utils.MemberUtils.mute
import ltd.guimc.lgzbot.utils.RegexUtils
import ltd.guimc.lgzbot.utils.TextUtils.removeInterference
import ltd.guimc.lgzbot.utils.TextUtils.removeNonVisible
import net.mamoe.mirai.console.permission.PermissionService.Companion.hasPermission
import net.mamoe.mirai.console.permission.PermitteeId.Companion.permitteeId
import net.mamoe.mirai.console.util.cast
import net.mamoe.mirai.event.events.GroupMessageEvent
import net.mamoe.mirai.message.data.ForwardMessage
import net.mamoe.mirai.message.data.Image
import net.mamoe.mirai.message.data.Image.Key.queryUrl
import net.mamoe.mirai.message.data.MessageChain
import net.mamoe.mirai.message.data.MessageSource.Key.recall
import org.apache.commons.lang3.time.StopWatch
import java.io.ByteArrayInputStream
import java.io.File
import java.lang.Thread.sleep
import java.security.MessageDigest
import java.sql.Connection
import java.sql.DriverManager
import java.sql.PreparedStatement
import java.sql.ResultSet
import java.util.concurrent.locks.ReentrantLock


object ImageOCRFilter {
    lateinit var engine: InferenceEngine
    lateinit var connection: Connection
    private val lock = ReentrantLock()
    var memberReviewing = mutableMapOf<Long, String>()
    var supported: Boolean = false

    suspend fun filter(e: GroupMessageEvent) {
        if (!supported) return
        findImageToFilter(e, e.message)
    }

    suspend fun findImageToFilter(e: GroupMessageEvent, m: MessageChain): Boolean {
        for (message in m) {
            if (message is Image) {
                if (filterImage(e, message)) {
                    return true
                }
            }/* else if (message is ForwardMessage) {
                if (findImageToFilter(e, message.nodeList)) {
                    return true
                }
            }*/
        }
        return false
    }

    suspend fun findImageToFilter(e: GroupMessageEvent, m: List<ForwardMessage.Node>): Boolean {
        for (fm in m) {
            if (findImageToFilter(e, fm.messageChain)) {
                return true
            }
        }
        return false
    }

    private suspend fun filterImage(e: GroupMessageEvent, m: Image): Boolean {
        var muted = false
        try {
            var content = sbc2dbcCase(recognizeAndStoreImage(m.cast<Image>().queryUrl()))
                .lowercase()
                .removeInterference()
                .removeNonVisible()
            if (!(e.sender.permission.level < e.group.botPermission.level || ModuleStateConfig.slientmute)) return false
            if (e.group.permitteeId.hasPermission(disableImageCheck)) return false

            val predictedResult = LL4JUtils.predictAllResult(content)
            val predicted = predictedResult[1] > predictedResult[0]
            if (RegexUtils.matchRegex(seriousRegex, content)) {
                recalledMessage++
                e.message.recall()
                e.sender.mute(7 * 24 * 60 * 60, "非法发言内容 (图片OCR识别) (敏感内容)")
                muted = true
            }
            if (!muted && RegexUtils.matchRegex(adRegex, content) && content.length >= 30) {
                try {
                    recalledMessage++
                    e.message.recall()
                    if (predictedResult[1] - predictedResult[0] >= 0.22) {
                        e.group.mute(e.sender, "非法发言内容 (图片OCR识别) (模型证实)")
                        riskList.add(e.sender)
                        setVl(e.sender.id, 99.0)
                    } else if (predictedResult[1] - predictedResult[0] >= 0.12) {
                        e.sender.mute(120, "非法发言内容 (图片OCR识别) (模型猜测)")
                    } else {
                        logger.info("开始 ${e.sender.id} 的追溯检查")
                        memberReviewing.put(e.sender.id, content)
                    }
                    muted = true
                } catch (_: Exception) {
                }
                messagesHandled++
            }

            if (memberReviewing.containsKey(e.sender.id)) {
                memberReviewing.put(
                    e.sender.id,
                    memberReviewing.get(e.sender.id) + "\n" + content.replace("\n", "")
                )
                if (memberReviewing.get(e.sender.id)?.let { RegexUtils.countLines(it) }!! >= 5) {
                    logger.info("结束 ${e.sender.id} 的追溯检查")
                    memberReviewing.remove(e.sender.id)
                    if (memberReviewing.get(e.sender.id)!!.length <= 50) {
                        memberReviewing.remove(e.sender.id)
                        return true
                    }
                    if (predictedResult[1] - predictedResult[0] >= 0.22) {
                        e.group.mute(e.sender, "追溯检查 (图片OCR) (模型复查)")
                        muted = true
                        recalledMessage++
                        e.message.recall()
                        try {
                            historyMessage[e.sender.id]?.forEach {
                                recalledMessage++
                                it.recall()
                                sleep(100)
                            }
                        } catch (_: Exception) {
                        }
                        historyMessage[e.sender.id]?.clear()
                        setVl(e.sender.id, .0)
                        messagesHandled++
                    }
                }
            }

            // 拼音检查发言
            if (!muted && riskList.indexOf(e.sender) != -1 && RegexUtils.matchRegexPinyin(adPinyinRegex, content)) {
                try {
                    recalledMessage++
                    e.message.recall()
                    e.group.mute(e.sender, "非法发言内容 (Risk) (图片OCR识别)")
                    muted = true
                } catch (_: Exception) {
                }
                setVl(e.sender.id, 99.0)
                messagesHandled++
            }
        } catch (e: Exception) {
            e.printStackTrace()
        }
        return muted
    }

    fun calculateSha256(fileData: ByteArray): String {
        val md = MessageDigest.getInstance("SHA-256")
        val digest = md.digest(fileData)
        return digest.fold("") { str, it -> str + "%02x".format(it) }
    }

    fun recognizeAndStoreImage(imageUrl: String): String {
        try {
            // 获取图片数据
            val imageRaw = HttpUtils.getBytesResponse(imageUrl)
            if (imageRaw == null) {
                throw RuntimeException("Image data is null")
            }

            // 计算 SHA-256 哈希值
            val sha256Hash = calculateSha256(imageRaw)

            // 检查哈希是否已存在
            val querySql = "SELECT image_content FROM recognized_images WHERE sha256_hash = ?"
            val pstmt: PreparedStatement = connection.prepareStatement(querySql)
            pstmt.setString(1, sha256Hash)
            val rs: ResultSet = pstmt.executeQuery()
            var content: String = "";

            if (rs.next()) {
                // 如果哈希已存在,获取内容
                content = rs.getString("image_content")
            } else {
                // 如果哈希不存在,进行OCR处理并存储
                val imageData = ByteArrayInputStream(imageRaw.clone())
                val imageType = ImageUtils.imgType(imageData)
                if (imageType.equals("gif") || imageType.equals("webp") || imageType.equals("tif")) return ""
                val tempFile = File.createTempFile("temp", ".$imageType")
                tempFile.writeBytes(imageRaw)

                // OCR 识别
                val stopWatch = StopWatch()
                stopWatch.start()
                content = engine.runOcr(tempFile.path).strRes.trim()
                stopWatch.stop()
                logger.info("OCR已完成 耗时: ${stopWatch.time} ms")

                // 删除临时文件
                tempFile.delete()

                // 插入识别内容到数据库
                lock.lock()
                val insertSql = "INSERT INTO recognized_images (sha256_hash, image_content) VALUES (?, ?)"
                val insertPstmt: PreparedStatement = connection.prepareStatement(insertSql)
                insertPstmt.setString(1, sha256Hash)
                insertPstmt.setString(2, content)
                insertPstmt.executeUpdate()
                insertPstmt.close()
                lock.unlock()
            }

            rs.close()
            pstmt.close()
            return content
        } finally {
            // 确保锁被释放
            if (lock.isLocked) {
                lock.unlock()
            }
        }
    }


    fun init() {
        try {
            engine = InferenceEngine.getInstance(Model.ONNX_PPOCR_V3)
            connection = initializeDatabase(
                "jdbc:sqlite:" + File(
                    PluginMain.dataFolderPath.toFile(),
                    "recognized_images.db"
                ).absolutePath
            )
            supported = true
        } catch (e: Exception) {
            logger.warning(e)
            supported = false
        }
    }

    fun initializeDatabase(dbUrl: String): Connection {
        val dbFile = File(dbUrl.substringAfter("jdbc:sqlite:"))
        val isNewDatabase = !dbFile.exists()
        val conn = DriverManager.getConnection(dbUrl)

        if (isNewDatabase) {
            val stmt = conn.createStatement()
            val createTableSql = """
            CREATE TABLE IF NOT EXISTS recognized_images (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                sha256_hash TEXT UNIQUE,
                image_content TEXT
            )
        """
            stmt.execute(createTableSql)
            stmt.close()
        }

        return conn
    }
}