hv0905/NekoImageGallery

View on GitHub
app/Services/index_service.py

Summary

Maintainability
A
35 mins
Test Coverage
from PIL import Image
from fastapi.concurrency import run_in_threadpool

from app.Models.errors import PointDuplicateError
from app.Models.mapped_image import MappedImage
from app.Services.lifespan_service import LifespanService
from app.Services.ocr_services import OCRService
from app.Services.transformers_service import TransformersService
from app.Services.vector_db_context import VectorDbContext
from app.config import config


class IndexService(LifespanService):
    def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext):
        self._ocr_service = ocr_service
        self._transformers_service = transformers_service
        self._db_context = db_context

    def _prepare_image(self, image: Image.Image, image_data: MappedImage, skip_ocr=False):
        image_data.width = image.width
        image_data.height = image.height
        image_data.aspect_ratio = float(image.width) / image.height

        if image.mode != 'RGB':
            image = image.convert('RGB')  # to reduce convert in next steps
        else:
            image = image.copy()
        image_data.image_vector = self._transformers_service.get_image_vector(image)
        if not skip_ocr and config.ocr_search.enable:
            image_data.ocr_text = self._ocr_service.ocr_interface(image)
            if image_data.ocr_text != "":
                image_data.text_contain_vector = self._transformers_service.get_bert_vector(image_data.ocr_text)
            else:
                image_data.ocr_text = None

    # currently, here only need just a simple check
    async def _is_point_duplicate(self, image_data: list[MappedImage]) -> bool:
        image_id_list = [str(item.id) for item in image_data]
        result = await self._db_context.validate_ids(image_id_list)
        return len(result) != 0

    async def index_image(self, image: Image.Image, image_data: MappedImage, skip_ocr=False, skip_duplicate_check=False,
                          background=False):
        if not skip_duplicate_check and (await self._is_point_duplicate([image_data])):
            raise PointDuplicateError("The uploaded points are contained in the database!", image_data.id)

        if background:
            await run_in_threadpool(self._prepare_image, image, image_data, skip_ocr)
        else:
            self._prepare_image(image, image_data, skip_ocr)

        await self._db_context.insertItems([image_data])

    async def index_image_batch(self, image: list[Image.Image], image_data: list[MappedImage],
                                skip_ocr=False, allow_overwrite=False):
        if not allow_overwrite and (await self._is_point_duplicate(image_data)):
            raise PointDuplicateError("The uploaded points are contained in the database!")
        for img, img_data in zip(image, image_data):
            self._prepare_image(img, img_data, skip_ocr)
        await self._db_context.insertItems(image_data)